In [None]:
import os
import torch
from dotenv import load_dotenv #Load HuggingFace Token
load_dotenv()
# --- Hyper‑parameters ---
wdir = '.'
MAX_LEN = 1024                         # truncate long GSM8K chains of thought
OUTPUT_DIR = f'{wdir}/models/sft'              # where to write LoRA adapter & tokenizer
BATCH_SIZE = 4
GRAD_ACCUM = 8                         # effective batch 32
LR = 2e-4
EPOCHS = 5
# Define total training steps
dataset_size = 10000
effective_batch_size = BATCH_SIZE * GRAD_ACCUM  # per_device_batch_size * num_gpus * grad_accum
TOTAL_STEP = (dataset_size // effective_batch_size + 1) * EPOCHS  # 684 steps
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Training Hyperparameters
eval_steps = GRAD_ACCUM * 20
early_stopping_patience = 3



In [None]:
from utils.models import get_model, get_tokenizer, ground_truth_reward_model
from utils.data_loader import get_data
from utils.utils import get_log_probs_sft, describe_rewards
from utils.evaluation import evaluate_sft, evaluate_rewards

import torch
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model

from tqdm import tqdm
import gc

# CONFIG
batch_size = 100

prompt_length = 20
max_length = 196

# Get Model
model_name = 'google/gemma-3-270m'
base_model = get_model(model_name).to(device)
tok = get_tokenizer(model_name)
tok.padding_side = 'right'
# Get data
base_data = get_data('train', 18000, 25000)

In [None]:
for name, param in base_model.named_parameters():
    param.requires_grad = False
# 4. LoRA Config (CHANGED for Gemma)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    # CHANGED: Target modules for Gemma
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    # CHANGED: Task type for Causal LM
    task_type="CAUSAL_LM"
)

base_model = get_peft_model(base_model, lora_config)
base_model.print_trainable_parameters()

In [None]:

# The optimizer will only see the trainable PEFT parameters
optimizer = torch.optim.AdamW(
    base_model.parameters(),
    lr=LR,
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.01
)


# Cosine scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps = int(0.03 * TOTAL_STEP),
    num_training_steps = TOTAL_STEP
)

In [None]:
train_dataloader = DataLoader(base_data[:dataset_size], batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda batch: batch)
eval_dataloader = DataLoader(base_data[dataset_size:], batch_size=BATCH_SIZE, collate_fn=lambda batch: batch)

In [None]:
# --- Early Stopping and Model Saving Variables ---
best_eval_loss = float('inf')
patience_counter = 0
global_step = 0
best_eval_loss = evaluate_sft(base_model, eval_dataloader, tok, device, prompt_length, max_length)
print(f"\nStep {global_step}: Validation Loss = {best_eval_loss:.4f}")
print("\n--- Starting Training ---")
for epoch in range(EPOCHS):

    # Note: optimizer.zero_grad() is now inside the accumulation block

    # Use enumerate to get the batch index 'i'
    for i, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")):
        base_model.train()
        # --- Forward Pass ---
        batch_nll, batch_valid_tokens = get_log_probs_sft(base_model, tok, batch, device, prompt_length, max_length)

        # Compute per-token average loss for the batch
        loss = batch_nll.sum() / batch_valid_tokens.sum()

        loss = loss.mean()

        # --- Scale the Loss and Backpropagate ---
        loss = loss / GRAD_ACCUM
        loss.backward()

        # --- Optimizer Step ---
        if (global_step + 1) % GRAD_ACCUM == 0:
            torch.nn.utils.clip_grad_norm_(base_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # --- Evaluation and Early Stopping Logic ---
        # This block should be inside the optimizer step block
        global_step += 1 # Increment global_step only when weights are updated
        if global_step % eval_steps == 0:
            print(f"\nCurrent Step {global_step}")
            eval_loss = evaluate_sft(base_model, eval_dataloader, tok, device, prompt_length, max_length)
            print(f"\nStep {global_step}: Validation Loss = {eval_loss:.4f}")

            # 2. Save a checkpoint at every evaluation step
            checkpoint_dir = os.path.join(OUTPUT_DIR, f"checkpoint-{global_step}")
            print(f"\nSaving checkpoint to {checkpoint_dir}...")
            base_model.save_pretrained(checkpoint_dir)

            if eval_loss < best_eval_loss:
                print(f"\nValidation loss improved from {best_eval_loss} to {eval_loss}. Saving model...")
                best_model_dir = os.path.join(OUTPUT_DIR, "best_model")
                base_model.save_pretrained(best_model_dir)
                best_eval_loss = eval_loss
                patience_counter = 0  # Reset patience
            else:
                patience_counter += 1
                print(f"\nValidation loss did not improve. Patience: {patience_counter}/{early_stopping_patience}")

            if patience_counter >= early_stopping_patience:
                print("\nEarly stopping triggered.")
                break
        gc.collect()
        torch.cuda.empty_cache()

    if patience_counter >= early_stopping_patience:
        break

print("\n--- Training Finished ---")

In [None]:
rewards = evaluate_rewards(base_model, ground_truth_reward_model, tok, eval_dataloader, prompt_length, max_length)
describe_rewards(rewards)