# *Lab: DPO Training


In [40]:
import math
from typing import List, Optional, Tuple, Union
import os
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import json
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from functools import partial

## Data

In [72]:
from torch.utils.data import Dataset, DataLoader

def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )

    input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""

    return instruction_text + input_text + "\n\n### Response:\n"

class InstructionDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.prompt_with_completions = []
        self.completions = []
        
        for entry in data:
            
            instruction_plus_input = format_input(entry)
            completion = entry['output']
            
            self.prompt_with_completions.append(instruction_plus_input + completion)
            self.completions.append(completion)
        
            
    def __len__(self):
        return len(self.prompt_with_completions)
    
    def __getitem__(self, idx):
        return self.prompt_with_completions[idx], self.completions[idx]

def custom_collate_fn(batch, tokenizer, ignore_idx=-100):
    
    prompt_with_completions, completions = zip(*batch)
    
    padded = tokenizer(list(prompt_with_completions), padding='longest', truncation=True, return_tensors='pt')
    padded = padded['input_ids']
    inputs = padded[:,:-1]
    targets = padded[:,1:]
    
    mask = (targets == tokenizer.pad_token_id)
    
    targets = targets.masked_fill(mask, ignore_idx)
    
    return inputs, targets
    
     
def create_data_loader(data, tokenizer, batch_size=4, max_length=256, 
                       stride=128, shuffle=True, drop_last=True, num_workers=0):
    
    dataset = InstructionDataset(data=data)
    
    collate_fn = partial(custom_collate_fn, tokenizer=tokenizer, ignore_idx=-100)
    
    data_loader = DataLoader(dataset, 
                             batch_size=batch_size,
                             shuffle=shuffle,
                             drop_last=drop_last,
                             num_workers=num_workers,
                             collate_fn=collate_fn)
    
    return data_loader

In [73]:
def read_text_data(file_path, url):
    import urllib
    if not os.path.exists(file_path):
        with urllib.request.urlopen(url) as response:
            text_data = response.read().decode('utf-8')
        with open(file_path, "w", encoding="utf-8") as file:
            file.write(text_data)
    else:
        with open(file_path, "r", encoding="utf-8") as file:
            text_data = file.read()

    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)
          
    return data

In [74]:
import transformers
from transformers import AutoTokenizer
def test_data_component():
    file_path = "instruction-data2.json"
    url = (
        "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json"        
    )

    text_data = read_text_data(file_path, url)
    tokenizer_name = "Qwen/Qwen2.5-0.5B"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    #tokenizer = tiktoken.get_encoding('gpt2')
    train_loader = create_data_loader(data=text_data, 
                                      tokenizer=tokenizer)
    for batch in train_loader:
        print(batch)
        break
    
test_data_component()

(tensor([[ 38214,    374,    458,   7600,    429,  16555,    264,   3383,     13,
           9645,    264,   2033,    429,  34901,  44595,    279,   1681,    382,
          14374,  29051,    510,  31115,    264,  15908,  40158,   6524,    382,
          14374,   5571,    510,     40,   1079,  14589,    369,    537,  33094,
            847,  16319,    389,    882,    382,  14374,   5949,    510,  30665,
            508,    675,  49088,     40,  36879,    369,  33094,    847,  16319,
           3309,     13,    358,   3535,    429,  33094,    975,    389,    882,
            374,   1376,    311,    697,   6950,    304,    752,     13,    358,
           1079,   8480,    369,    279,   7626,    323,    358,   1896,   2480,
          38142,    369,    432,     13,    358,  58283,  22231,    894,  60009,
           8881,    553,    847,   6168,    382,     40,  15440,    429,    358,
            686,   1896,    678,   5871,   7354,    311,   5978,    429,   1741,
            458,  10455,   

### Preference Learning

In [None]:
def preference_loss(
    policy_chosen_logps: torch.FloatTensor, 
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
    beta: float,
    label_smoothing: float = 0.0,
    ipo: bool = False,
    reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """
    Args:
        policy_chosen_logps: log probabilities of the policy model for the chosen responses, shape: (batch_size,)
        
    
    """
    
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps
    
    logits = pi_logratios - ref_logratios
    
    if reference_free:
        ref_logratios = 0.0
    
    if ipo:
        losses = (logits - 1 / (2 * beta)) ** 2
    else:
        losses = - F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
        
    
    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
    
    return losses, chosen_rewards, rejected_rewards

def _get_batch_logps(
    logits: torch.FloatTensor,
    labels: torch.LongTensor,
    loss_mask: torch.LongTensor,
    average_log_prob: bool = False,
    ) -> torch.FloatTensor:
    
    """
    Args:
        logits: logits of the model output. Shape: (batch_size, seq_length, vocab_size)
        labels: labels for which token's log probability; label = -100 indicates ignore. Shape (batch_size, seq_length)

    """
    
    
    assert logits.shape[:-1] == labels.shape
    # let the sequence be A, B, C, D
    # labels[:,1ï¼š] are B, C, D
    # logits corresponds to B, C, D, X
    # logits[:,:-1,:] corresponds to B, C, D
    labels = labels[:,1:].clone() # labels 
    logits = logits[:,:-1,:]
    
    loss_mark = loss_mask[:, 1:]
    
    # shape (batch_size, seq_len - 1)
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsequeeze(2)).squeeze(2)
    
    if average_log_prob:
        return (per_token_logps * loss_mark).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)
    
def get_logps(outputs, labels, input_mask) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    
    all_logits = outputs.logits.to(torch.float32)
    all_logps = _get_batch_logps(all_logits, labels, input_mask, average_log_prob=False)
    
    batch_size = all_logps.shape[0]
    
    chosen_logps = all_logps[: batch_size // 2]
    rejected_logps = all_logps[batch_size//2:]
    
    return chosen_logps, rejected_logps
    
        

## Training

In [78]:
def compute_batch_loss(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    
    flat_targets = target_batch.flatten() 
    flat_logits = logits.flatten(0, 1)# flatten the first two dimensions 
    loss = F.cross_entropy(flat_logits, flat_targets) # tokens with ignore idx will not contribute to loss 
    return loss

def train_model_epoch(model, 
                train_loader,
                optimizer,
                device,
                num_epochs):
    
    train_losses, val_losses, track_token_seen = [],[],[]
    tokens_seen = 0
    global_steps = -1
    
    for epoch in range(num_epochs):
        model.train()
        
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad()
            loss = compute_batch_loss(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()
            tokens_seen += input_batch.numel()
            global_steps += 1
            train_losses.append(loss.detach().item())
            print(train_losses[-1])
        
    return train_losses, model

In [None]:
def train_dpo(model, ref_model, optimizer, train_loader, train_settings):
    
    for epoch in range(train_settings.num_epochs):
        
        for idx, batch in enumerate(train_loader):
            
            model_outputs = model(batch)
            
            reference_outputs = ref_model(batch)
            
            policy_chosen_logps, policy_rejected_logps = get_logps(model_outputs, batch['labels'])
            
            reference_chosen_logps, reference_rejected_logps = get_logps(reference_outputs, batch['labels'])
            
            loss_kwargs = {"beta": 0.1, "reference_free": False}
            
            losses, chosen_rewards, rejected_rewards = preference_loss(
                policy_chosen_logps=policy_chosen_logps,
                policy_rejected_logps=policy_rejected_logps,
                reference_chosen_logps=reference_chosen_logps,
                reference_rejected_logps=reference_rejected_logps,
                **loss_kwargs
            )
    
            loss = losses.mean()
            
            loss.backward()
            
            optimizer.step()
    

In [79]:
def train_main(model_config, train_settings):
    
    torch.manual_seed(train_settings.seed)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    text_data = read_text_data(train_settings.file_path, train_settings.url)
            
    model = LlamaForCausalLM(config=model_config)
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=train_settings.learning_rate,
                                  weight_decay=train_settings.weight_decay)
    
    # set up dataloader
    
    
    tokenizer = AutoTokenizer.from_pretrained(train_settings.pretrained_model_name)
    train_loader = create_data_loader(data=text_data, 
                                      tokenizer=tokenizer)
    
    
    train_loader = create_data_loader(data=text_data,
                                      tokenizer=tokenizer,
                                      batch_size=train_settings.batch_size,
                                      drop_last=True,
                                      shuffle=True,
                                        num_workers=0
    )
        
    train_losses, model = train_model_epoch(model=model,
                train_loader=train_loader,
                optimizer=optimizer,
                num_epochs=train_settings.num_epochs,
                device=device)
    
    return train_losses, model
    

In [None]:
if __name__ == '__main__':
    
    model_config = {
        "attention_dropout": 0.0,
        "bos_token_id": 151643,
        "eos_token_id": 151643,
        "pad_token_id": 151643,
        "hidden_act": "silu",
        "hidden_size": 896,
        "initializer_range": 0.02,
        "intermediate_size": 4864,
        "max_position_embeddings": 32768,
        "max_window_layers": 24,
        "model_type": "qwen2",
        "num_attention_heads": 14,
        "num_hidden_layers": 24,
        "num_key_value_heads": 2,
        "rms_norm_eps": 1e-06,
        "rope_theta": 1000000.0,
        "tie_word_embeddings": True,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.47.1",
        "use_cache": True,
        "use_mrope": False,
        "vocab_size": 151936,
        "qkv_bias": True,
        "o_bias": False,
        "mlp_bias": False
    }

    model_config = OmegaConf.create(model_config)
    train_settings = {
        "pretrained_model_name": "Qwen/Qwen2.5-0.5B",
        "learning_rate": 5e-6,
        "num_epochs": 10,
        "batch_size": 4,
        "weight_decay": 0.1,
        "stride": 128,
        "seed": 1,
        "file_path":"./instruction_data/instruction-data2.json",
        "url":"https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json"
    }
    
    train_settings = OmegaConf.create(train_settings)
    
    # train model
    train_losses, model = train_main(model_config=model_config,
                       train_settings=train_settings)
    
    print(train_losses)


12.114983558654785
11.961651802062988
11.927282333374023
11.826778411865234
11.707758903503418
11.793996810913086
11.565061569213867
11.341635704040527
11.44568920135498
11.51699161529541
11.538795471191406
11.205238342285156
11.476774215698242
11.2799654006958
10.608583450317383
10.964029312133789
11.31577205657959
10.75229549407959
10.765582084655762
10.845158576965332
10.3082857131958
10.781938552856445
11.183990478515625
10.394356727600098
10.934983253479004
9.947456359863281
10.557378768920898
10.630148887634277
10.455920219421387
