# Experiment 3:
# GRPO for Enhanced Math Reasoning

In [None]:
from tunix.trainers import GRPOLearner
from tunix.environments import MathEnvironment
from datasets import load_dataset

# Load GSM8K test set for evaluation
gsm8k_test = load_dataset("gsm8k", "main", split="test")

# Initialize the math environment
math_env = MathEnvironment(
    dataset="gsm8k",
    reward_type="correctness",  # Reward based on correct final answer
    format_check=True
)

# Load base model (or your DPO model)
base_model = FlaxAutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it")

# Configure GRPO learner
grpo_learner = GRPOLearner(
    model=base_model,
    tokenizer=tokenizer,
    environment=math_env,
    learning_rate=1e-6,
    group_size=4,  # To generate 4 responses per prompt
    num_iterations=1000,
    temperature=0.7,
    batch_size=8,
    eval_frequency=100
)

# Train with reinforcement learning
print("Starting GRPO training on GSM8K...")
print("This will take several hours...\n")

rl_metrics = grpo_learner.train()

print(f"\n{'='*50}")
print("GRPO Training Results")
print(f"{'='*50}")
print(f"  Pass@1 Accuracy: {rl_metrics['pass_at_1']:.2%}")
print(f"  Pass@5 Accuracy: {rl_metrics['pass_at_5']:.2%}")
print(f"  Answer Accuracy: {rl_metrics['answer_accuracy']:.2%}")
print(f"  Partial Accuracy: {rl_metrics['partial_accuracy']:.2%}")
print(f"  Format Accuracy: {rl_metrics['format_accuracy']:.2%}")
print(f"  Average Reward: {rl_metrics['avg_reward']:.3f}")
print(f"  Training time: {rl_metrics['training_time_hours']:.1f} hours")


### Detailed Evaluation

In [None]:
def evaluate_on_gsm8k(model, tokenizer, test_data, num_samples=1319):
    """Evaluate model on full GSM8K test set"""
    correct = 0
    partial_correct = 0
    format_correct = 0

    for i, example in enumerate(test_data[:num_samples]):
        # Generate model response
        prompt = f"Solve this math problem step by step:\n{example['question']}"
        inputs = tokenizer(prompt, return_tensors="jax", padding=True)
        outputs = model.generate(**inputs, max_length=512, temperature=0.0)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract final answer
        true_answer = float(example['answer'].split('####')[1].strip())
        has_format = '<<' in response and '>>' in response
        if has_format:
            format_correct += 1

        try:
            if '####' in response:
                pred_answer = float(response.split('####')[1].strip())
            else:
                numbers = [float(s) for s in response.split() if s.replace('.','').isdigit()]
                pred_answer = numbers[-1] if numbers else 0

            # Check correctness
            if abs(pred_answer - true_answer) < 0.01:
                correct += 1
            elif 0.9 <= (pred_answer / true_answer) <= 1.1:
                partial_correct += 1

        except (ValueError, IndexError):
            pass

        if (i + 1) % 100 == 0:
            print(f"Evaluated {i + 1}/{num_samples} examples...")

    return {
        "pass@1": correct / num_samples,
        "partial": partial_correct / num_samples,
        "format": format_correct / num_samples
    }

# Evaluate before and after GRPO
print("\nEvaluating base model...")
base_results = evaluate_on_gsm8k(base_model, tokenizer, gsm8k_test)

print("\nEvaluating GRPO-trained model...")
rl_results = evaluate_on_gsm8k(rl_model, tokenizer, gsm8k_test)

print(f"\n{'='*60}")
print("GSM8K Evaluation Results Comparison")
print(f"{'='*60}")
print(f"{'Metric':<25} {'Base Model':<15} {'After GRPO':<15} {'Improvement'}")
print(f"{'-'*60}")
print(f"{'Pass@1 Accuracy':<25} {base_results['pass@1']:<15.1%} {rl_results['pass@1']:<15.1%} {(rl_results['pass@1']-base_results['pass@1'])/base_results['pass@1']:+.1%}")
print(f"{'Partial Accuracy':<25} {base_results['partial']:<15.1%} {rl_results['partial']:<15.1%} {(rl_results['partial']-base_results['partial'])/base_results['partial']:+.1%}")
print(f"{'Format Accuracy':<25} {base_results['format']:<15.1%} {rl_results['format']:<15.1%} {(rl_results['format']-base_results['format'])/base_results['format']:+.1%}")
