In [None]:
from utils.models import get_model, get_tokenizer, ground_truth_reward_model
from utils.evaluation import evaluate_dpo, evaluate_ground_truth_rewards
from utils.data_preprocessing import CustomDataset, transform_df_for_dpo, precompute_reference_log_probs_batched
from utils.loss_functions import BetaDPOLoss
from utils.utils import get_log_probs, describe_rewards
from utils.preference_generation import determine_preference

import torch
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from peft import LoraConfig, PeftModel

from tqdm import tqdm
import pandas as pd
import numpy as np
import gc, os
import glob

config.json:   0%|          | 0.00/929 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/501M [00:00<?, ?B/s]

Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


model.safetensors:   0%|          | 0.00/501M [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Device set to use cuda:0


In [None]:
# --- Hyper‑parameters ---
wdir = '.'
MAX_LEN = 1024                         # truncate long GSM8K chains of thought
QUALITY = 'high'
OUTPUT_DIR = f"{wdir}/models/DPO_{QUALITY}"              # where to write LoRA adapter & tokenizer
BATCH_SIZE = 2
GRAD_ACCUM = 16                         # effective batch 32
LR = 1e-5
EPOCHS = 2
# Define total training steps
dataset_size = 10000
effective_batch_size = BATCH_SIZE * GRAD_ACCUM  # per_device_batch_size * num_gpus * grad_accum
TOTAL_STEP = (dataset_size // effective_batch_size + 1) * EPOCHS  # 684 steps
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_DIR, exist_ok=True)

prompt_length = 20
max_length = 196

rng = np.random.default_rng(42)
scale = 1.0

# β-DPO hyperparam
beta0        = 0.1      # base β
alpha        = 0.6      # β sensitivity to batch informativeness
ema_momentum = 0.9      # EMA for center/scale of gaps
rho_keep     = 0.8      # keep top-ρ after β-guided filtering

# EMA state
M_center, M_scale = 0.0, 1.0

In [None]:

# Get Model
model_name = 'google/gemma-3-270m'

policy_model = get_model(model_name).to(device)
policy_model = PeftModel.from_pretrained(policy_model,
                                         f'{wdir}/models/sft/best_model',
                                         adapter_name = 'sft')
ref_model = get_model(model_name).to(device)
ref_model = PeftModel.from_pretrained(ref_model,
                                     f'{wdir}/models/sft/best_model')

tok = get_tokenizer(model_name)

# LoRA Config (CHANGED for Gemma)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    # CHANGED: Target modules for Gemma
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    # CHANGED: Task type for Causal LM
    task_type="CAUSAL_LM"
)

policy_model.add_adapter("dpo", lora_config)  # Adds new empty trainable adapter
# Set active adapters for composition (SFT + DPO for policy)
policy_model.base_model.set_adapter(["sft", "dpo"])

for name, param in policy_model.named_parameters():
    if 'dpo' not in name:
        param.requires_grad = False

policy_model.print_trainable_parameters()

config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/536M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/133 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

trainable params: 1,474,560 || all params: 271,047,296 || trainable%: 0.5440


In [None]:
# Path to the folder containing the CSV files
path = f'{wdir}/data/'

# Get a list of all CSV files in the folder
all_files = glob.glob(path + f"dpo_data_{QUALITY}.csv")

print(all_files)
# Expected output: ['data/sales_jan.csv', 'data/sales_feb.csv', 'data/sales_mar.csv']
# Read each CSV file into a DataFrame and store them in a list
list_of_dfs = [pd.read_csv(file) for file in all_files]

# Concatenate all DataFrames in the list by row
df = pd.concat(list_of_dfs, ignore_index=True)

df['pref'] = df.apply(determine_preference,
                      axis=1,
                      args=(scale, rng)  # scale=2.0, rng=custom generator)
)
print(df.pref.value_counts())

# # --- DATA PROCESS ---
df = transform_df_for_dpo(df)
train_ref_y1, train_ref_y2 = precompute_reference_log_probs_batched(ref_model, tok, df[:dataset_size], effective_batch_size // 8, device, prompt_length, max_length)
eval_ref_y1, eval_ref_y2 = precompute_reference_log_probs_batched(ref_model, tok, df[dataset_size:], effective_batch_size // 8, device, prompt_length, max_length)


['/content/drive/Othercomputers/My MacBook Pro/Oracle DPO/data/dpo_data_high.csv']
pref
1 > 2    6021
2 > 1    5979
Name: count, dtype: int64
Pre-computing reference log probabilities in batches...


Batches: 100%|██████████| 2500/2500 [14:40<00:00,  2.84it/s]


Pre-computing reference log probabilities in batches...


Batches: 100%|██████████| 500/500 [02:55<00:00,  2.85it/s]


In [None]:
# --- Create DataLoaders ---
train_dataset = CustomDataset(df[:dataset_size], train_ref_y1, train_ref_y2)
eval_dataset = CustomDataset(df[dataset_size:], eval_ref_y1, eval_ref_y2)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda batch: batch)
eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, collate_fn=lambda batch: batch)

In [None]:
# --- Training Components ---
loss_fn = BetaDPOLoss()
# The optimizer will only see the trainable PEFT parameters
optimizer = torch.optim.AdamW(
    policy_model.parameters(),
    lr=LR,
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.01
)


# Cosine scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps = int(0.03 * TOTAL_STEP),
    num_training_steps = TOTAL_STEP
)

In [None]:
# --- Early Stopping and Model Saving Variables ---
print("\n--- Starting Training ---")

# Collect implicit gaps M = β0 * Δ̂ for the effective batch (CPU scalars)
global_step = 0
M_list = []
caches = []   # keep CPU copies of each micro-batch for Stage B

for epoch in range(EPOCHS):

    # Use enumerate to get the batch index 'i'
    for i, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")):
        with torch.no_grad():
            policy_model.eval()
            choices_mb = torch.stack([item['choices'] for item in batch]).float()  # [B]
            ref_log_probs_y1_mb = torch.stack([item['ref_log_probs_y1'] for item in batch])  # [B]
            ref_log_probs_y2_mb = torch.stack([item['ref_log_probs_y2'] for item in batch])  # [B]

            # Policy log-probs (no grad) for selection
            log_probs_y1_policy = get_log_probs(policy_model, tok, batch, 'first_responses',  device, prompt_length, max_length)  # [B]
            log_probs_y2_policy = get_log_probs(policy_model, tok, batch, 'second_responses', device, prompt_length, max_length)  # [B]

            # Δ̂ for each pair, signed by label (choices in {0,1})
            # Δ_raw = (lp1 - lr1) - (lp2 - lr2)
            Delta_raw = (log_probs_y1_policy - ref_log_probs_y1_mb.to(device)) - (log_probs_y2_policy - ref_log_probs_y2_mb.to(device))
            # signed Δ: if first is preferred (1) use Δ_raw; else use -Δ_raw
            sign = (1.0-2.0*choices_mb.to(device))  # 1 if first wins, -1 if second wins
            Delta_est = sign * Delta_raw              # [B]

            M_list.append((beta0 * Delta_est.detach().cpu()))  # CPU scalars

            # cache the CPU version of this micro-batch for Stage B
            caches.append({
                "batch": batch,  # keep as list of dicts; we'll subselect items later
                "choices": choices_mb,  # CPU
            })


        # --- Optimizer Step ---
        if (global_step + 1) % GRAD_ACCUM == 0:


            # Effective-batch stats
            M_eff = torch.cat(M_list)                  # [B_eff] on CPU
            M_bar = float(M_eff.mean().item())

            # Update EMA center/scale
            M_center = ema_momentum * M_center + (1 - ema_momentum) * M_bar
            M_scale  = ema_momentum * M_scale  + (1 - ema_momentum) * (float(M_eff.std().item()) + 1e-6)

            # Batch β (scalar)
            beta_batch = (1 + alpha * (M_bar - M_center)) * beta0

            # β-guided filtering: Gaussian score around EMA center
            scores = torch.exp(-0.5 * ((M_eff - M_center) / (M_scale + 1e-6))**2)
            k_keep = int(rho_keep * len(scores))
            weights = scores / scores.sum()                 # optional normalize; multinomial accepts unnormalized too
            keep_idx = torch.multinomial(weights, k_keep, replacement=False)
            keep_global_idx = set(keep_idx.tolist())

            policy_model.train()

            # Walk through the same micro-batches; for each, build a *kept* sub-batch and compute loss
            cursor = 0
            for mb in caches:
                batch_cpu = mb["batch"]                # list of dicts (CPU)
                Bk = len(batch_cpu)
                kept_idx = [i for i in range(Bk) if (cursor + i) in keep_global_idx]
                cursor += Bk
                if not kept_idx:
                    continue

                # Build a sub-batch (list of dicts) with only kept items
                sub_batch = [batch_cpu[i] for i in kept_idx]

                # Forward with grad
                log_probs_y1_policy = get_log_probs(policy_model, tok, sub_batch, 'first_responses',  device, prompt_length, max_length)  # [B_kept]
                log_probs_y2_policy = get_log_probs(policy_model, tok, sub_batch, 'second_responses', device, prompt_length, max_length)  # [B_kept]

                ref_log_probs_y1 = torch.stack([item['ref_log_probs_y1'] for item in sub_batch]).to(device)
                ref_log_probs_y2 = torch.stack([item['ref_log_probs_y2'] for item in sub_batch]).to(device)
                choices_kept     = torch.stack([item['choices']        for item in sub_batch]).float().to(device)

                # β-DPO loss: -log σ(β * Δ)
                loss_mb = loss_fn(
                    log_probs_y1_policy,
                    log_probs_y2_policy,
                    ref_log_probs_y1,
                    ref_log_probs_y2,
                    choices_kept,
                    beta_batch
                )

                # scale by accumulation factor
                (loss_mb / GRAD_ACCUM).backward()

                del sub_batch
                torch.cuda.empty_cache()

            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            # Reset caches
            M_list = []
            caches = []
        # --- Evaluation and Early Stopping Logic ---
        # This block should be inside the optimizer step block
        global_step += 1 # Increment global_step only when weights are updated

        gc.collect()
        torch.cuda.empty_cache()

print("\n--- Training Finished ---")


--- Starting Training ---


Epoch 1: 100%|██████████| 5000/5000 [1:27:50<00:00,  1.05s/it]
Epoch 2: 100%|██████████| 5000/5000 [1:27:12<00:00,  1.05s/it]


--- Training Finished ---





In [None]:
policy_model.eval()
rewards = evaluate_ground_truth_rewards(policy_model, ground_truth_reward_model, tok, eval_dataloader, prompt_length, max_length)
describe_rewards(rewards)