In [None]:
from utils.models import get_model, get_tokenizer, ground_truth_reward_model
from utils.evaluation import evaluate_dpo, evaluate_ground_truth_rewards
from utils.data_preprocessing import CustomDataset, transform_df_for_dpo, precompute_reference_log_probs_batched
from utils.loss_functions import DPOLoss
from utils.utils import get_log_probs, describe_rewards
from utils.preference_generation import determine_preference

import torch
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from peft import LoraConfig, PeftModel

from tqdm import tqdm
import pandas as pd
import numpy as np
import gc, os
import glob

In [None]:
# --- Hyper‑parameters ---
wdir = '.'
MAX_LEN = 1024                         # truncate long GSM8K chains of thought
QUALITY = 'low'
OUTPUT_DIR = f"{wdir}/models/DPO_{QUALITY}"              # where to write LoRA adapter & tokenizer
BATCH_SIZE = 16
GRAD_ACCUM = 2                         # effective batch 32
LR = 1e-5
EPOCHS = 2
# Define total training steps
dataset_size = 10000
effective_batch_size = BATCH_SIZE * GRAD_ACCUM  # per_device_batch_size * num_gpus * grad_accum
TOTAL_STEP = (dataset_size // effective_batch_size + 1) * EPOCHS  # 684 steps
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_DIR, exist_ok=True)

prompt_length = 20
max_length = 196

rng = np.random.default_rng(42)
scale = 1.0

In [None]:

# Get Model
model_name = 'google/gemma-3-270m'

policy_model = get_model(model_name).to(device)
policy_model = PeftModel.from_pretrained(policy_model,
                                         f'{wdir}/models/sft/best_model',
                                         adapter_name = 'sft')
ref_model = get_model(model_name).to(device)
ref_model = PeftModel.from_pretrained(ref_model,
                                     f'{wdir}/models/sft/best_model')

tok = get_tokenizer(model_name)

# LoRA Config (CHANGED for Gemma)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    # CHANGED: Target modules for Gemma
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    # CHANGED: Task type for Causal LM
    task_type="CAUSAL_LM"
)

policy_model.add_adapter("dpo", lora_config)  # Adds new empty trainable adapter
# Set active adapters for composition (SFT + DPO for policy)
policy_model.base_model.set_adapter(["sft", "dpo"])

for name, param in policy_model.named_parameters():
    if 'dpo' not in name:
        param.requires_grad = False

policy_model.print_trainable_parameters()

In [None]:
# Path to the folder containing the CSV files
path = f'{wdir}/data/'

# Get a list of all CSV files in the folder
all_files = glob.glob(path + f"dpo_data_{QUALITY}.csv")

print(all_files)
# Expected output: ['data/sales_jan.csv', 'data/sales_feb.csv', 'data/sales_mar.csv']
# Read each CSV file into a DataFrame and store them in a list
list_of_dfs = [pd.read_csv(file) for file in all_files]

# Concatenate all DataFrames in the list by row
df = pd.concat(list_of_dfs, ignore_index=True)

df['pref'] = df.apply(determine_preference, 
                      axis=1,
                      args=(scale, rng)  # scale=2.0, rng=custom generator)
)
print(df.pref.value_counts())

# # --- DATA PROCESS ---
df = transform_df_for_dpo(df)
train_ref_y1, train_ref_y2 = precompute_reference_log_probs_batched(ref_model, tok, df[:dataset_size], effective_batch_size // 8, device, prompt_length, max_length)
eval_ref_y1, eval_ref_y2 = precompute_reference_log_probs_batched(ref_model, tok, df[dataset_size:], effective_batch_size // 8, device, prompt_length, max_length)


In [None]:
# --- Create DataLoaders ---
train_dataset = CustomDataset(df[:dataset_size], train_ref_y1, train_ref_y2)
eval_dataset = CustomDataset(df[dataset_size:], eval_ref_y1, eval_ref_y2)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda batch: batch)
eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, collate_fn=lambda batch: batch)

In [None]:
# --- Training Components ---
loss_fn = DPOLoss(0.1)
# The optimizer will only see the trainable PEFT parameters
optimizer = torch.optim.AdamW(
    policy_model.parameters(),
    lr=LR,
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.01
)


# Cosine scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps = int(0.03 * TOTAL_STEP),
    num_training_steps = TOTAL_STEP
)

In [None]:
# --- Early Stopping and Model Saving Variables ---
global_step = 0
print("\n--- Starting Training ---")
for epoch in range(EPOCHS):

    # Note: optimizer.zero_grad() is now inside the accumulation block

    # Use enumerate to get the batch index 'i'
    for i, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")):
        policy_model.train()
        # --- Forward Pass ---
        log_probs_y1_policy = get_log_probs(policy_model, tok, batch, 'first_responses', device, prompt_length, max_length)
        log_probs_y2_policy = get_log_probs(policy_model, tok, batch, 'second_responses', device, prompt_length, max_length)

        # Stack the tensor items
        choices = torch.stack([item['choices'] for item in batch])
        ref_log_probs_y1 = torch.stack([item['ref_log_probs_y1'] for item in batch])
        ref_log_probs_y2 = torch.stack([item['ref_log_probs_y2'] for item in batch])

        loss = loss_fn(
            log_probs_y1_policy,
            log_probs_y2_policy,
            ref_log_probs_y1.to(device),
            ref_log_probs_y2.to(device),
            choices.to(device)
        )

        # --- Scale the Loss and Backpropagate ---
        loss = loss / GRAD_ACCUM
        loss.backward()

        # --- Optimizer Step ---
        if (global_step + 1) % GRAD_ACCUM == 0:
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        global_step += 1 # Increment global_step only when weights are updated

        gc.collect()
        torch.cuda.empty_cache()

print("\n--- Training Finished ---")

In [None]:
policy_model.eval()
rewards = evaluate_ground_truth_rewards(policy_model, ground_truth_reward_model, tok, eval_dataloader, prompt_length, max_length)
describe_rewards(rewards)