### generate.py

In [3]:
import torch
import numpy as np
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel

In [4]:
def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise

In [5]:
def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens

In [6]:
@ torch.no_grad()
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
             cfg_scale=0., remasking='low_confidence', mask_id=126336):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
    '''
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        for i in range(steps):
            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if remasking == 'low_confidence':
                p = F.softmax(logits.to(torch.float64), dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

    return x

### Inference

In [None]:
device = 'cuda'
model = AutoModel.from_pretrained('./LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16, use_cache=False).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('./LLaDA-8B-Instruct', trust_remote_code=True)

print("Loaded model on device:", model.device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]



Loaded model on device: cuda:0


In [9]:
prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"

# Add special tokens for the Instruct model. The Base model does not require the following two lines.
m = [{"role": "user", "content": prompt}, ]
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)

In [10]:
print(prompt)

<|startoftext|><|start_header_id|>user<|end_header_id|>

Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?<|eot_id|><|start_header_id|>assistant<|end_header_id|>




In [11]:
input_ids = tokenizer(prompt)['input_ids']
input_ids = torch.tensor(input_ids).to(model.device).unsqueeze(0)

out = generate(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])

Lily can run 12 kilometers per hour for 4 hours, so she runs a total of 12 * 4 = 48 kilometers.
After that, she runs 6 kilometers per hour for the remaining 4 hours, so she runs a total of 6 * 4 = 24 kilometers.
Therefore, Lily can run a total of 48 + 24 = 72 kilometers in 8 hours.
The final result is 72


### Reasoning

In [None]:
# MATH DEMO
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]
output_dir = "outputs/LLaDA-GRPO"
import os
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
training_args = GRPOConfig(
    output_dir=output_dir,
    # run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    # peft_config=peft_config
)
trainer.train()
tokenizer.decode([126080, 126346,   3840, 126347,    198,    198,  48851,  26612, 126348, 126346,    598,  10450, 126347,    198,    198,  14455,      0,   2071,560,    331,   6528,    362,   3342,     30])
'<|startoftext|>user\n\nsay hiassistant\n\nHello! How can I assist you today?'
# TODO: UNDERSTAND TOKEN STRUCTURE
# tensor([[126080, 126346,   3840, 126347,    198,    198,  48851,  26612, 126348,
#          126346,    598,  10450, 126347,    198,    198,  14455,      0,   2071,
#             560,    331,   6528,    362,   3342,     30, 126348, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081]],
#        device='cuda:0')
# Bot's reply: Hello! How can I assist you today?

OutOfMemoryError: CUDA out of memory. Tried to allocate 988.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 24.12 MiB is free. Process 35470 has 14.71 GiB memory in use. Of the allocated memory 14.54 GiB is allocated by PyTorch, and 52.41 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## ARC-AGI

### Training

In [None]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# Define system prompt for ARC-AGI format
SYSTEM_PROMPT = """
You are solving ARC-AGI puzzles. For each puzzle, provide your reasoning step-by-step and then give your final answer.
Respond in the following format:
<reasoning>
...detailed step-by-step solution...
</reasoning>
<answer>
...your final answer...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    """Extract the answer from XML formatted text."""
    if "<answer>" not in text or "</answer>" not in text:
        return ""
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_xml_reasoning(text: str) -> str:
    """Extract the reasoning from XML formatted text."""
    if "<reasoning>" not in text or "</reasoning>" not in text:
        return ""
    reasoning = text.split("<reasoning>")[-1]
    reasoning = reasoning.split("</reasoning>")[0]
    return reasoning.strip()

def get_arc_agi_questions(split = "train") -> Dataset:
    """Load and preprocess the ARC-AGI dataset."""
    # Load ARC-AGI dataset
    data = load_dataset('allenai/arc_agi')[split]
    
    # Map the dataset to the desired format
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': x['answer'],  # assuming the answer field exists
        'question_id': x['id']  # store question ID for evaluation
    })
    
    return data

# Load dataset
dataset = get_arc_agi_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward based on whether the answer is correct."""
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    
    # Log for debugging
    print('-'*20)
    print(f"Question:\n{q}")
    print(f"\nExpected Answer:\n{answer[0]}")
    print(f"\nResponse:\n{responses[0]}")
    print(f"\nExtracted:\n{extracted_responses[0]}")
    
    # For ARC-AGI, we need a more flexible matching since answers might not be just digits
    return [2.0 if r.lower().strip() == a.lower().strip() else 0.0 
            for r, a in zip(extracted_responses, answer)]

def reasoning_quality_reward_func(completions, **kwargs) -> list[float]:
    """Reward for the presence of quality reasoning."""
    responses = [completion[0]['content'] for completion in completions]
    reasonings = [extract_xml_reasoning(r) for r in responses]
    
    # Simple heuristic: longer reasoning (up to a point) indicates more thought
    # Cap at 500 chars to prevent rambling
    return [min(len(r.split()) / 100, 1.0) for r in reasonings]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has the required XML format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """More lenient reward function for XML format checking."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    """Count XML tags with precise formatting."""
    count = 0.0
    if text.count("<reasoning>") == 1:
        count += 0.125
    if text.count("</reasoning>") == 1:
        count += 0.125
    if text.count("<answer>") == 1:
        count += 0.125
        count -= len(text.split("</answer>")[-1])*0.001  # Penalty for content after closing answer tag
    if text.count("</answer>") == 1:
        count += 0.125
        count -= (len(text.split("</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """Reward function based on XML tag counting."""
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# Setup training configuration
output_dir = "outputs/ARC-AGI-GRPO"

# Environmental variables for distributed training
import os
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

# Configure training parameters
training_args = GRPOConfig(
    output_dir=output_dir,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=512,  # Increased for potentially longer ARC-AGI questions
    max_completion_length=1024,  # Increased for more elaborate reasoning
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",  # Weights & Biases logging
    log_on_each_node=False,
)

# Configure LoRA for parameter-efficient fine-tuning
peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)

# Load model and tokenizer (these need to be defined)
model_name = "meta-llama/Llama-2-13b-hf"  # Choose an appropriate base model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Special tokens handling for chat format
if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token

# Initialize trainer
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        reasoning_quality_reward_func,
        correctness_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config  # Uncomment if using PEFT
)

# Start training
if __name__ == "__main__":
    trainer.train()
    
    # Save the final model
    trainer.save_model(output_dir + "/final_model")
    
    # Optional: Run evaluation on test set
    # test_dataset = get_arc_agi_questions("test")
    # metrics = trainer.evaluate(test_dataset)
    # print(f"Test metrics: {metrics}")

### Evaluation

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import json
import os
from tqdm import tqdm
import re

# Helper functions for text extraction
def extract_xml_answer(text: str) -> str:
    """Extract the answer from XML formatted text."""
    if "<answer>" not in text or "</answer>" not in text:
        return ""
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_xml_reasoning(text: str) -> str:
    """Extract the reasoning from XML formatted text."""
    if "<reasoning>" not in text or "</reasoning>" not in text:
        return ""
    reasoning = text.split("<reasoning>")[-1]
    reasoning = reasoning.split("</reasoning>")[0]
    return reasoning.strip()

def main():
    # Configuration
    model_path = "outputs/ARC-AGI-GRPO/final_model"  # Path to your trained model
    base_model_name = "meta-llama/Llama-2-13b-hf"  # Base model used for training
    output_dir = "evaluation_results"
    os.makedirs(output_dir, exist_ok=True)
    
    # System prompt
    system_prompt = """
    You are solving ARC-AGI puzzles. For each puzzle, provide your reasoning step-by-step and then give your final answer.
    Respond in the following format:
    <reasoning>
    ...detailed step-by-step solution...
    </reasoning>
    <answer>
    ...your final answer...
    </answer>
    """
    
    # Load model and tokenizer
    print("Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    
    # Load the base model
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name, 
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    
    # Load the trained PEFT model
    model = PeftModel.from_pretrained(base_model, model_path)
    model.eval()
    
    # Load test dataset
    print("Loading ARC-AGI test dataset...")
    test_data = load_dataset('allenai/arc_agi', split='test')
    
    # Prepare for evaluation
    results = []
    correct_count = 0
    format_correct_count = 0
    
    print("Starting evaluation...")
    for i, example in enumerate(tqdm(test_data)):
        question_id = example['id']
        question = example['question']
        correct_answer = example['answer']
        
        # Format input with system prompt
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question}
        ]
        
        # Convert to model input format
        input_text = tokenizer.apply_chat_template(messages, tokenize=False)
        inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
        
        # Generate response
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=1024,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )
        
        # Decode response
        response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        
        # Extract answer and reasoning
        extracted_answer = extract_xml_answer(response)
        extracted_reasoning = extract_xml_reasoning(response)
        
        # Check format compliance
        format_correct = (
            "<reasoning>" in response and 
            "</reasoning>" in response and 
            "<answer>" in response and 
            "</answer>" in response
        )
        
        # Check answer correctness (case-insensitive comparison)
        is_correct = extracted_answer.lower().strip() == correct_answer.lower().strip()
        
        if is_correct:
            correct_count += 1
        
        if format_correct:
            format_correct_count += 1
        
        # Store result
        results.append({
            "question_id": question_id,
            "question": question,
            "expected_answer": correct_answer,
            "model_response": response,
            "extracted_answer": extracted_answer,
            "extracted_reasoning": extracted_reasoning,
            "format_correct": format_correct,
            "answer_correct": is_correct
        })
        
        # Periodically save results
        if (i + 1) % 10 == 0 or i == len(test_data) - 1:
            with open(os.path.join(output_dir, "results.json"), "w") as f:
                json.dump(results, f, indent=2)
    
    # Calculate metrics
    total = len(test_data)
    accuracy = correct_count / total * 100
    format_accuracy = format_correct_count / total * 100
    
    # Save metrics
    metrics = {
        "total_examples": total,
        "correct_answers": correct_count,
        "answer_accuracy": accuracy,
        "format_correct": format_correct_count,
        "format_accuracy": format_accuracy
    }
    
    with open(os.path.join(output_dir, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)
    
    print(f"Evaluation complete!")
    print(f"Answer Accuracy: {accuracy:.2f}%")
    print(f"Format Accuracy: {format_accuracy:.2f}%")

if __name__ == "__main__":
    main()