In [8]:
# ===============================================================================
# DPO（Direct Preference Optimization）算法实现
# DPO通过人类偏好数据直接优化语言模型，使其生成更符合人类偏好的输出
# 这里面使用了一个偏好prefer以及两个reject的格式
# ===============================================================================
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from copy import deepcopy

torch.manual_seed(0)

<torch._C.Generator at 0x110735f10>

In [9]:
# 加载模型
# 创建简化版的Llama模型作为策略模型（将被优化的模型）
policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=1000, num_hidden_layers=1, hidden_size=128))
# 创建参考模型（通常是SFT模型，在训练过程中保持不变）
reference_model = deepcopy(policy_model)  # 深度复制确保两个模型初始参数完全相同
policy_model


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(1000, 128)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=128, out_features=128, bias=False)
          (k_proj): Linear(in_features=128, out_features=128, bias=False)
          (v_proj): Linear(in_features=128, out_features=128, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=128, out_features=11008, bias=False)
          (up_proj): Linear(in_features=128, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=128, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((128,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((128,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((128,), eps=1e-06)
    (rotary_emb): LlamaRotaryEmbeddin

In [12]:
# 超参数
beta = 0.1  # DPO的温度系数，控制策略模型与参考模型的偏离程度，值越小允许偏离越大

# 准备训练数据
# 在DPO中，我们需要提示(prompt)、优选回答(chosen/good)和拒绝回答(rejected/bad)
prompt_ids = [1, 2, 3, 4, 5, 6]  # 输入提示的token IDs
good_response_ids = [7, 8, 9, 10]  # 优质回答的token IDs
# 多个低质量回答的示例，每个都是token IDs的列表
bad_response_ids_list = [[1, 2, 0, 0], [4, 5, 6, 0]]

In [14]:
# 构建模型输入：将提示与回答拼接
# 创建包含多个序列的批次：[提示+优质回答, 提示+低质回答1, 提示+低质回答2, ...]
input_ids = torch.LongTensor(
    [prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]]
)
input_ids

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  1,  2,  0,  0],
        [ 1,  2,  3,  4,  5,  6,  4,  5,  6,  0]])

In [15]:
# 准备用于计算语言模型损失的标签
# 在语言模型训练中，标签是输入向右移动一位（预测下一个token）
# -100表示在计算损失时忽略该位置（这里忽略提示部分）
labels = torch.LongTensor(
    [
        [-100] * len(prompt_ids) + good_response_ids,
        *[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list]
    ]
)[:, 1:]  # 向右移动一位，因为我们预测的是下一个token
labels

tensor([[-100, -100, -100, -100, -100,    7,    8,    9,   10],
        [-100, -100, -100, -100, -100,    1,    2,    0,    0],
        [-100, -100, -100, -100, -100,    4,    5,    6,    0]])

In [17]:
# 创建掩码，用于标识哪些位置参与损失计算（即回答部分）
loss_mask = (labels != -100)
# 将-100替换为0，因为在gather操作中-100是无效索引
labels[labels == -100] = 0
labels

tensor([[ 0,  0,  0,  0,  0,  7,  8,  9, 10],
        [ 0,  0,  0,  0,  0,  1,  2,  0,  0],
        [ 0,  0,  0,  0,  0,  4,  5,  6,  0]])

In [19]:
# ===============================================================================
# 计算策略模型（policy model）的对数概率
# ===============================================================================
# 前向传播，获取每个token位置的预测logits
logits = policy_model(input_ids)["logits"][:, :-1, :]  # 去掉最后一个位置，与label对齐
logits

tensor([[[ 0.1456,  0.1039, -0.6784,  ..., -0.0864, -0.0790,  0.1382],
         [ 0.2436,  0.1769, -0.0961,  ...,  0.0505,  0.1151, -0.1636],
         [ 0.1643,  0.1157, -0.0293,  ...,  0.1549, -0.3632, -0.1554],
         ...,
         [-0.3279,  0.2739, -0.4922,  ..., -0.0714,  0.2901, -0.3667],
         [-0.0849,  0.2404, -0.3772,  ..., -0.1673, -0.1023, -0.2427],
         [-0.0503,  0.2694, -0.4878,  ...,  0.0252,  0.3393, -0.6586]],

        [[ 0.1456,  0.1039, -0.6784,  ..., -0.0864, -0.0790,  0.1382],
         [ 0.2436,  0.1769, -0.0961,  ...,  0.0505,  0.1151, -0.1636],
         [ 0.1643,  0.1157, -0.0293,  ...,  0.1549, -0.3632, -0.1554],
         ...,
         [ 0.3901,  0.2500,  0.0026,  ...,  0.1113,  0.1034,  0.0109],
         [ 0.2349,  0.2526,  0.0359,  ..., -0.0251, -0.1628, -0.4019],
         [ 0.4680,  0.1479, -0.3143,  ...,  0.0656,  0.0351,  0.1007]],

        [[ 0.1456,  0.1039, -0.6784,  ..., -0.0864, -0.0790,  0.1382],
         [ 0.2436,  0.1769, -0.0961,  ...,  0

In [20]:
# 将logits转换为对数概率，并提取每个位置上正确token的对数概率
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps

tensor([[-6.7857, -6.6947, -6.7697, -6.5169, -6.6811, -7.0314, -6.7662, -6.5559,
         -6.6359],
        [-6.7857, -6.6947, -6.7697, -6.5169, -6.6811, -7.1295, -6.9317, -6.6988,
         -6.4697],
        [-6.7857, -6.6947, -6.7697, -6.5169, -6.6811, -6.7405, -6.8247, -7.2198,
         -6.8045]], grad_fn=<SqueezeBackward1>)

In [21]:
# 仅对回答部分（loss_mask=True的位置）求和，得到每个序列的总对数概率
all_logps = (per_token_logps * loss_mask).sum(-1)
all_logps

tensor([-60.4376, -60.6779, -61.0376], grad_fn=<SumBackward1>)

In [22]:
# 分离优质回答和低质量回答的对数概率
policy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]

In [23]:
policy_good_logps

tensor([-60.4376], grad_fn=<SliceBackward0>)

In [24]:
policy_bad_logps

tensor([-60.6779, -61.0376], grad_fn=<SliceBackward0>)

In [25]:
# ===============================================================================
# 计算参考模型（reference model）的对数概率
# ===============================================================================
with torch.no_grad():  # 不计算梯度，因为参考模型不需要更新
    # 重复与策略模型相同的步骤
    logits = reference_model(input_ids)["logits"][:, :-1, :]
    print("logits:\n",logits)
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
    print("per_token_logps:\n",per_token_logps)
    all_logps = (per_token_logps * loss_mask).sum(-1)
    print("all_logps\n",all_logps)
    reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]
    print("reference_good_logps:\n",reference_good_logps)
    print("reference_bad_logps\n",reference_bad_logps)

logits:
 tensor([[[ 0.1456,  0.1039, -0.6784,  ..., -0.0864, -0.0790,  0.1382],
         [ 0.2436,  0.1769, -0.0961,  ...,  0.0505,  0.1151, -0.1636],
         [ 0.1643,  0.1157, -0.0293,  ...,  0.1549, -0.3632, -0.1554],
         ...,
         [-0.3279,  0.2739, -0.4922,  ..., -0.0714,  0.2901, -0.3667],
         [-0.0849,  0.2404, -0.3772,  ..., -0.1673, -0.1023, -0.2427],
         [-0.0503,  0.2694, -0.4878,  ...,  0.0252,  0.3393, -0.6586]],

        [[ 0.1456,  0.1039, -0.6784,  ..., -0.0864, -0.0790,  0.1382],
         [ 0.2436,  0.1769, -0.0961,  ...,  0.0505,  0.1151, -0.1636],
         [ 0.1643,  0.1157, -0.0293,  ...,  0.1549, -0.3632, -0.1554],
         ...,
         [ 0.3901,  0.2500,  0.0026,  ...,  0.1113,  0.1034,  0.0109],
         [ 0.2349,  0.2526,  0.0359,  ..., -0.0251, -0.1628, -0.4019],
         [ 0.4680,  0.1479, -0.3143,  ...,  0.0656,  0.0351,  0.1007]],

        [[ 0.1456,  0.1039, -0.6784,  ..., -0.0864, -0.0790,  0.1382],
         [ 0.2436,  0.1769, -0.0961,

In [26]:
# ===============================================================================
# 计算DPO损失
# DPO的核心思想：增大策略模型对优质回答的概率，同时减小对低质量回答的概率
# ===============================================================================
# 计算DPO的logits：(策略模型相对于参考模型对好回答的提升) - (对坏回答的提升)
logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps)
# 应用logsigmoid函数并乘以beta控制优化强度，取负值（因为要最小化损失）
loss = -F.logsigmoid(beta * logits).mean()  # 对所有样本取平均

# 输出损失值
print(loss)

tensor(0.6931, grad_fn=<NegBackward0>)
