#### Generates a new test case based on the problem statement using GPT-2.

Note: 
- The device is set to **mps**, you might need to change it to **cpu**
- The templates do not contain hyperparameters and model.

In [None]:
import os
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Directories
STATEMENTS_DIR = "scraped/statements"
BASE_GENERATED_DIR = "generated"

# Ensure base directory exists
os.makedirs(BASE_GENERATED_DIR, exist_ok=True)

# Load GPT-2 Model & Tokenizer
model_name = "gpt2-xl"  # Change to "gpt2-medium" for better quality
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "mps"
model.to(device)


def generate_test_case(statement, strategy_id=0):
    
    prompt_templates = {
        0: f"Given the following problem description, generate a unique test case:\n\n{statement}\n\nTest case:",
        1: f"Generate an edge case test for the following problem:\n\n{statement}\n\nEdge case test:",
        3: f"{statement}\n\nInput:",
        4: f"{statement}\n\nInput:",
        5: f"{statement}\n\nInput:",
        6: f"{statement}\n\nInput:",
    }

    prompt = prompt_templates.get(strategy_id, prompt_templates[strategy_id])

    # Encode input and generate output
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    output = model.generate(
        input_ids, 
        max_length=1024, 
        num_return_sequences=1, 
        pad_token_id=tokenizer.eos_token_id,
        # repetition_penalty=1.2,
    )

    # Decode output
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    # Extract only the generated test case (remove prompt)
    return generated_text


def save_test_case(problem_id, generated_text, strategy_id):
    """Saves the generated test case to the appropriate directory."""
    
    # Define strategy-specific output directory
    output_dir = os.path.join(BASE_GENERATED_DIR, f"gpt2-{strategy_id}")
    os.makedirs(output_dir, exist_ok=True)

    # Define file path
    output_path = os.path.join(output_dir, f"{problem_id}.txt")

    # Skip if already generated
    if os.path.exists(output_path):
        print(f"⚠️ Skipping {problem_id}, already generated for strategy {strategy_id}.")
        return

    # Save generated test case
    with open(output_path, "w", encoding="utf-8") as file:
        file.write(generated_text)

    print(f"✅ Generated test case for {problem_id} -> {output_path}")


def process_statements(strategy_id=0):
    """Reads statements, generates test cases using a strategy, and saves them."""
    
    for filename in os.listdir(STATEMENTS_DIR):
        if filename.endswith(".txt"):
            problem_id = filename.replace(".txt", "")  # Extract ID-letter
            input_path = os.path.join(STATEMENTS_DIR, filename)

            # Read problem statement
            with open(input_path, "r", encoding="utf-8") as file:
                statement = file.read()

            # Check if statement is too long - skip it
            if len(statement) > 2048:
                print(f"⚠️ Skipping {problem_id}, too long.")
                continue

            # Generate test case using the selected strategy
            generated_test = generate_test_case(statement, strategy_id)

            # Save generated test case
            save_test_case(problem_id, generated_test, strategy_id)


# Run different strategies
strategy_id = 6
process_statements(strategy_id)

✅ Generated test case for 118-a -> generated/gpt2-6/118-a.txt
✅ Generated test case for 451-a -> generated/gpt2-6/451-a.txt
✅ Generated test case for 750-a -> generated/gpt2-6/750-a.txt
✅ Generated test case for 1692-a -> generated/gpt2-6/1692-a.txt
✅ Generated test case for 266-b -> generated/gpt2-6/266-b.txt
✅ Generated test case for 1807-a -> generated/gpt2-6/1807-a.txt
✅ Generated test case for 1475-a -> generated/gpt2-6/1475-a.txt
✅ Generated test case for 1512-a -> generated/gpt2-6/1512-a.txt
✅ Generated test case for 1335-a -> generated/gpt2-6/1335-a.txt
✅ Generated test case for 271-a -> generated/gpt2-6/271-a.txt
✅ Generated test case for 443-a -> generated/gpt2-6/443-a.txt
✅ Generated test case for 1352-a -> generated/gpt2-6/1352-a.txt
✅ Generated test case for 58-a -> generated/gpt2-6/58-a.txt
✅ Generated test case for 467-a -> generated/gpt2-6/467-a.txt
✅ Generated test case for 148-a -> generated/gpt2-6/148-a.txt
✅ Generated test case for 228-a -> generated/gpt2-6/228-a.tx