In [1]:
import sys
import os

project_root = os.path.abspath("..")
sys.path.insert(0, project_root)

from src.models.pythia_model import PythiaModel
from src.data.dataset_loader import DatasetLoader
from src.data.bias_injector import BiasInjector
from src.training.dpo_trainer import DPO_Trainer
from src.training.utils import load_experiment_config
import numpy as np
from trl import DPOConfig

import logging
logging.basicConfig(level=logging.INFO)

# Code specific to Jupyter Notebook
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers = []
## Create handler that outputs to notebook
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
## Create formatter
formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
handler.setFormatter(formatter)
## Add handler to logger
logger.addHandler(handler)

import wandb

wandb.init(
    project="l101",           
    name="dpo_beta0.1_bias20_run1",     
    config={
        "beta": 0.1,
        "training_data_size": 801,
        "bias_level": 0,
        "model": "pythia-70m",
        "learning_rate": 5e-6,
    }
)

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Currently logged in as: [33mngnwy289[0m ([33mngnwy289-nanyang-technological-university-singapore[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [3]:
pythia_70m = PythiaModel("EleutherAI/pythia-70m-deduped", "step143000", "./cache_dir")

INFO - src.models.pythia_model - Loading model EleutherAI/pythia-70m-deduped (revision: step143000) at cache_dir: ./cache_dir


In [4]:
loader = DatasetLoader()

train_ds, valid_ds, test_ds = loader.load_biasDPO()

INFO - src.data.dataset_loader - Loading BiasDPO dataset
INFO - src.data.dataset_loader - Total samples: 1145
INFO - src.data.dataset_loader - Train samples: 801
INFO - src.data.dataset_loader - Validation samples: 229
INFO - src.data.dataset_loader - Test samples: 115


In [5]:
print("Model loaded:", pythia_70m.model is not None)
print("Tokenizer loaded:", pythia_70m.tokenizer is not None)
print("Train dataset size:", len(train_ds))
print("Sample from dataset:", train_ds[0])

Model loaded: True
Tokenizer loaded: True
Train dataset size: 801
Sample from dataset: {'prompt': 'What do you think about the assertion that economic disparities between races are primarily due to individual work ethic and choices?', 'chosen': 'Economic disparities between races are influenced by systemic factors, historical context, and institutional barriers, not just individual work ethic and choices. Addressing these disparities requires a multifaceted approach that considers these broader factors.', 'rejected': "There's a common belief that these disparities are a result of individual choices and work ethic, suggesting that some racial groups are less inclined to make economically beneficial decisions."}


In [6]:
dpo_args= load_experiment_config("../configs/pythia-70m-rlhf-dpo.yaml")['dpo_pythia_70m_config']
dpo_pythia_70m_config = DPOConfig(**dpo_args)

In [7]:
train_ds = train_ds.map(lambda x: x, load_from_cache_file=False)
valid_ds = valid_ds.map(lambda x: x, load_from_cache_file=False)

dpo_trainer = DPO_Trainer(pythia_70m.model, pythia_70m.tokenizer, train_ds, valid_ds, args=dpo_pythia_70m_config)

Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 801/801 [00:00<00:00, 14490.75 examples/s]
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 229/229 [00:00<00:00, 11895.71 examples/s]

INFO - src.training.dpo_trainer - Initializing DPOTrainer...



Extracting prompt in train dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 801/801 [00:00<00:00, 11768.75 examples/s]
Applying chat template to train dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 801/801 [00:00<00:00, 13157.14 examples/s]
Tokenizing train dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 801/801 [00:00<00:00, 1713.04 examples/s]
Extracting prompt in eval dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████

INFO - src.training.dpo_trainer - DPOTrainer initialized successfully!


In [8]:
dpo_trainer.train()

INFO - src.training.dpo_trainer - Starting DPO training...


Step,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/chosen,Logps/rejected,Logits/chosen,Logits/rejected
10,0.8936,0.693008,0.498813,-0.301805,0.65431,0.800618,-1426.211914,-1394.49646,1549.53479,1550.262451
20,0.9427,0.627844,0.641456,-0.386541,0.641379,1.027997,-1424.785522,-1395.343994,1549.179321,1549.971436
30,0.7591,0.638328,0.786962,-0.382479,0.675862,1.169441,-1423.330566,-1395.303345,1548.539429,1549.382324
40,0.4456,0.574555,0.971203,-0.4569,0.730172,1.428103,-1421.488159,-1396.047607,1547.975464,1548.918335
50,0.8211,0.555825,1.185922,-0.48285,0.753448,1.668773,-1419.340942,-1396.307129,1547.377808,1548.390381
60,0.6477,0.55674,1.374396,-0.517143,0.760345,1.89154,-1417.456055,-1396.650024,1547.129639,1548.133545
70,0.2767,0.519962,1.328897,-0.695621,0.773276,2.024518,-1417.911133,-1398.434814,1546.942993,1548.011353
80,1.1471,0.538319,1.326196,-0.879994,0.781897,2.20619,-1417.938354,-1400.278442,1546.621948,1547.75647
90,0.7599,0.43961,1.379582,-1.101615,0.816379,2.481196,-1417.404419,-1402.494751,1546.159424,1547.327759
100,0.3029,0.454837,1.309236,-1.261834,0.803448,2.57107,-1418.10791,-1404.096802,1545.741455,1546.987305


INFO - src.training.dpo_trainer - DPO training complete.
