# Experiment 1:
# Supervised Fine-Tuning with LoRA on GSM8K

### Loading GSM8K Dataset

In [None]:
from datasets import load_dataset
import jax
import jax.numpy as jnp

# Load GSM8K dataset
dataset = load_dataset("gsm8k", "main")
train_data = dataset["train"]
test_data = dataset["test"]

print(f"Training examples: {len(train_data)}")
print(f"Test examples: {len(test_data)}")

# Let's look at a sample
sample = train_data[0]
print(f"\nSample Question: {sample['question']}")
print(f"Sample Answer: {sample['answer']}")


### Preparing the Dataset for Training

In [None]:
def format_gsm8k_for_training(example): # Formating GSM8K examples into instruction-response pairs
    instruction = f"Solve this math problem step by step:\n{example['question']}"
    response = example['answer']
    return {
        "instruction": instruction,
        "response": response
    }

# Process the dataset
formatted_train = [format_gsm8k_for_training(ex) for ex in train_data]
formatted_test = [format_gsm8k_for_training(ex) for ex in test_data]

print(f"Formatted {len(formatted_train)} training examples")
print(f"\nExample formatted input:")
print(formatted_train[0]['instruction'])
print(f"\nExpected output:")
print(formatted_train[0]['response'][:100] + "...")


### The Training Code

In [None]:
from tunix.trainers import PeftTrainer
from flax import linen as nn
from transformers import FlaxAutoModelForCausalLM, AutoTokenizer

# Load LLM model and tokenizer
model_name = "google/gemma-2b"
model = FlaxAutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Configure LoRA parameters
lora_config = {
    "rank": 8,
    "alpha": 16,
    "dropout": 0.1,
    "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"]
}

# Initialize the PeftTrainer (Only for supervised learning.)
trainer = PeftTrainer(
    model=model,
    tokenizer=tokenizer,
    lora_config=lora_config,
    learning_rate=2e-4,
    num_epochs=3,
    batch_size=4,
    max_length=512,
    warmup_steps=100
)

# Train the model
print("Starting training...")
metrics = trainer.train(formatted_train[:1000])  # Using subset for demo

print(f"\nTraining completed!")
print(f"  Initial loss: {metrics['initial_loss']:.4f}")
print(f"  Final loss: {metrics['final_loss']:.4f}")
print(f"  Training time: {metrics['training_time_minutes']:.1f} minutes")

# Save the fine-tuned model
trainer.save_checkpoint("./gemma-2b-gsm8k-lora")