In [8]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [9]:
import os, sys
project_root = os.path.abspath('/Users/subhojit/workspace/saturn/src')
if project_root not in sys.path:
    sys.path.append(project_root)

from reward_model import RewardModel
from dataloader import PreferenceDatasetLite
import matplotlib.pyplot as plt
%matplotlib inline

In [10]:
def pairwise_reward_loss(chosen_rewards, rejected_rewards):
    return -torch.nn.functional.logsigmoid(chosen_rewards - rejected_rewards).mean()

In [12]:
from huggingface_hub import login
login(token='')

In [14]:
model_name = "Qwen/Qwen1.5-0.5B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen1.5-0.5B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
from datasets import load_dataset
dataset = load_dataset("Anthropic/hh-rlhf")["train"]

Generating train split: 100%|██████████| 160800/160800 [00:00<00:00, 315531.02 examples/s]
Generating test split: 100%|██████████| 8552/8552 [00:00<00:00, 250338.05 examples/s]


In [16]:
dataset[0]

{'chosen': "\n\nHuman: What are some cuss words in english?\n\nAssistant: Here’s an incomplete list.\n\nAss, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama’s boy, faggot, pervert, queer, scumbag, bitch,\n\nHuman: What's your favorite one?\n\nAssistant: I haven't even thought about it.",
 'rejected': "\n\nHuman: What are some cuss words in english?\n\nAssistant: Here’s an incomplete list.\n\nAss, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, ji

In [24]:
def preprocess(example):
    chosen = tokenizer(
        example["chosen"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )
    rejected = tokenizer(
        example["rejected"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )
    return {
        "input_ids_chosen": chosen["input_ids"].squeeze(0),
        "attention_mask_chosen": chosen["attention_mask"].squeeze(0),
        "input_ids_rejected": rejected["input_ids"].squeeze(0),
        "attention_mask_rejected": rejected["attention_mask"].squeeze(0),
    }

pre_processed = dataset.map(preprocess)

Map: 100%|██████████| 160800/160800 [01:43<00:00, 1559.14 examples/s]


In [19]:
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Mini wrapper
class RewardDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids_chosen": item["input_ids_chosen"],
            "attention_mask_chosen": item["attention_mask_chosen"],
            "input_ids_rejected": item["input_ids_rejected"],
            "attention_mask_rejected": item["attention_mask_rejected"]
        }

    def __len__(self):
        return len(self.data)

dataloader = DataLoader(RewardDataset(pre_processed), batch_size=4, shuffle=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [23]:
model.train()
for epoch in range(1):  # You can extend this
    for batch in dataloader:
        # Forward pass for both chosen and rejected
        reward_chosen = model(
            input_ids=batch["input_ids_chosen"],
            attention_mask=batch["attention_mask_chosen"]
        ).logits.squeeze()

        reward_rejected = model(
            input_ids=batch["input_ids_rejected"],
            attention_mask=batch["attention_mask_rejected"]
        ).logits.squeeze()

        # Pairwise loss: max(0, 1 - (r_c - r_r))
        loss = -F.logsigmoid(reward_chosen - reward_rejected).mean()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch} — Loss: {loss.item():.4f}")

KeyboardInterrupt: 