#### Generates a new test case based on the problem statement using GPT-2. * For the simple problems

In [43]:
import os
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm.auto import tqdm

# Define paths
# input_file_path = "scraped/simple/simple-01.txt"
# output_dir = "generated/gpt2-simple-1"
input_file_path = "scraped/simple/simple-04.txt"
output_dir = "generated/gpt2-simple-2"

In [44]:
# Function to parse dataset
def parse_problems(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        content = f.read().strip()
    
    problems = content.split("===")
    parsed_data = []
    
    for idx, problem in enumerate(problems):
        parts = problem.strip().split("### Test Case ###")
        if len(parts) < 2:
            continue  # Skip malformed entries
        
        problem_statement = parts[0].strip()
        test_cases = [tc.strip() for tc in parts[1:]]

        parsed_data.append({
            "id": idx + 1,
            "problem": problem_statement,
            "test_cases": test_cases,
            "raw": problem,
        })
    
    return parsed_data

# Process and generate new test cases
problems = parse_problems(input_file_path)

In [None]:
# Load GPT-2 model and tokenizer
model_name = "gpt2"  # or "gpt2-small"

# Check if MPS is available
# mps: 2.12s/it
# cpu: 1.56s/it
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

tokenizer = GPT2Tokenizer.from_pretrained(model_name, padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained(model_name).to(device)

def generate_test_cases_batched(prompts, batch_size=64):
    """
    Generate multiple test cases in parallel using a batch.
    """
    tokenized_inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
    input_ids = tokenized_inputs["input_ids"].to(device)  # Move batch to MPS
    
    with torch.no_grad():
        output_ids = model.generate(
            input_ids, 
            max_new_tokens=24, 
            num_return_sequences=1, 
            pad_token_id=tokenizer.eos_token_id,
            # repetition_penalty=1.1
        )

    # Move back to CPU and decode
    generated_texts = [tokenizer.decode(ids.cpu(), skip_special_tokens=True) for ids in output_ids]
    return generated_texts


Using device: mps


In [51]:
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

prompts = [f"{problem_data['raw']}\n### Test Case ###" for problem_data in problems]

generated_cases = generate_test_cases_batched(prompts, batch_size=64)

# Print results
for i, updated in enumerate(generated_cases):
    output_file_path = os.path.join(output_dir, f"{problems[i]['id']}.txt")
    with open(output_file_path, "w", encoding="utf-8") as f:
        f.write(updated)

print(f"Saved {len(generated_cases)} statements.")


Saved 20 statements.
