In [1]:
import torch
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class RewardHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # llm最后输出的隐藏层的维度
        self.hidden_size = config.hidden_size
        # 线性层用来对llm最后输出的隐藏层给奖励
        self.reward = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        # 使用正态分布初始化权重
        nn.init.normal_(
            self.reward.weight,
            std=(1.0 / np.sqrt(self.hidden_size + 1))
        )
        # 将偏置初始化为0
        nn.init.zeros_(self.reward.bias)

    def forward(self, hidden_states):
        # 给出奖励
        return self.reward(hidden_states)

class GPT2RewardHead(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.llm = AutoModelForCausalLM.from_pretrained(model_name)
        self.reward_head = RewardHead(self.llm.config)

    def forward(self, input_ids, attention_mask):
        transformer_outputs = self.llm.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden_state = transformer_outputs.hidden_states[-1]
        # 给出奖励
        reward = self.reward_head(last_hidden_state).squeeze(-1)
        # sigmoid用来将奖励搞到(0,1)范围内
        return torch.sigmoid(reward)

In [2]:
model_name = "/Users/zhangyf/llm/gpt2"
reward_model = GPT2RewardHead(model_name)
reward_model.load_state_dict(torch.load("reward_model.pt", map_location='cpu'))

<All keys matched successfully>

In [3]:
import torch
from typing import Optional
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class ValueHead(nn.Module):
    """
    ValueHead类为GPT2实现了一个“头”，会为输出的每个token返回一个标量值
    标量值就是这个token的价值，ValueHead就是评论家。
    """
    def __init__(self, config):
        super().__init__()
        # llm最后输出的隐藏层的维度
        self.hidden_size = config.hidden_size
        # 价值网络的输出是标量
        self.value = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(
            self.value.weight,
            std=(1.0 / np.sqrt(self.hidden_size + 1))
        )
        nn.init.zeros_(self.value.bias)

    def forward(self, hidden_states):
        output = hidden_states
        return self.value(output)

In [4]:
class ModelForCausalLMWithValueHead(nn.Module):
    """
    GPT2模型+一个价值头
    """
    def __init__(self, model_path):
        super().__init__()
        # 这个要初始化为我们微调出来的gpt2-sft模型
        # actor演员模型：策略模型
        self.llm = AutoModelForCausalLM.from_pretrained(model_path)
        # 添加价值头
        # critic评论家模型：价值函数模型，价值头
        self.v_head = ValueHead(self.llm.config)

    def forward(
        self,
        input_ids,
        attention_mask,
    ) -> Optional[torch.FloatTensor]:
        # gpt2-sft模型的输出
        transformer_outputs = self.llm.forward(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states = True,
        )
        # 输出的token的概率分布，维度为 `vocab_size`
        lm_logits = transformer_outputs.logits
        # 获取最后一层隐藏层
        last_hidden_state = transformer_outputs.hidden_states[-1]

        # 评估token的价值，评估的是最后一层隐藏层的价值
        value = self.v_head(last_hidden_state).squeeze(-1)
        # 返回输出的token的logits和token的价值
        return lm_logits, value

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

In [5]:
model_path = '/Users/zhangyf/llm/gpt2-sft'
model = ModelForCausalLMWithValueHead(model_path)

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

from datasets import load_dataset
dataset = load_dataset("./sst2")
print(dataset)

ds_train, ds_val = dataset['train'], dataset['validation']

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})


In [7]:
print(len(ds_train))
ds_train = ds_train.filter(lambda x: len(x['sentence'].split(' ')) > 8)
ds_val = ds_val.filter(lambda x: len(x['sentence'].split(' ')) > 8)

print(len(ds_train))
print(len(ds_val))

67349
31105
807


In [8]:
import random
input_min_token_length = 2
input_max_token_length = 8
input_token_length_range = list(range(
    input_min_token_length,
    input_max_token_length))
print(input_token_length_range)
print(random.choice(input_token_length_range))

[2, 3, 4, 5, 6, 7]
7


In [9]:
def tokenize(sample):
    input_size = random.choice(input_token_length_range)
    sample['input_ids'] = tokenizer.encode(sample['sentence'])[:input_size]
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = tokenizer.decode(sample['input_ids'])
    return sample

map_kwargs = {
    "batched": False,
    "remove_columns": ['idx', 'sentence', 'label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

In [10]:
tokenized_dataset_val[0]

{'input_ids': [270, 705, 82, 257, 23332, 290, 1690],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1],
 'query': "it 's a charming and often"}

In [11]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

print(tokenized_dataset_train[6])

{'input_ids': tensor([ 1640,   883,  3807, 31006,   508, 13121]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1]), 'query': 'for those moviegoers who complain'}


In [12]:
REWARD_TOKEN_ID = tokenizer.eos_token_id

In [13]:
from torch.utils.data import DataLoader

batch_size = 32

def collator(batch):
    return dict((key, [d[key] for d in batch]) for key in batch[0])

train_dataloader = DataLoader(tokenized_dataset_train, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_val, batch_size=batch_size, collate_fn=collator, shuffle=True)

batch = next(iter(train_dataloader))
print(batch)

{'input_ids': [tensor([  11, 1534, 1531]), tensor([8043,  837, 2647,  837]), tensor([2502, 8988]), tensor([6888,  299,  470,  766, 1521,  597]), tensor([   76,  3316,   705,    82, 39769,   318]), tensor([ 1462,   852,   262, 25203,    12]), tensor([   64,  8303,   379, 11783,   289]), tensor([ 265, 1661,  257, 1643, 7758,  375]), tensor([19188,   423]), tensor([  271, 14169,   837,   572, 12945]), tensor([4360,  772,  981,  465, 3435]), tensor([ 270,  705,   82,  257, 9058,  523]), tensor([21754,  1577,  7559, 12692, 10148,   257]), tensor([ 338, 5729]), tensor([ 271,  517, 3499,  357,  290]), tensor([39240,   298,   290, 20868]), tensor([5562, 4071, 2646, 3025]), tensor([ 5661,  2646,   837,   588,   262, 12470,  2801]), tensor([1169, 1877,   12, 9526, 9891]), tensor([  272, 30690,  5022,   286,  3223, 35704,   290]), tensor([ 732, 1239, 1254]), tensor([ 11, 262, 717]), tensor([  271,   281,   555, 48544,   306]), tensor([ 4360,   991,  2407, 25103,   290, 23310,   477]), tensor([ 26

In [14]:
output_min_length = 5
output_max_length = 16

# https://huggingface.co/docs/trl/how_to_train#how-to-generate-text-for-training
# gpt2-sft输出的配置
# - 模型会从整个词汇表中按照原始概率分布进行采样
# - 每个词被选中的概率完全由模型的原始输出决定
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0, # 所有词汇表中的词都可能被选中
    "top_p": 1.0, # 包含整个概率分布
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id
}

In [15]:
new_tokens = random.choice(list(range(output_min_length, output_max_length)))
generation_kwargs["max_new_tokens"] = new_tokens
sample = tokenizer('Hi, this')
print(sample)

{'input_ids': [17250, 11, 428], 'attention_mask': [1, 1, 1]}


In [16]:
query_response = model.generate(
    input_ids=torch.tensor(sample['input_ids']).unsqueeze(0),
    attention_mask=torch.tensor(sample['attention_mask']).unsqueeze(0),
    **generation_kwargs
).squeeze(0)
print(query_response)

tensor([17250,    11,   428,   318,   262,   845,   717, 25168,   286,   262,
         8663,  2277,  3807])


In [17]:
print(tokenizer.decode(query_response))

Hi, this is the very first installment of the franchise hit movie


In [18]:
new_tokens

10

In [19]:
with torch.no_grad():
    query_response_score = torch.cat([
        query_response,
        torch.tensor([REWARD_TOKEN_ID])])
    attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
    score = reward_model(
        query_response_score.unsqueeze(0),
        attention_mask.unsqueeze(0)
    ).squeeze(0)[-1]
print(score)

tensor(0.9939)


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
reward_model = reward_model.to(device)

query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']

response_tensors = [] # 补全的张量
query_response_tensors = [] # 提示词+补全 的张量
score_tensors = [] # 奖励模型给的分数的张量

for i, query in enumerate(query_tensors):
    query = query.to(device)
    query_attention_mask = query_attention_masks[i].to(device)
    new_tokens = random.choice(list(range(output_min_length, output_max_length)))
    generation_kwargs["max_new_tokens"] = new_tokens
    query_response = model.generate(
        input_ids=query.unsqueeze(0),
        attention_mask=query_attention_mask.unsqueeze(0),
        **generation_kwargs
    ).squeeze(0)

    response_len = len(query_response) - len(query)
    response_tensors.append(query_response[-response_len:])
    query_response_tensors.append(query_response)

    with torch.no_grad():
        query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
        attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
        score = reward_model(
            query_response_score.unsqueeze(0),
            attention_mask.unsqueeze(0)
        ).squeeze(0)[-1]
        score = 2 * (score - 0.5)
    score_tensors.append(score)

batch["response"] = [tokenizer.decode(response) for response in response_tensors]
from pprint import pprint
pprint(batch['response'])

[' obsessions , icky jokes and',
 ' and whimsy of love . ',
 ' korean maternal instincts with vitality and tenderness , which is something '
 'no other',
 ' movie could ever possibly be as successful in the',
 ' well on its way to',
 'burningly bad version of the wildly over-',
 'ooliganism in the head ̶ ̶ ̶ it',
 'ramatic but  ive never felt more',
 ' a great recommendation to watch , especially iced tea',
 ' and hilarious    icated on a different ethnic background and',
 ' are intermittently funny , their jokes are',
 ' poorly shot that by mid',
 ' whirl       if it',
 ' amused at how engaging all this is    that the laughs are so',
 " accurate ) than the man '",
 ' really shines     as the world and empire decomposes . ',
 ' using of voice-over actors in action movies can only inspire awe ',
 '-timed bride-and-dame remake , is altern',
 ' ersatzness icky salt rises to the level of enthusiasm',
 ' , throughout , some deep sadness and psychological',
 " we 're in danger when we see ba

In [21]:
score_tensors

[tensor(-0.9561),
 tensor(0.9919),
 tensor(0.9979),
 tensor(-0.9763),
 tensor(0.9562),
 tensor(-0.9988),
 tensor(-0.9871),
 tensor(0.2387),
 tensor(0.8755),
 tensor(0.9992),
 tensor(0.2430),
 tensor(-0.9997),
 tensor(0.2599),
 tensor(0.9940),
 tensor(0.9944),
 tensor(0.9956),
 tensor(-0.8964),
 tensor(-0.9049),
 tensor(0.6486),
 tensor(0.7333),
 tensor(-0.9872),
 tensor(-0.1871),
 tensor(0.9091),
 tensor(0.9994),
 tensor(-0.9970),
 tensor(0.9938),
 tensor(-0.9895),
 tensor(0.9503),
 tensor(0.9925),
 tensor(-0.9017),
 tensor(-0.9826),
 tensor(0.9989)]

In [22]:
from copy import deepcopy
# 冻结的模型
sft_model = deepcopy(model)

In [23]:
query_response_tensors

[tensor([   11,  1534,  1531, 10201,  6202,   837,   220, 17479, 14532,   290]),
 tensor([ 8043,   837,  2647,   837,   290, 29923,  1837,   286,  1842,   764,
           220]),
 tensor([ 2502,  8988,   479, 29456, 22160, 26744,   351, 41687,   290, 15403,
          1108,   837,   543,   318,  1223,   645,   584]),
 tensor([6888,  299,  470,  766, 1521,  597, 3807,  714, 1683, 5457,  307,  355,
         4388,  287,  262]),
 tensor([   76,  3316,   705,    82, 39769,   318,   880,   319,   663,   835,
           284]),
 tensor([ 1462,   852,   262, 25203,    12, 10899,  4420,  2089,  2196,   286,
           262, 20278,   625,    12]),
 tensor([   64,  8303,   379, 11783,   289,   970,  5516,  1042,   287,   262,
          1182,   220, 48869,   220, 48869,   220, 48869,   340]),
 tensor([ 265, 1661,  257, 1643, 7758,  375,  859, 1512,  475,  220,  220,  425,
         1239, 2936,  517]),
 tensor([19188,   423,   257,  1049, 15602,   284,  2342,   837,  2592,   220,
          3711,  8887])

In [24]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

input_data = data_collator([
    {'input_ids': ids,
     'attention_mask': torch.ones_like(ids)} for ids in query_response_tensors
]).to(device)
print(input_data)

{'input_ids': tensor([[   11,  1534,  1531, 10201,  6202,   837,   220, 17479, 14532,   290,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [ 8043,   837,  2647,   837,   290, 29923,  1837,   286,  1842,   764,
           220, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [ 2502,  8988,   479, 29456, 22160, 26744,   351, 41687,   290, 15403,
          1108,   837,   543,   318,  1223,   645,   584, 50256, 50256, 50256,
         50256],
        [ 6888,   299,   470,   766,  1521,   597,  3807,   714,  1683,  5457,
           307,   355,  4388,   287,   262, 50256, 50256, 50256, 50256, 50256,
         50256],
        [   76,  3316,   705,    82, 39769,   318,   880,   319,   663,   835,
           284, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [ 1462,   852,   262, 25203,    12, 10899,  4420,  2089,  2196,   286,
           262, 20278,   625,   

  arr = np.array(obj)
  arr = np.array(obj)


In [25]:
def compute_rewards(
    input_data,
    query_tensors,
    response_tensors,
    score_tensors
):
    with torch.no_grad():
        # 正在微调的模型所输出的token的logits和token的价值
        # 模型输出所有token的概率分布
        logits, values = model(**input_data) # b, seq, vocab
        # 冻结的模型的输出和价值
        ref_logits, _ = sft_model(**input_data)
        # 正在微调的模型的输出的对数概率
        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        # 冻结的模型的输出的对数概率
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)
        # 实际生成的token序列
        labels = input_data['input_ids'][:, 1:] # b, seq
        # 使用gather提取实际token的概率
        logp = torch.gather(
            logp,
            2,
            labels.unsqueeze(-1)
        ).squeeze(-1) # batch, seq
        ref_logp = torch.gather(
            ref_logp,
            2,
            labels.unsqueeze(-1)
        ).squeeze(-1) # batch, seq
        # kl散度
        kl = logp - ref_logp
        # kl散度的权重
        beta = 0.2
        # 最终奖励的计算
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:,:] = attention_mask[:, 1:]
        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks

In [26]:
score_tensors


[tensor(-0.9561),
 tensor(0.9919),
 tensor(0.9979),
 tensor(-0.9763),
 tensor(0.9562),
 tensor(-0.9988),
 tensor(-0.9871),
 tensor(0.2387),
 tensor(0.8755),
 tensor(0.9992),
 tensor(0.2430),
 tensor(-0.9997),
 tensor(0.2599),
 tensor(0.9940),
 tensor(0.9944),
 tensor(0.9956),
 tensor(-0.8964),
 tensor(-0.9049),
 tensor(0.6486),
 tensor(0.7333),
 tensor(-0.9872),
 tensor(-0.1871),
 tensor(0.9091),
 tensor(0.9994),
 tensor(-0.9970),
 tensor(0.9938),
 tensor(-0.9895),
 tensor(0.9503),
 tensor(0.9925),
 tensor(-0.9017),
 tensor(-0.9826),
 tensor(0.9989)]

In [27]:
logprobs, rewards, values, masks = compute_rewards(
    input_data,
    query_tensors,
    response_tensors,
    score_tensors
)
print(rewards[0])
print(input_data['input_ids'][0])
print(input_data['attention_mask'][0])
print(masks[0])
print(values[0])

tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.9561, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000])
tensor([   11,  1534,  1531, 10201,  6202,   837,   220, 17479, 14532,   290,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([ 0.0000, -0.0000,  4.9234, -4.2811,  1.9289,  2.6755,  0.7618,  5.4120,
        -0.9051,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000])


In [28]:
def masked_mean(values, mask):
    # 计算带掩码的平均值
    return (values * mask).sum() / mask.sum()

def masked_var(values, mask):
    # 计算带掩码的方差
    mean = masked_mean(values, mask)
    centred_values = values - mean
    return masked_mean(centred_values ** 2, mask)

def masked_whiten(values, mask):
    '''
    对数据进行带掩码的白化处理，
    让有效数据的方差变为1，但均值保持不变
    '''
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    whitened += mean
    return whitened

def compute_advantage(rewards, values, masks):
    '''
    广义优势估计（GAE）
    '''
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95

    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)

    returns = advantages + values
    return advantages, returns

In [29]:
advantages, returns = compute_advantage(rewards, values, masks)
print(advantages[0])
print(returns[0])

tensor([-0.8922, -0.8971, -3.2641,  1.0217, -1.8616, -2.2757, -1.4354, -3.6997,
        -0.8219, -0.7987, -0.7987, -0.7987, -0.7987, -0.7987, -0.7987, -0.7987,
        -0.7987, -0.7987, -0.7987, -0.7987])
tensor([-0.8922, -0.8971,  1.6592, -3.2594,  0.0673,  0.3998, -0.6735,  1.7123,
        -1.7270, -0.7987, -0.7987, -0.7987, -0.7987, -0.7987, -0.7987, -0.7987,
        -0.7987, -0.7987, -0.7987, -0.7987])


In [30]:
learning_rate = 1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# 随机排列一下各个批次大小
np.random.permutation(batch_size)

array([ 2, 28,  8, 15,  5, 14, 27, 23, 26,  7, 30, 20, 16, 11, 22, 24,  9,
        0, 29,  4, 17, 13, 31, 21, 25, 18, 10,  3,  1, 19,  6, 12])

In [31]:
# 最小的批次大小
mini_batch_size = 4
# 训练 4 个 epoch
ppo_epochs = 4
# ε = 0.2
cliprange_ratio = 0.2

v_loss_coeff = 0.1
# 比例的阈值
ratio_threshold = 10

def compute_loss(
    old_logprobs, # 冻结的一份概率
    values, # 价值
    logprobs, # 正在微调的模型输出的对数概率
    vpreds, # 预测的价值
    masks, # 掩码
    advantages, # 广义优势估计
    returns # 回报
):
    # 比率
    ratio = torch.exp(logprobs - old_logprobs)
    # 比率 * 广义优势估计
    pg_loss1 = - ratio * advantages
    # clip(比率，1-ϵ,1+ϵ) * 广义优势估计
    pg_loss2 = - torch.clamp(
        ratio,
        1 - cliprange_ratio,
        1 + cliprange_ratio
    ) * advantages
    # 策略（gpt2-sft）的损失
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)
    # 价值网络（价值头）的损失，mse
    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    # 由于 正在微调的模型 = gpt2-sft + value_head
    # 总的损失 = 策略网络的损失 + 0.1 * 价值网络的损失
    loss = pg_loss + v_loss_coeff * v_loss
    # 计算平均比率
    avg_ratio = masked_mean(ratio, masks)
    # 这一步不在ppo公式中
    # 如果平均比率 > 10
    if avg_ratio > ratio_threshold:
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0

    return loss, v_loss

def mini_batch_train():
    for ep in range(ppo_epochs):
        batch_inds = np.random.permutation(batch_size)
        # range(0, 32, 4)
        for start in range(0, batch_size, mini_batch_size):
            # start = 0; end = 4
            end = start + mini_batch_size
            mini_batch_inds = batch_inds[start:end]

            mb_model_inputs = {
                'input_ids': input_data \
                    ['input_ids'] \
                    [mini_batch_inds],
                'attention_mask': input_data \
                    ['attention_mask'] \
                    [mini_batch_inds]
            }
            # 模型的输出是token的logits和value
            mb_logits, mb_vpreds = model(**mb_model_inputs)
            # 去掉最后一个token
            mb_logits = torch.nn.functional.log_softmax(
                mb_logits[:, :-1, :],
                dim=-1
            )
            # 取出真实标签对应的概率
            mb_logprobs = torch.gather(
                mb_logits,
                2,
                mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)
            ).squeeze(-1)

            loss, loss_v = compute_loss(
                logprobs[mini_batch_inds],
                values[mini_batch_inds],
                mb_logprobs,
                mb_vpreds[:, :-1],
                masks[mini_batch_inds],
                advantages[mini_batch_inds],
                returns[mini_batch_inds]
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('loss/total', loss.item())
    print('mini-batch training finished')

In [32]:
mini_batch_train()

loss/total 1.5086708068847656
loss/total 1.9685832262039185
loss/total 1.4921576976776123
loss/total 1.9890475273132324
loss/total 1.4088155031204224
loss/total 1.8091586828231812
loss/total 1.9725093841552734
loss/total 1.543008804321289
loss/total 1.4625900983810425
loss/total 1.6584177017211914
loss/total 1.389695405960083
loss/total 1.21249258518219
loss/total 1.3319447040557861
loss/total 1.2083057165145874
loss/total 1.6089109182357788
loss/total 1.2492483854293823
loss/total 1.3308871984481812
loss/total 0.9942128658294678
loss/total 1.2945976257324219
loss/total 1.1975150108337402
loss/total 1.4258742332458496
loss/total 1.4625556468963623
loss/total 1.2934088706970215
loss/total 1.2743065357208252
loss/total 1.2525415420532227
loss/total 1.312488079071045
loss/total 1.1764097213745117
loss/total 1.2857826948165894
loss/total 1.299506425857544
loss/total 1.6054877042770386
loss/total 1.113140344619751
loss/total 0.9546226263046265
mini-batch training finished


In [33]:
num_epochs = 1

for epoch in range(num_epochs):
    for batch in train_dataloader:
        # 生成补全内容（回复）
        query_tensors = batch['input_ids'] # 提示词的张量
        query_attention_masks = batch['attention_mask']

        response_tensors = [] # 补全的张量
        query_response_tensors = [] # 提示词+补全的张量
        score_tensors = [] # 分数的张量

        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            # 随机挑一个补全的长度
            new_tokens = random.choice(list(range(
                output_min_length,
                output_max_length)))
            # 设置补全长度属性
            generation_kwargs["max_new_tokens"] = new_tokens
            # 提示词 + 补全
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
            ).squeeze(0)
            # 补全的长度
            response_len = len(query_response) - len(query)
            # 补全的张量
            response_tensors.append(query_response[-response_len:])
            query_response_tensors.append(query_response)

            with torch.no_grad():
                # 提示词 + 补全 + reward_token
                query_response_score = torch.cat([
                    query_response,
                    torch.tensor([REWARD_TOKEN_ID]).to(device)])
                attention_mask = torch.ones_like(
                    query_response_score,
                    dtype=torch.long)
                # 奖励模型的评分
                score = reward_model(
                    query_response_score.unsqueeze(0),
                    attention_mask.unsqueeze(0)
                ).squeeze(0)[-1]
                # 将奖励模型的评分从(0,1)缩放到(-1,1)
                score = 2 * (score - 0.5)
            score_tensors.append(score)

        input_data = data_collator([
            {
                'input_ids': ids,
                'attention_mask': torch.ones_like(ids)
            }
            for ids in query_response_tensors
        ]).to(device)

        # 奖励和优势
        logprobs, rewards, values, masks = compute_rewards(
            input_data,
            query_tensors,
            response_tensors,
            score_tensors
        )
        advantages, returns = compute_advantage(rewards, values, masks)

        # 小批次训练
        mini_batch_train()
    print(f'epoch {epoch + 1} finished')

loss/total 0.7508983612060547
loss/total 1.0324691534042358
loss/total 1.0763437747955322
loss/total 1.1622765064239502
loss/total 1.107187271118164
loss/total 0.945980429649353
loss/total 1.047513723373413
loss/total 0.9599798321723938
loss/total 1.0468543767929077
loss/total 1.109917163848877
loss/total 0.4298473298549652
loss/total 0.6093572974205017
loss/total 1.1704490184783936
loss/total 0.7546617388725281
loss/total 0.8411296606063843
loss/total 0.5838620066642761
loss/total 0.4295908212661743
loss/total 0.8948656320571899
loss/total 0.6778674125671387
loss/total 0.3236315846443176
loss/total 0.8596951365470886
loss/total 0.9914372563362122
loss/total 0.7081918716430664
loss/total 0.9933494329452515
loss/total 0.4940328299999237
loss/total 0.7501310110092163
loss/total 0.37403276562690735
loss/total 0.9266468286514282
loss/total 0.9018948674201965
loss/total 1.1137652397155762
loss/total 0.5823561549186707
loss/total 0.742218017578125
mini-batch training finished
loss/total 0.81

KeyboardInterrupt: 