In [1]:
pip install torch transformers peft datasets gradio


Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting gradio
  Downloading gradio-5.22.0-py3-none-any.whl.metadata (16 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cuf

In [18]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
import gradio as gr
import re
import ast
import operator as op

# Define supported operators for safe evaluation
operators = {
    ast.Add: op.add,
    ast.Sub: op.sub,
    ast.Mult: op.mul,
    ast.Div: op.truediv,
    ast.Pow: op.pow,
    ast.USub: op.neg
}

# Function to safely evaluate a math expression
def evaluate_expression(expr):
    try:
        expr = expr.replace("÷", "/")  # Replace ÷ with / for Python compatibility
        expr = expr.replace("^", "**")
        tree = ast.parse(expr, mode='eval')
        result = eval_node(tree.body)
        return result, None  # Successful evaluation: return (result, None)
    except Exception as e:
        return None, f"Error evaluating expression: {str(e)}"  # Failed evaluation: return (None, error message)

def eval_node(node):
    if isinstance(node, ast.Num):
        return node.n
    elif isinstance(node, ast.BinOp):
        left = eval_node(node.left)
        right = eval_node(node.right)
        return operators[type(node.op)](left, right)
    elif isinstance(node, ast.UnaryOp):
        operand = eval_node(node.operand)
        return operators[type(node.op)](operand)
    else:
        raise ValueError("Unsupported operation")

# Function to handle factorial
def compute_factorial(n):
    if not isinstance(n, int) or n < 0:
        return None
    if n == 0:
        return 1
    result = 1
    for i in range(1, n + 1):
        result *= i
    return result

# Dataset for fine-tuning GPT-2 (only for explanations)
data = [
    {"input": "Incorrect: 8 ÷ 2(2+2) = 1?", "output": "Explanation: Evaluate parentheses first then perform division and multiplication sequentially."},
    {"input": "Incorrect: 5 + 5 = 20?", "output": "Explanation: Simple addition error."},
    {"input": "Incorrect: 6 * 6 = 36 but 6 / 6 = 6?", "output": "Explanation: A number divided by itself equals 1."},
    {"input": "Incorrect: 2^3 = 6?", "output": "Explanation: 2 cubed is 8."},
    {"input": "Incorrect: √16 = 5?", "output": "Explanation: The square root of 16 is 4."},
    {"input": "Incorrect: 9 - 3 = 3?", "output": "Explanation: Correct subtraction yields 6."},
    {"input": "Incorrect: 4 * 4 = 8?", "output": "Explanation: Multiplication error."},
    {"input": "Incorrect: 10 / 2 = 10?", "output": "Explanation: Division error."},
    {"input": "Incorrect: 15% of 200 = 50?", "output": "Explanation: 15% of 200 equals 30."},
    {"input": "Incorrect: 100 / 4 = 20?", "output": "Explanation: Division error."},
    {"input": "Incorrect: 3 + 7 = 11?", "output": "Explanation: 3 plus 7 equals 10."},
    {"input": "Incorrect: 2 * 3 + 4 = 14?", "output": "Explanation: Follow order of operations: multiply then add."},
    {"input": "Incorrect: 12 / 3 * 2 = 10?", "output": "Explanation: 12 divided by 3 is 4; 4 times 2 is 8."},
    {"input": "Incorrect: 7 * 7 = 42?", "output": "Explanation: Multiplication error."},
    {"input": "Incorrect: 14 - 7 = 8?", "output": "Explanation: Subtraction error."},
    {"input": "Incorrect: (3 + 2) * 2 = 12?", "output": "Explanation: Add first, then multiply."},
    {"input": "Incorrect: 50% of 100 = 60?", "output": "Explanation: 50% is half of 100."},
    {"input": "Incorrect: 9 + 9 = 18 then 18 / 2 = 10?", "output": "Explanation: Division error."},
    {"input": "Incorrect: 5! = 100?", "output": "Explanation: 5 factorial is 120."},
    {"input": "Incorrect: 3^2 + 4^2 = 14?", "output": "Explanation: 9 + 16 equals 25."}
]

# Combine input and output for training
train_data = [{"text": f"{item['input']} {item['output']}"} for item in data]
dataset = Dataset.from_list(train_data)
print("Dataset created with", len(dataset), "examples.")

# Load GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2")
print("GPT-2 model loaded.")

# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, max_length=128, padding="max_length")

tokenized_dataset = dataset.map(tokenize_function, batched=False)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
print("Dataset tokenized.")

# Training arguments
training_args = TrainingArguments(
    output_dir="output",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=5,
    logging_steps=1,
    save_strategy="epoch",
    learning_rate=5e-5,
    weight_decay=0.1,
    report_to="none"
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

print("Starting training...")
trainer.train()
print("Training complete!")

# Inference function
def correct_math(prompt):
    # Step 1: Rule-based correction for the correct answer
    prompt = prompt.strip()
    match = re.match(r"(.+?)\s*=\s*(\d+\.?\d*)\?", prompt)
    if not match:
        return f"Incorrect: {prompt} Correct: (unable to parse expression)."

    expr, incorrect_answer = match.groups()
    expr = expr.strip()
    incorrect_answer = float(incorrect_answer)

    # Handle special cases (factorial, percentages, square roots)
    if "!" in expr:
        num = int(re.search(r"(\d+)!", expr).group(1))
        correct_answer = compute_factorial(num)
        expr = f"{num}!"
    elif "% of" in expr:
        percentage, number = map(float, re.search(r"(\d+)% of (\d+)", expr).groups())
        correct_answer = (percentage / 100) * number
    elif "√" in expr:
        number = float(re.search(r"√(\d+)", expr).group(1))
        correct_answer = number ** 0.5
    else:
        correct_answer, error = evaluate_expression(expr)
        if error:
            return f"Incorrect: {prompt} Correct: {error}"

    correct_answer = round(correct_answer, 2) if isinstance(correct_answer, float) else correct_answer

    # Step 2: Use GPT-2 to generate the explanation
    model.eval()
    full_prompt = f"Incorrect: {prompt} Explanation: "
    inputs = tokenizer(full_prompt, return_tensors="pt", padding=True)
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)

    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=50,
        num_beams=4,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id
    )

    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Raw output from GPT-2 for {prompt}: {result}")

    # Extract the explanation
    match = re.search(r"Explanation: (.*?)(?:\.)", result)
    explanation = match.group(1) if match else "No explanation generated."

    # Step 3: Combine the rule-based correction with the model's explanation
    return f"Incorrect: {expr} = {incorrect_answer}? Correct: {expr} = {correct_answer}. Explanation: {explanation}"

# Test all prompts in the dataset
prompts = [item["input"].replace("Incorrect: ", "") for item in data]
for prompt in prompts:
    print(f"Prompt: {prompt}")
    print(f"Output: {correct_math(prompt)}\n")

# Gradio interface
gr.Interface(
    fn=correct_math,
    inputs="text",
    outputs="text",
    title="Math Correction Model",
    description="Enter an incorrect math statement (e.g., '5! = 100?') to get the correct answer and explanation."
).launch()

Dataset created with 20 examples.
GPT-2 model loaded.


Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Dataset tokenized.
Starting training...


Step,Training Loss
1,4.8748
2,4.1707
3,3.676
4,3.6013
5,3.1857
6,2.4348
7,2.9286
8,2.3788
9,2.0524
10,2.1368


Training complete!
Prompt: 8 ÷ 2(2+2) = 1?
Output: Incorrect: 8 ÷ 2(2+2) = 1? Correct: Error evaluating expression: Unsupported operation

Prompt: 5 + 5 = 20?
Raw output from GPT-2 for 5 + 5 = 20?: Incorrect: 5 + 5 = 20? Explanation:  Multiplication error. Explanation: Multiplication error. Explanation: Multiplication error. Explanation: Multiplication error. Explanation: Multiplication error. Explanation: Multiplication error. Explanation: Multi
Output: Incorrect: 5 + 5 = 20.0? Correct: 5 + 5 = 10. Explanation:  Multiplication error

Prompt: 6 * 6 = 36 but 6 / 6 = 6?
Output: Incorrect: 6 * 6 = 36 but 6 / 6 = 6? Correct: Error evaluating expression: invalid syntax (<unknown>, line 1)

Prompt: 2^3 = 6?
Raw output from GPT-2 for 2^3 = 6?: Incorrect: 2^3 = 6? Explanation:  Multiplication error. Subtraction error. Subtraction error. Subtraction error. Subtraction error. Subtraction error. Subtraction error. Subtraction error. Subtraction error. Subtraction error
Output: Incorrect: 2^3 = 6.

