In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import warnings
from transformers import LlamaConfig, LlamaForCausalLM

# 设置随机种子以保证可复现性
torch.manual_seed(42)

# 定义示例数据 - 整个流程将使用这些数据进行演示
input_ids = torch.tensor([[3, 5, 2, 8, 1, 4]])           # [1, 6] - 提示部分
output_ids = torch.tensor([[7, 9, 6, 0]])                # [1, 4] - 生成部分
full_ids = torch.cat([input_ids, output_ids], dim=1)     # [1, 10] - 完整序列
full_mask = torch.ones_like(full_ids)                    # [1, 10] - 序列掩码

print("输入ID形状:", input_ids.shape)
print("输出ID形状:", output_ids.shape)
print("完整序列形状:", full_ids.shape)
print("掩码形状:", full_mask.shape)

输入ID形状: torch.Size([1, 6])
输出ID形状: torch.Size([1, 4])
完整序列形状: torch.Size([1, 10])
掩码形状: torch.Size([1, 10])


In [32]:
# 创建策略模型和参考模型
policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=12, num_hidden_layers=1, hidden_size=32))
reference_model = deepcopy(policy_model)  # 深度复制确保参数完全相同

# 冻结参考模型参数
for param in reference_model.parameters():
    param.requires_grad = False

# 对两个模型进行简单测试
with torch.no_grad():
    policy_outputs = policy_model(full_ids)
    ref_outputs = reference_model(full_ids)

print("策略模型输出形状:", policy_outputs.logits.shape)  # [batch_size, seq_len, vocab_size]
print("参考模型输出形状:", ref_outputs.logits.shape)

策略模型输出形状: torch.Size([1, 10, 12])
参考模型输出形状: torch.Size([1, 10, 12])


In [33]:
# 创建奖励模型
class RewardModel(nn.Module):
    def __init__(self, vocab_size=12, hidden_size=8):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.head = nn.Linear(hidden_size, 1)
    
    def forward(self, input_ids, masks=None):
        # [bs, seq_len] -> [bs, seq_len, hidden]
        x = self.embedding(input_ids)
        print(f"奖励模型 - 嵌入输出形状: {x.shape}")
        
        outputs, _ = self.lstm(x)
        print(f"奖励模型 - LSTM输出形状: {outputs.shape}")
        
        # 只取序列最后一个有效位置的输出
        if masks is not None:
            last_indices = masks.sum(dim=1) - 1
            batch_indices = torch.arange(outputs.size(0))
            last_hidden = outputs[batch_indices, last_indices]
            print(f"奖励模型 - 最后隐藏状态形状: {last_hidden.shape}")
        else:
            last_hidden = outputs[:, -1]
            
        # 生成标量奖励
        reward = self.head(last_hidden).squeeze(-1)  # [bs]
        print(f"奖励模型 - 输出奖励形状: {reward.shape}")
        
        return reward

reward_model = RewardModel()
# 测试奖励模型
reward = reward_model(full_ids, full_mask)
print(f"奖励值: {reward.item()}")

奖励模型 - 嵌入输出形状: torch.Size([1, 10, 8])
奖励模型 - LSTM输出形状: torch.Size([1, 10, 8])
奖励模型 - 最后隐藏状态形状: torch.Size([1, 8])
奖励模型 - 输出奖励形状: torch.Size([1])
奖励值: 0.29774409532546997


In [34]:
# 创建价值模型
class CriticModel(nn.Module):
    def __init__(self, vocab_size=12, hidden_size=8):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.head = nn.Linear(hidden_size, 1)
    
    def forward(self, input_ids):
        x = self.embedding(input_ids)
        print(f"价值模型 - 嵌入输出形状: {x.shape}")
        
        outputs, _ = self.lstm(x)
        print(f"价值模型 - LSTM输出形状: {outputs.shape}")
        
        values = self.head(outputs).squeeze(-1)  # [bs, seq_len]
        print(f"价值模型 - 输出价值形状: {values.shape}")
        
        return values

critic_model = CriticModel()
# 测试价值模型
values = critic_model(full_ids)
print("价值估计:")
print(values)

价值模型 - 嵌入输出形状: torch.Size([1, 10, 8])
价值模型 - LSTM输出形状: torch.Size([1, 10, 8])
价值模型 - 输出价值形状: torch.Size([1, 10])
价值估计:
tensor([[-0.2702, -0.2315, -0.2532, -0.2238, -0.2914, -0.2477, -0.2275, -0.3206,
         -0.2904, -0.2328]], grad_fn=<SqueezeBackward1>)


In [35]:
def masked_mean(values, mask):
    """计算掩码均值"""
    result = (values * mask).sum() / mask.sum()
    print(f"掩码均值 - 输入形状: {values.shape}, 掩码形状: {mask.shape}, 输出: {result.item()}")
    return result

def masked_var(values, mask):
    """计算掩码方差"""
    mean = masked_mean(values, mask)
    result = masked_mean((values - mean) ** 2, mask)
    print(f"掩码方差 - 输入形状: {values.shape}, 输出: {result.item()}")
    return result

def masked_whiten(values, mask, shift_mean=True):
    """对数值进行标准化处理"""
    print(f"标准化前 - 值形状: {values.shape}, 掩码形状: {mask.shape}")
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8) if shift_mean else values * torch.rsqrt(var + 1e-8)
    result = whitened * mask
    print(f"标准化后 - 输出形状: {result.shape}")
    return result

# 测试掩码操作
test_values = torch.randn(1, 10)  # 随机值
test_mask = torch.ones(1, 10)  # 全1掩码
whitened_values = masked_whiten(test_values, test_mask)

print("原始值:", test_values[0])
print("标准化后:", whitened_values[0])

标准化前 - 值形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10])
掩码均值 - 输入形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10]), 输出: -0.13686081767082214
掩码均值 - 输入形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10]), 输出: -0.13686081767082214
掩码均值 - 输入形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10]), 输出: 1.0405241250991821
掩码方差 - 输入形状: torch.Size([1, 10]), 输出: 1.0405241250991821
标准化后 - 输出形状: torch.Size([1, 10])
原始值: tensor([-0.0111, -0.3385, -0.7628,  0.2919, -0.6887,  2.1098, -0.6621, -1.4626,
         1.1491, -0.9935])
标准化后: tensor([ 0.1232, -0.1977, -0.6136,  0.4203, -0.5410,  2.2025, -0.5149, -1.2997,
         1.2607, -0.8398])


In [36]:
def logprobs_from_logits(logits, labels):
    """计算给定标签的对数概率"""
    print(f"对数概率计算 - 输入logits形状: {logits.shape}, 标签形状: {labels.shape}")
    logp = F.log_softmax(logits, dim=-1)
    print(f"对数概率分布形状: {logp.shape}")
    
    logp_labels = torch.gather(logp, dim=-1, index=labels.unsqueeze(-1))
    print(f"收集对应标签的对数概率形状: {logp_labels.shape}")
    
    result = logp_labels.squeeze(-1)
    print(f"最终对数概率形状: {result.shape}")
    return result

def entropy_from_logits(logits):
    """计算策略熵"""
    print(f"熵计算 - 输入logits形状: {logits.shape}")
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -torch.sum(probs * log_probs, dim=-1)
    print(f"熵输出形状: {entropy.shape}")
    return entropy

# 测试对数概率和熵计算
policy_outputs = policy_model(full_ids)
policy_logits = policy_outputs.logits
policy_logprobs = logprobs_from_logits(policy_logits, full_ids)
entropy = entropy_from_logits(policy_logits)

print("策略对数概率:", policy_logprobs[0])
print("策略熵:", entropy[0])

对数概率计算 - 输入logits形状: torch.Size([1, 10, 12]), 标签形状: torch.Size([1, 10])
对数概率分布形状: torch.Size([1, 10, 12])
收集对应标签的对数概率形状: torch.Size([1, 10, 1])
最终对数概率形状: torch.Size([1, 10])
熵计算 - 输入logits形状: torch.Size([1, 10, 12])
熵输出形状: torch.Size([1, 10])
策略对数概率: tensor([-2.5491, -2.4207, -2.4306, -2.5236, -2.4441, -2.6406, -2.2948, -2.5483,
        -2.4693, -2.4820], grad_fn=<SelectBackward0>)
策略熵: tensor([2.4801, 2.4831, 2.4791, 2.4823, 2.4804, 2.4769, 2.4790, 2.4770, 2.4779,
        2.4822], grad_fn=<SelectBackward0>)


In [37]:
def _kl_penalty(policy_logprobs, ref_logprobs):
    """计算KL散度惩罚项"""
    print(f"KL散度计算 - 策略对数概率形状: {policy_logprobs.shape}, 参考对数概率形状: {ref_logprobs.shape}")
    # KL散度: D_KL(P||Q) = E_P[log P - log Q]，这里P是参考模型分布
    kl = ref_logprobs - policy_logprobs
    print(f"KL散度输出形状: {kl.shape}")
    print(f"KL散度均值: {kl.mean().item()}")
    return kl

# 计算参考模型的对数概率
ref_outputs = reference_model(full_ids)
ref_logits = ref_outputs.logits
ref_logprobs = logprobs_from_logits(ref_logits, full_ids)

# 测试KL散度计算
kl = _kl_penalty(policy_logprobs, ref_logprobs)
print("KL散度值:", kl[0])

对数概率计算 - 输入logits形状: torch.Size([1, 10, 12]), 标签形状: torch.Size([1, 10])
对数概率分布形状: torch.Size([1, 10, 12])
收集对应标签的对数概率形状: torch.Size([1, 10, 1])
最终对数概率形状: torch.Size([1, 10])
KL散度计算 - 策略对数概率形状: torch.Size([1, 10]), 参考对数概率形状: torch.Size([1, 10])
KL散度输出形状: torch.Size([1, 10])
KL散度均值: 0.0
KL散度值: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SelectBackward0>)


In [38]:
def compute_rewards(scores, logprobs, ref_logprobs, masks, kl_coef=0.1):
    """计算每个token的奖励，包含KL惩罚"""
    print(f"奖励计算 - 奖励分数形状: {scores.shape}, logprobs形状: {logprobs.shape}")
    rewards, non_score_rewards, kls = [], [], []
    
    for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
        # 1. 计算KL散度惩罚
        kl = _kl_penalty(logprob, ref_logprob)  # [seq_len]
        kls.append(kl)
        
        # 2. 计算KL惩罚奖励分量 (-kl_coef * kl)
        non_score_reward = -kl_coef * kl  # [seq_len]
        print(f"KL惩罚奖励形状: {non_score_reward.shape}")
        non_score_rewards.append(non_score_reward)
        
        # 3. 初始化总奖励
        reward = non_score_reward.clone()  # [seq_len]
        
        # 4. 找到最后一个非掩码位置索引
        last_non_masked_index = mask.nonzero()[-1]
        print(f"最后一个token位置: {last_non_masked_index.item()}")
        
        # 5. 将奖励模型评分添加到最后一个有效token
        # 这里体现了信用分配：整体奖励由序列末尾向前传播
        reward[last_non_masked_index] += score
        print(f"最终奖励形状: {reward.shape}")
        
        rewards.append(reward)
    
    stacked_rewards = torch.stack(rewards)
    print(f"批次奖励形状: {stacked_rewards.shape}")
    return stacked_rewards, torch.stack(non_score_rewards), torch.stack(kls)

# 测试奖励计算
reward_scores = reward_model(full_ids, full_mask)  # [1]
rewards, kl_rewards, kls = compute_rewards(reward_scores, policy_logprobs, ref_logprobs, full_mask)

print("最终奖励值:")
print(rewards[0])

奖励模型 - 嵌入输出形状: torch.Size([1, 10, 8])
奖励模型 - LSTM输出形状: torch.Size([1, 10, 8])
奖励模型 - 最后隐藏状态形状: torch.Size([1, 8])
奖励模型 - 输出奖励形状: torch.Size([1])
奖励计算 - 奖励分数形状: torch.Size([1]), logprobs形状: torch.Size([1, 10])
KL散度计算 - 策略对数概率形状: torch.Size([10]), 参考对数概率形状: torch.Size([10])
KL散度输出形状: torch.Size([10])
KL散度均值: 0.0
KL惩罚奖励形状: torch.Size([10])
最后一个token位置: 9
最终奖励形状: torch.Size([10])
批次奖励形状: torch.Size([1, 10])
最终奖励值:
tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        0.2977], grad_fn=<SelectBackward0>)


In [40]:
def compute_advantages(values, rewards, mask, gamma=0.99, lam=0.95):
    """计算广义优势估计(GAE)，修正版本"""
    print(f"GAE计算 - 价值形状: {values.shape}, 奖励形状: {rewards.shape}")
    
    batch_size = values.shape[0]
    seq_len = values.shape[1]
    advantages = torch.zeros_like(values)  # [batch, seq]
    
    # 应用掩码
    values = values * mask
    rewards = rewards * mask
    
    print("\n== GAE计算详细步骤 ==")
    
    # 对每个批次单独计算GAE
    for b in range(batch_size):
        # 初始化最后一个GAE为0
        lastgaelam = 0.0
        
        # 反向遍历序列
        for t in reversed(range(seq_len)):
            # 1. 确定下一状态值 
            if t == seq_len - 1:
                nextvalue = 0.0
            else:
                nextvalue = values[b, t + 1].item()
            
            # 2. 计算TD误差: r_t + γV_(t+1) - V_t
            current_reward = rewards[b, t].item()
            current_value = values[b, t].item()
            delta = current_reward + gamma * nextvalue - current_value
            
            # 3. 更新GAE: δ_t + γλ*GAE_(t+1)
            lastgaelam = delta + gamma * lam * lastgaelam
            
            # 4. 存储到张量中
            advantages[b, t] = lastgaelam
            
            # 5. 打印详情
            print(f"步骤 t={t}:")
            print(f"  奖励 r_{t} = {current_reward:.6f}")
            print(f"  当前值 V_{t} = {current_value:.6f}")
            print(f"  下一值 V_{t+1} = {nextvalue:.6f}")
            print(f"  TD误差 δ_{t} = {delta:.6f}")
            print(f"  GAE_{t} = {lastgaelam:.6f}")
    
    # 计算回报 = 优势 + 价值
    returns = advantages + values
    
    # 标准化优势
    advantages = masked_whiten(advantages, mask)
    
    print(f"\n最终结果 - 优势形状: {advantages.shape}, 回报形状: {returns.shape}")
    return values, advantages, returns

# 测试GAE计算
_, advantages, returns = compute_advantages(values, rewards, full_mask)

GAE计算 - 价值形状: torch.Size([1, 10]), 奖励形状: torch.Size([1, 10])

== GAE计算详细步骤 ==
步骤 t=9:
  奖励 r_9 = 0.297744
  当前值 V_9 = -0.232775
  下一值 V_10 = 0.000000
  TD误差 δ_9 = 0.530519
  GAE_9 = 0.530519
步骤 t=8:
  奖励 r_8 = -0.000000
  当前值 V_8 = -0.290370
  下一值 V_9 = -0.232775
  TD误差 δ_8 = 0.059923
  GAE_8 = 0.558876
步骤 t=7:
  奖励 r_7 = -0.000000
  当前值 V_7 = -0.320555
  下一值 V_8 = -0.290370
  TD误差 δ_7 = 0.033088
  GAE_7 = 0.558711
步骤 t=6:
  奖励 r_6 = -0.000000
  当前值 V_6 = -0.227476
  下一值 V_7 = -0.320555
  TD误差 δ_6 = -0.089873
  GAE_6 = 0.435595
步骤 t=5:
  奖励 r_5 = -0.000000
  当前值 V_5 = -0.247652
  下一值 V_6 = -0.227476
  TD误差 δ_5 = 0.022451
  GAE_5 = 0.432128
步骤 t=4:
  奖励 r_4 = -0.000000
  当前值 V_4 = -0.291388
  下一值 V_5 = -0.247652
  TD误差 δ_4 = 0.046212
  GAE_4 = 0.452628
步骤 t=3:
  奖励 r_3 = -0.000000
  当前值 V_3 = -0.223812
  下一值 V_4 = -0.291388
  TD误差 δ_3 = -0.064661
  GAE_3 = 0.361035
步骤 t=2:
  奖励 r_2 = -0.000000
  当前值 V_2 = -0.253204
  下一值 V_3 = -0.223812
  TD误差 δ_2 = 0.031629
  GAE_2 = 0.371183
步骤 t=1:
 

In [41]:
def clip_by_value(x, min_val, max_val):
    """裁剪张量值到指定范围"""
    return torch.max(torch.min(x, max_val), min_val)

def entropy_from_logits(logits):
    """计算策略熵"""
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -torch.sum(probs * log_probs, dim=-1)  # [batch, seq]
    return entropy

def ppo_loss(old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns,
           cliprange=0.2, cliprange_value=0.2, vf_coef=0.1):
    """计算PPO损失函数"""
    print(f"PPO损失 - 旧logprobs: {old_logprobs.shape}, 新logprobs: {logprobs.shape}")
    
    # 1. 计算概率比率 r(θ) = π_θ/π_θ_old
    ratio = torch.exp(logprobs - old_logprobs)  # [batch, seq]
    print(f"概率比率形状: {ratio.shape}, 均值: {ratio.mean().item():.4f}")
    
    # 2. 计算策略损失
    # 原始策略梯度: -advantages * ratio
    # 裁剪策略梯度: -advantages * clip(ratio, 1-ε, 1+ε)
    pg_losses1 = -advantages * ratio
    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
    
    # 取两者中较大值（较小的收益）
    pg_loss = masked_mean(torch.max(pg_losses1, pg_losses2), mask)
    
    # 3. 计算价值损失
    # 裁剪新的价值预测，防止过大更新
    vpredclipped = clip_by_value(
        vpreds,
        values - cliprange_value,
        values + cliprange_value,
    )
    
    vf_losses1 = (vpreds - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2
    vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
    
    # 4. 计算总损失
    loss = pg_loss + vf_coef * vf_loss
    
    # 5. 计算熵和KL散度统计信息
    entropy = masked_mean(entropy_from_logits(logits), mask)
    approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
    
    return pg_loss, vf_coef * vf_loss, {
        "policy_loss": pg_loss.item(),
        "value_loss": vf_loss.item(),
        "total_loss": loss.item(),
        "entropy": entropy.item(),
        "approx_kl": approxkl.item()
    }

# 生成新的预测值
new_policy_outputs = policy_model(full_ids)
new_policy_logits = new_policy_outputs.logits
new_policy_logprobs = logprobs_from_logits(new_policy_logits, full_ids)
new_values = critic_model(full_ids)

# 测试PPO损失
pg_loss, vf_loss, stats = ppo_loss(
    policy_logprobs, values, new_policy_logits, new_values, new_policy_logprobs,
    full_mask, advantages, returns
)

对数概率计算 - 输入logits形状: torch.Size([1, 10, 12]), 标签形状: torch.Size([1, 10])
对数概率分布形状: torch.Size([1, 10, 12])
收集对应标签的对数概率形状: torch.Size([1, 10, 1])
最终对数概率形状: torch.Size([1, 10])
价值模型 - 嵌入输出形状: torch.Size([1, 10, 8])
价值模型 - LSTM输出形状: torch.Size([1, 10, 8])
价值模型 - 输出价值形状: torch.Size([1, 10])
PPO损失 - 旧logprobs: torch.Size([1, 10]), 新logprobs: torch.Size([1, 10])
概率比率形状: torch.Size([1, 10]), 均值: 1.0000
掩码均值 - 输入形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10]), 输出: -1.4305115314527939e-07
掩码均值 - 输入形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10]), 输出: 0.19876843690872192
掩码均值 - 输入形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10]), 输出: 2.4798240661621094
掩码均值 - 输入形状: torch.Size([1, 10]), 掩码形状: torch.Size([1, 10]), 输出: 0.0
