# DPO off-policy Demostration

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch.nn.functional as F

# -----------------------
# 配置
# -----------------------
MODEL_NAME = "gpt2"
BATCH_SIZE = 2
LR = 5e-6
EPOCHS = 1
BETA = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------
# 加载模型和 tokenizer
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)

# reference policy，冻结参数
ref_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
ref_model.eval()  # 不训练

# -----------------------
# 加载偏好数据集
# 数据集格式: { "prompt": ..., "chosen": ..., "rejected": ... }
# -----------------------
dataset = load_dataset("Dahoas/rm-static")["train"]

# -----------------------
# 数据处理函数
# -----------------------
def tokenize_pair(example):
    prompt_ids = tokenizer(example["prompt"], return_tensors="pt").input_ids[0]
    chosen_ids = tokenizer(example["chosen"], return_tensors="pt").input_ids[0]
    rejected_ids = tokenizer(example["rejected"], return_tensors="pt").input_ids[0]
    return {
        "prompt_ids": prompt_ids,
        "chosen_ids": chosen_ids,
        "rejected_ids": rejected_ids
    }

dataset = dataset.map(tokenize_pair)

# -----------------------
# DataLoader
# -----------------------
def collate_fn(batch):
    return batch

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# -----------------------
# 优化器
# -----------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

# -----------------------
# 训练循环
# -----------------------
for epoch in range(EPOCHS):
    for batch in loader:
        optimizer.zero_grad()
        losses = []

        for item in batch:
            prompt_ids = item["prompt_ids"].to(DEVICE)
            chosen_ids = item["chosen_ids"].to(DEVICE)
            rejected_ids = item["rejected_ids"].to(DEVICE)

            # 模型 log probs
            chosen_logits = model(input_ids=prompt_ids.unsqueeze(0), labels=chosen_ids.unsqueeze(0)).logits
            rejected_logits = model(input_ids=prompt_ids.unsqueeze(0), labels=rejected_ids.unsqueeze(0)).logits

            # reference log probs
            with torch.no_grad():
                chosen_logits_ref = ref_model(input_ids=prompt_ids.unsqueeze(0), labels=chosen_ids.unsqueeze(0)).logits
                rejected_logits_ref = ref_model(input_ids=prompt_ids.unsqueeze(0), labels=rejected_ids.unsqueeze(0)).logits

            # log_softmax
            chosen_log_probs = F.log_softmax(chosen_logits, dim=-1)
            rejected_log_probs = F.log_softmax(rejected_logits, dim=-1)
            chosen_log_probs_ref = F.log_softmax(chosen_logits_ref, dim=-1)
            rejected_log_probs_ref = F.log_softmax(rejected_logits_ref, dim=-1)

            # 聚合序列 log prob
            chosen_seq_log_prob = chosen_log_probs.gather(2, chosen_ids.unsqueeze(0).unsqueeze(-1)).squeeze(-1).sum()
            rejected_seq_log_prob = rejected_log_probs.gather(2, rejected_ids.unsqueeze(0).unsqueeze(-1)).squeeze(-1).sum()
            chosen_seq_log_prob_ref = chosen_log_probs_ref.gather(2, chosen_ids.unsqueeze(0).unsqueeze(-1)).squeeze(-1).sum()
            rejected_seq_log_prob_ref = rejected_log_probs_ref.gather(2, rejected_ids.unsqueeze(0).unsqueeze(-1)).squeeze(-1).sum()

            # DPO loss
            log_ratio_chosen = chosen_seq_log_prob - chosen_seq_log_prob_ref
            log_ratio_rejected = rejected_seq_log_prob - rejected_seq_log_prob_ref
            loss = -torch.log(torch.sigmoid((log_ratio_chosen - log_ratio_rejected)/BETA))
            losses.append(loss)

        batch_loss = torch.stack(losses).mean()
        batch_loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} Loss: {batch_loss.item():.4f}")
