# Experiment 1:
# Supervised Fine-Tuning with LoRA on GSM8K

### Loading GSM8K Dataset

In [None]:
from datasets import load_dataset
import jax
import jax.numpy as jnp

# Load GSM8K dataset
dataset = load_dataset("gsm8k", "main")
train_data = dataset["train"]
test_data = dataset["test"]

print(f"Training examples: {len(train_data)}")
print(f"Test examples: {len(test_data)}")

# Let's look at a sample
sample = train_data[0]
print(f"\nSample Question: {sample['question']}")
print(f"Sample Answer: {sample['answer']}")


: 

### Preparing the Dataset for Training

In [None]:
def format_gsm8k_for_training(example): # Formating GSM8K examples into instruction-response pairs
    instruction = f"Solve this math problem step by step:\n{example['question']}"
    response = example['answer']
    return {
        "instruction": instruction,
        "response": response
    }

# Process the dataset
formatted_train = [format_gsm8k_for_training(ex) for ex in train_data]
formatted_test = [format_gsm8k_for_training(ex) for ex in test_data]

print(f"Formatted {len(formatted_train)} training examples")
print(f"\nExample formatted input:")
print(formatted_train[0]['instruction'])
print(f"\nExpected output:")
print(formatted_train[0]['response'][:100] + "...")


### The Training Code

In [None]:
from tunix.trainers import PeftTrainer
from flax import linen as nn
from transformers import FlaxAutoModelForCausalLM, AutoTokenizer

# Load LLM model and tokenizer
model_name = "google/gemma-2b"
model = FlaxAutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Configure LoRA parameters
lora_config = {
    "rank": 8,
    "alpha": 16,
    "dropout": 0.1,
    "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"]
}

# Initialize the PeftTrainer (Only for supervised learning.)
trainer = PeftTrainer(
    model=model,
    tokenizer=tokenizer,
    lora_config=lora_config,
    learning_rate=2e-4,
    num_epochs=3,
    batch_size=4,
    max_length=512,
    warmup_steps=100
)

# Train the model
print("Starting training...")
metrics = trainer.train(formatted_train[:1000])  # Using subset for demo

print(f"\nTraining completed!")
print(f"  Initial loss: {metrics['initial_loss']:.4f}")
print(f"  Final loss: {metrics['final_loss']:.4f}")
print(f"  Training time: {metrics['training_time_minutes']:.1f} minutes")

# Save the fine-tuned model
trainer.save_checkpoint("./gemma-2b-gsm8k-lora")

# Experiment 2:
# Direct Preference Optimization (DPO)

### Creating Preference Pairs from GSM8K

In [None]:
def create_preference_pairs(examples, model, tokenizer):
    """Create preference pairs for DPO training"""
    preference_data = []

    for example in examples[:500]:  # Using subset
        prompt = f"Solve this math problem:\n{example['question']}\n\nProvide step-by-step reasoning."

        # The "chosen" response is the correct, well-formatted answer
        chosen = example['answer']

        # Create a "rejected" response (less structured)
        # In practice, you'd generate these or use model outputs
        rejected = example['answer'].split('####')[0].strip()
        rejected = rejected.replace('<<', '').replace('>>', '')

        preference_data.append({
            "prompt": prompt,
            "chosen": chosen,
            "rejected": rejected
        })

    return preference_data

preference_data = create_preference_pairs(train_data, model, tokenizer)

print(f"Created {len(preference_data)} preference pairs")
print(f"\nExample preference pair:")
print(f"Prompt: {preference_data[0]['prompt'][:100]}...")
print(f"\nChosen: {preference_data[0]['chosen'][:100]}...")
print(f"\nRejected: {preference_data[0]['rejected'][:100]}...")


### DPO Results

In [None]:
from tunix.trainers import DPOTrainer

# Load the SFT model from previous step
model = FlaxAutoModelForCausalLM.from_pretrained("./gemma-2b-gsm8k-lora")

# Initialize DPO trainer
dpo_trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    beta=0.1,  # KL penalty coefficient
    learning_rate=5e-7,
    num_epochs=2,
    batch_size=2,
    max_length=512
)

# Train with preference optimization
print("Starting DPO training...")
dpo_metrics = dpo_trainer.train(preference_data)

print(f"\nDPO Training Metrics:")
print(f"  Reward Margin: {dpo_metrics['reward_margin']:.3f}")
print(f"  Preference Accuracy: {dpo_metrics['accuracy']:.2%}")
print(f"  KL Divergence: {dpo_metrics['kl_div']:.4f}")
print(f"  Training time: {dpo_metrics['training_time_hours']:.1f} hours")

# Save aligned model
dpo_trainer.save_checkpoint("./gemma-2b-gsm8k-dpo")

# Experiment 3:
# GRPO for Enhanced Math Reasoning

In [None]:
from tunix.trainers import GRPOLearner
from tunix.environments import MathEnvironment
from datasets import load_dataset

# Load GSM8K test set for evaluation
gsm8k_test = load_dataset("gsm8k", "main", split="test")

# Initialize the math environment
math_env = MathEnvironment(
    dataset="gsm8k",
    reward_type="correctness",  # Reward based on correct final answer
    format_check=True
)

# Load base model (or your DPO model)
base_model = FlaxAutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it")

# Configure GRPO learner
grpo_learner = GRPOLearner(
    model=base_model,
    tokenizer=tokenizer,
    environment=math_env,
    learning_rate=1e-6,
    group_size=4,  # To generate 4 responses per prompt
    num_iterations=1000,
    temperature=0.7,
    batch_size=8,
    eval_frequency=100
)

# Train with reinforcement learning
print("Starting GRPO training on GSM8K...")
print("This will take several hours...\n")

rl_metrics = grpo_learner.train()

print(f"\n{'='*50}")
print("GRPO Training Results")
print(f"{'='*50}")
print(f"  Pass@1 Accuracy: {rl_metrics['pass_at_1']:.2%}")
print(f"  Pass@5 Accuracy: {rl_metrics['pass_at_5']:.2%}")
print(f"  Answer Accuracy: {rl_metrics['answer_accuracy']:.2%}")
print(f"  Partial Accuracy: {rl_metrics['partial_accuracy']:.2%}")
print(f"  Format Accuracy: {rl_metrics['format_accuracy']:.2%}")
print(f"  Average Reward: {rl_metrics['avg_reward']:.3f}")
print(f"  Training time: {rl_metrics['training_time_hours']:.1f} hours")


### Detailed Evaluation

In [None]:
def evaluate_on_gsm8k(model, tokenizer, test_data, num_samples=1319):
    """Evaluate model on full GSM8K test set"""
    correct = 0
    partial_correct = 0
    format_correct = 0

    for i, example in enumerate(test_data[:num_samples]):
        # Generate model response
        prompt = f"Solve this math problem step by step:\n{example['question']}"
        inputs = tokenizer(prompt, return_tensors="jax", padding=True)
        outputs = model.generate(**inputs, max_length=512, temperature=0.0)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract final answer
        true_answer = float(example['answer'].split('####')[1].strip())
        has_format = '<<' in response and '>>' in response
        if has_format:
            format_correct += 1

        try:
            if '####' in response:
                pred_answer = float(response.split('####')[1].strip())
            else:
                numbers = [float(s) for s in response.split() if s.replace('.','').isdigit()]
                pred_answer = numbers[-1] if numbers else 0

            # Check correctness
            if abs(pred_answer - true_answer) < 0.01:
                correct += 1
            elif 0.9 <= (pred_answer / true_answer) <= 1.1:
                partial_correct += 1

        except (ValueError, IndexError):
            pass

        if (i + 1) % 100 == 0:
            print(f"Evaluated {i + 1}/{num_samples} examples...")

    return {
        "pass@1": correct / num_samples,
        "partial": partial_correct / num_samples,
        "format": format_correct / num_samples
    }

# Evaluate before and after GRPO
print("\nEvaluating base model...")
base_results = evaluate_on_gsm8k(base_model, tokenizer, gsm8k_test)

print("\nEvaluating GRPO-trained model...")
rl_results = evaluate_on_gsm8k(rl_model, tokenizer, gsm8k_test)

print(f"\n{'='*60}")
print("GSM8K Evaluation Results Comparison")
print(f"{'='*60}")
print(f"{'Metric':<25} {'Base Model':<15} {'After GRPO':<15} {'Improvement'}")
print(f"{'-'*60}")
print(f"{'Pass@1 Accuracy':<25} {base_results['pass@1']:<15.1%} {rl_results['pass@1']:<15.1%} {(rl_results['pass@1']-base_results['pass@1'])/base_results['pass@1']:+.1%}")
print(f"{'Partial Accuracy':<25} {base_results['partial']:<15.1%} {rl_results['partial']:<15.1%} {(rl_results['partial']-base_results['partial'])/base_results['partial']:+.1%}")
print(f"{'Format Accuracy':<25} {base_results['format']:<15.1%} {rl_results['format']:<15.1%} {(rl_results['format']-base_results['format'])/base_results['format']:+.1%}")


# Experiment 4:
# Knowledge Distillation

In [None]:
from tunix.trainers import DistillationTrainer

# Load teacher (our best GRPO model) and student models
teacher_model = FlaxAutoModelForCausalLM.from_pretrained("./gemma-2b-gsm8k-grpo")
student_model = FlaxAutoModelForCausalLM.from_pretrained("google/gemma-1b")

# Prepare distillation data from GSM8K
distill_data = formatted_train[:2000]

# Configure distillation
distill_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    tokenizer=tokenizer,
    distillation_type="logit",  # or "attention" for attention transfer
    temperature=2.0,  # Softens probability distributions
    alpha=0.5,  # Balance between distillation and task loss
    learning_rate=5e-5,
    num_epochs=5,
    batch_size=8
)

# Distill knowledge
print("Starting knowledge distillation...")
distill_metrics = distill_trainer.train(distill_data)

print(f"\nDistillation completed!")
print(f"  Initial distillation loss: {distill_metrics['initial_loss']:.4f}")
print(f"  Final distillation loss: {distill_metrics['final_loss']:.4f}")
print(f"  Training time: {distill_metrics['training_time_hours']:.1f} hours")

# Evaluate student model on GSM8K
print("\nEvaluating student model on GSM8K...")
student_results = evaluate_on_gsm8k(student_model, tokenizer, gsm8k_test)

print(f"\n{'='*60}")
print("Distillation Results")
print(f"{'='*60}")
print(f"  Teacher Pass@1: {rl_results['pass@1']:.1%}")
print(f"  Student Pass@1: {student_results['pass@1']:.1%}")
print(f"  Accuracy Retention: {(student_results['pass@1']/rl_results['pass@1']):.1%}")
print(f"  Model Size: {distill_metrics['teacher_size_mb']:.1f} MB â†’ {distill_metrics['student_size_mb']:.1f} MB")
print(f"  Size Reduction: {(1 - distill_metrics['student_size_mb']/distill_metrics['teacher_size_mb']):.1%}")
print(f"  Inference Speedup: {distill_metrics['speedup']:.2f}x")

# Save distilled model
distill_trainer.save_checkpoint("./gemma-1b-gsm8k-distilled")
