In [1]:
# 定义并加载奖励模型
import torch
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class RewardHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # llm最后输出的隐藏层的维度
        self.hidden_size = config.hidden_size
        # 线性层用来对llm最后输出的隐藏层给奖励
        self.reward = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        # 使用正态分布初始化权重
        nn.init.normal_(
            self.reward.weight,
            std=(1.0 / np.sqrt(self.hidden_size + 1))
        )
        # 将偏置初始化为0
        nn.init.zeros_(self.reward.bias)

    def forward(self, hidden_states):
        # 给出奖励
        return self.reward(hidden_states)

class GPT2RewardHead(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.llm = AutoModelForCausalLM.from_pretrained(model_name)
        self.reward_head = RewardHead(self.llm.config)

    def forward(self, input_ids, attention_mask):
        transformer_outputs = self.llm.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden_state = transformer_outputs.hidden_states[-1]
        # 给出奖励
        reward = self.reward_head(last_hidden_state).squeeze(-1)
        # sigmoid用来将奖励搞到(0,1)范围内
        return torch.sigmoid(reward)

In [2]:
model_name = "/Users/zhangyf/llm/gpt2"
reward_model = GPT2RewardHead(model_name)
reward_model.load_state_dict(torch.load("reward_model.pt", map_location='cpu'))

<All keys matched successfully>

In [3]:
# 定义价值函数头
import torch
from typing import Optional
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class ValueHead(nn.Module):
    """
    ValueHead类为GPT2实现了一个“头”，会为输出的每个token返回一个标量值
    标量值就是这个token的价值，ValueHead就是评论家。
    """
    def __init__(self, config):
        super().__init__()
        # llm最后输出的隐藏层的维度
        self.hidden_size = config.hidden_size
        # 价值网络的输出是标量
        self.value = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(
            self.value.weight,
            std=(1.0 / np.sqrt(self.hidden_size + 1))
        )
        nn.init.zeros_(self.value.bias)

    def forward(self, hidden_states):
        output = hidden_states
        return self.value(output)

In [4]:
# 策略模型和价值头结合
class ModelForCausalLMWithValueHead(nn.Module):
    """
    GPT2模型+一个价值头
    """
    def __init__(self, model_path):
        super().__init__()
        # 这个要初始化为我们微调出来的gpt2-sft模型
        # actor演员模型：策略模型
        self.llm = AutoModelForCausalLM.from_pretrained(model_path)
        # 添加价值头
        # critic评论家模型：价值函数模型，价值头
        self.v_head = ValueHead(self.llm.config)

    def forward(
        self,
        input_ids,
        attention_mask,
    ) -> Optional[torch.FloatTensor]:
        # gpt2-sft模型的输出
        transformer_outputs = self.llm.forward(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states = True,
        )
        # 输出的token的概率分布，维度为 `vocab_size`
        lm_logits = transformer_outputs.logits
        # 获取最后一层隐藏层
        last_hidden_state = transformer_outputs.hidden_states[-1]

        # 评估token的价值，评估的是最后一层隐藏层的价值
        value = self.v_head(last_hidden_state).squeeze(-1)
        # 返回输出的token的logits和token的价值
        return lm_logits, value

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

In [5]:
# 要训练的策略模型
model_path = '/Users/zhangyf/llm/gpt2-sft'
model = ModelForCausalLMWithValueHead(model_path)

In [6]:
# 加载数据
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

from datasets import load_dataset
dataset = load_dataset("./sst2")
# print(dataset)
ds_train, ds_val = dataset['train'], dataset['validation']

In [7]:
# 过滤掉长度小于8的样本
print(len(ds_train))
ds_train = ds_train.filter(lambda x: len(x['sentence'].split(' ')) > 8)
ds_val = ds_val.filter(lambda x: len(x['sentence'].split(' ')) > 8)
print(len(ds_train))
print(len(ds_val))

67349
31105
807


In [8]:
import random
input_min_token_length = 2
input_max_token_length = 8
input_token_length_range = list(range(
    input_min_token_length,
    input_max_token_length))
print(input_token_length_range)
print(random.choice(input_token_length_range))

[2, 3, 4, 5, 6, 7]
2


In [9]:
def tokenize(sample):
    input_size = random.choice(input_token_length_range)
    sample['input_ids'] = tokenizer.encode(sample['sentence'])[:input_size]
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = tokenizer.decode(sample['input_ids'])
    return sample

map_kwargs = {
    "batched": False,
    "remove_columns": ['idx', 'sentence', 'label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

In [10]:
tokenized_dataset_train[48]

{'input_ids': [19188, 423, 8288],
 'attention_mask': [1, 1, 1],
 'query': 'would have liked'}

In [11]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

print(tokenized_dataset_train[48])

{'input_ids': tensor([19188,   423,  8288]), 'attention_mask': tensor([1, 1, 1]), 'query': 'would have liked'}


In [12]:
REWARD_TOKEN_ID = tokenizer.eos_token_id

In [13]:
from torch.utils.data import DataLoader

batch_size = 32

def collator(batch):
    return dict((key, [d[key] for d in batch]) for key in batch[0])

train_dataloader = DataLoader(tokenized_dataset_train, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_val, batch_size=batch_size, collate_fn=collator, shuffle=True)

batch = next(iter(train_dataloader))
print(batch)

{'input_ids': [tensor([27080,   428,   588]), tensor([20427,  6235,   837]), tensor([ 271, 1310, 1808,  326,  428,  318,  257]), tensor([2978,  647,  289,  463, 2815, 8404,  284]), tensor([292, 880]), tensor([  896, 47188,   605, 13991,   319,   262,   983]), tensor([10594,  5160]), tensor([ 11, 345, 743, 655, 886]), tensor([ 964,  257, 3721]), tensor([ 270, 5645,  510, 7463, 1790,  355,  257]), tensor([ 1169,  5544,  7559, 43962, 10148,   318,   257]), tensor([ 368, 5902, 4340,  262,  880,   12,   86]), tensor([  7, 264]), tensor([  272, 16403, 13516]), tensor([ 1659, 20234,   546,   920,  2416,   434,   287]), tensor([16544,  7923,   629,  1252,  7228]), tensor([ 9662,   831,  5160, 18647,   705]), tensor([   64,   890,   837, 19222, 37968,   286]), tensor([4364,  274,  705, 3437,  498, 8886,  468]), tensor([ 3911,   780,   262, 30410, 13526,   276]), tensor([ 403, 6668,  837,  612]), tensor([6888, 2664,  286]), tensor([5562,  705,   82,  810,  428, 2646,  815]), tensor([   64,  3968

In [14]:
output_min_length = 5
output_max_length = 16

# https://huggingface.co/docs/trl/how_to_train#how-to-generate-text-for-training
# gpt2-sft输出的配置
# - 模型会从整个词汇表中按照原始概率分布进行采样
# - 每个词被选中的概率完全由模型的原始输出决定
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0, # 所有词汇表中的词都可能被选中
    "top_p": 1.0, # 包含整个概率分布
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id
}

In [15]:
new_tokens = random.choice(list(range(output_min_length, output_max_length)))
generation_kwargs["max_new_tokens"] = new_tokens
sample = tokenizer('Hi, this')
print(sample)
print(generation_kwargs)

{'input_ids': [17250, 11, 428], 'attention_mask': [1, 1, 1]}
{'min_length': -1, 'top_k': 0.0, 'top_p': 1.0, 'do_sample': True, 'pad_token_id': 50256, 'max_new_tokens': 11}


In [16]:
query_response = model.generate(
    input_ids=torch.tensor(sample['input_ids']).unsqueeze(0),
    attention_mask=torch.tensor(sample['attention_mask']).unsqueeze(0),
    **generation_kwargs
).squeeze(0)
print(query_response)

tensor([17250,    11,   428,  2576,  1244,   407, 14509,   345,   287,   262,
         1182,   220,   220, 16590])


In [17]:
print(tokenizer.decode(query_response))

Hi, this girl might not jack you in the head  ichi


In [18]:
new_tokens

11

In [25]:
# 测试一下奖励模型，能拿到多少奖励
with torch.no_grad():
    query_response_score = torch.cat([
        query_response,
        torch.tensor([REWARD_TOKEN_ID])])
    attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
    score = reward_model(
        query_response_score.unsqueeze(0),
        attention_mask.unsqueeze(0)
    )
    print(score)
    score = score.squeeze(0)[-1]
print(score)

tensor([[0.8801, 0.8562, 0.8565, 0.8546, 0.5753, 0.0261, 0.1038, 0.1169, 0.9319,
         0.8713, 0.7461, 0.5380, 0.4798, 0.6302, 0.3740]])
tensor(0.3740)


In [29]:
query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']
print(len(query_tensors),len(query_attention_masks))

32 32


In [34]:
# 生成一批测试数据看一下
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model = model.to(device)
reward_model = reward_model.to(device)

query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']

response_tensors = [] # 补全的张量
query_response_tensors = [] # 提示词+补全 的张量
score_tensors = [] # 奖励模型给的分数的张量

for i, query in enumerate(query_tensors):
    query = query.to(device)
    query_attention_mask = query_attention_masks[i].to(device)
    new_tokens = random.choice(list(range(output_min_length, output_max_length)))
    generation_kwargs["max_new_tokens"] = new_tokens
    query_response = model.generate(
        input_ids=query.unsqueeze(0),
        attention_mask=query_attention_mask.unsqueeze(0),
        **generation_kwargs
    ).squeeze(0)

    response_len = len(query_response) - len(query)
    response_tensors.append(query_response[-response_len:])
    query_response_tensors.append(query_response)

    with torch.no_grad():
        query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
        attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
        score = reward_model(
            query_response_score.unsqueeze(0),
            attention_mask.unsqueeze(0)
        ).squeeze(0)[-1]
        score = 2 * (score - 0.5)
    score_tensors.append(score)

batch["response"] = [tokenizer.decode(response) for response in query_response_tensors]
from pprint import pprint
pprint(batch['response'])
# 奖励分数
pprint(score_tensors)

['avoid this like the dreaded king brown snake  ',
 'balance pointed , heartfelt observations   ',
 'is little question that this is a movie that examines consciousness and '
 'challenges expectation .    ',
 'helmer hudlin tries to fascinate viewers by trying to seem like some weird '
 'hip',
 'as well-made and beautifully acted  étal étal can',
 'its lyrical variations on the game of house music  urn ',
 'will earn his 110th career Emmy    ',
 ', you may just end up in heaven : \xa0your mouth , at bottom',
 'ever a concept who never knew these things before ',
 'it ends up falling short as a motion picture with an original',
 "the fourth `` pokemon '' is a delightfully unpredictable , hilarious , and "
 'unsettling film .    ',
 "emphasizes the well-wrought theme of green card making , and ichiro 's "
 'animated',
 '( soderbergh ) is fond',
 'an appealing blend of retro style and style reminiscent of the big-screen',
 'of despair about entrapment in the new millennium wistful',
 'cam

In [35]:
from copy import deepcopy
# 冻结的参考模型
sft_model = deepcopy(model)

In [38]:
len(query_response_tensors)

32

In [41]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

input_data = data_collator([
    {'input_ids': ids,
     'attention_mask': torch.ones_like(ids)} for ids in query_response_tensors
]).to(device)
print(input_data['input_ids'].shape)
print(input_data['attention_mask'].shape)

torch.Size([32, 21])
torch.Size([32, 21])


In [42]:
def compute_rewards(
    input_data,
    query_tensors,
    response_tensors,
    score_tensors
):
    with torch.no_grad():
        # 正在微调的模型所输出的token的logits和token的价值
        # 模型输出所有token的概率分布
        # logits.shape[32,22,50257],values.shape[32,22]
        logits, values = model(**input_data) # b, seq, vocab
        # 冻结的模型的输出和价值
        # ref_logits.shape[32,22,50257]
        ref_logits, _ = sft_model(**input_data)
        # 正在微调的模型的输出的对数概率
        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        # 冻结的模型的输出的对数概率 b, seq , 50257
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)
        # 实际生成的token序列
        labels = input_data['input_ids'][:, 1:] # b, seq
        # 使用gather提取实际token的概率
        logp = torch.gather(
            logp,
            2,
            labels.unsqueeze(-1)
        ).squeeze(-1) # batch, seq
        ref_logp = torch.gather(
            ref_logp,
            2,
            labels.unsqueeze(-1)
        ).squeeze(-1) # batch, seq
        # kl散度,shape[32,21]
        kl = logp - ref_logp
        # kl散度的权重
        beta = 0.2
        # 最终奖励的计算shape[32,21]
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:,:] = attention_mask[:, 1:]
        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks

In [43]:
score_tensors

[tensor(-0.9753, device='mps:0'),
 tensor(0.9986, device='mps:0'),
 tensor(0.9932, device='mps:0'),
 tensor(0.2334, device='mps:0'),
 tensor(0.9989, device='mps:0'),
 tensor(0.9055, device='mps:0'),
 tensor(0.9524, device='mps:0'),
 tensor(0.0257, device='mps:0'),
 tensor(0.8203, device='mps:0'),
 tensor(-0.9944, device='mps:0'),
 tensor(0.9994, device='mps:0'),
 tensor(0.9961, device='mps:0'),
 tensor(0.9913, device='mps:0'),
 tensor(0.9989, device='mps:0'),
 tensor(-0.4165, device='mps:0'),
 tensor(0.9977, device='mps:0'),
 tensor(0.9167, device='mps:0'),
 tensor(-0.9980, device='mps:0'),
 tensor(0.9994, device='mps:0'),
 tensor(-0.9989, device='mps:0'),
 tensor(-0.9893, device='mps:0'),
 tensor(-0.9987, device='mps:0'),
 tensor(0.1481, device='mps:0'),
 tensor(-0.9964, device='mps:0'),
 tensor(-0.9413, device='mps:0'),
 tensor(-0.9985, device='mps:0'),
 tensor(0.9961, device='mps:0'),
 tensor(0.9977, device='mps:0'),
 tensor(-0.8038, device='mps:0'),
 tensor(-0.9763, device='mps:0')

In [44]:
logprobs, rewards, values, masks = compute_rewards(
    input_data,
    query_tensors,
    response_tensors,
    score_tensors
)
print(rewards[0])
print(input_data['input_ids'][0])
print(input_data['attention_mask'][0])
print(masks[0])
print(values[0])
print(rewards.shape)
print(values.shape)
print(masks.shape)

tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.9753, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000], device='mps:0')
tensor([27080,   428,   588,   262, 39229,  5822,  7586, 17522,   220,   220,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256], device='mps:0')
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='mps:0')
tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='mps:0')
tensor([ 0.0000, -0.0000,  2.9971, -3.6465, -0.6704, -3.7086, -1.2735,  1.1686,
         0.7444, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000], device='mps:0')
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])


In [46]:
def masked_mean(values, mask):
    # 计算带掩码的平均值
    return (values * mask).sum() / mask.sum()

def masked_var(values, mask):
    # 计算带掩码的方差
    mean = masked_mean(values, mask)
    centred_values = values - mean
    return masked_mean(centred_values ** 2, mask)

def masked_whiten(values, mask):
    '''
    对数据进行带掩码的白化处理，
    让有效数据的方差变为1，但均值保持不变
    '''
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    whitened += mean
    return whitened

def compute_advantage(rewards, values, masks):
    '''
    广义优势估计（GAE）
    '''
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95

    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)

    returns = advantages + values
    return advantages, returns

In [47]:
advantages, returns = compute_advantage(rewards, values, masks)
print(advantages[0])
print(returns[0])
print(values[0])

tensor([-3.3640e-01, -3.5414e-01, -1.6552e+00,  1.1003e+00, -1.1524e-01,
         1.1786e+00,  1.9870e-01, -8.3580e-01, -6.9831e-01,  7.1807e-04,
         7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04,
         7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04],
       device='mps:0')
tensor([-3.3640e-01, -3.5414e-01,  1.3419e+00, -2.5462e+00, -7.8567e-01,
        -2.5300e+00, -1.0748e+00,  3.3279e-01,  4.6044e-02,  7.1807e-04,
         7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04,
         7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04,  7.1807e-04],
       device='mps:0')
tensor([ 0.0000, -0.0000,  2.9971, -3.6465, -0.6704, -3.7086, -1.2735,  1.1686,
         0.7444, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000], device='mps:0')


In [48]:
# 小批次PPO训练
learning_rate = 1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# 随机排列一下各个批次大小
np.random.permutation(batch_size)

array([ 2, 13,  7,  6,  1, 15,  5,  0,  9, 19,  4, 22, 28, 10, 11, 17, 27,
        3, 23, 18, 31, 25, 21, 14, 16, 26, 24, 30,  8, 20, 12, 29])

In [49]:
# 最小的批次大小
mini_batch_size = 4
# 训练 4 个 epoch
ppo_epochs = 4
# ε = 0.2
cliprange_ratio = 0.2

v_loss_coeff = 0.1
# 比例的阈值
ratio_threshold = 10

def compute_loss(
    old_logprobs, # 冻结的一份概率
    values, # 价值
    logprobs, # 正在微调的模型输出的对数概率
    vpreds, # 预测的价值
    masks, # 掩码
    advantages, # 广义优势估计
    returns # 回报
):
    # 比率
    ratio = torch.exp(logprobs - old_logprobs)
    # 比率 * 广义优势估计
    pg_loss1 = - ratio * advantages
    # clip(比率，1-ϵ,1+ϵ) * 广义优势估计
    pg_loss2 = - torch.clamp(
        ratio,
        1 - cliprange_ratio,
        1 + cliprange_ratio
    ) * advantages
    # 策略（gpt2-sft）的损失
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)
    # 价值网络（价值头）的损失，mse
    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    # 由于 正在微调的模型 = gpt2-sft + value_head
    # 总的损失 = 策略网络的损失 + 0.1 * 价值网络的损失
    loss = pg_loss + v_loss_coeff * v_loss
    # 计算平均比率
    avg_ratio = masked_mean(ratio, masks)
    # 这一步不在ppo公式中
    # 如果平均比率 > 10
    if avg_ratio > ratio_threshold:
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0

    return loss, v_loss

def mini_batch_train():
    for ep in range(ppo_epochs):
        batch_inds = np.random.permutation(batch_size)
        # range(0, 32, 4)
        i = 0
        for start in range(0, batch_size, mini_batch_size):
            # start = 0; end = 4
            end = start + mini_batch_size
            mini_batch_inds = batch_inds[start:end]

            mb_model_inputs = {
                'input_ids': input_data \
                    ['input_ids'] \
                    [mini_batch_inds],
                'attention_mask': input_data \
                    ['attention_mask'] \
                    [mini_batch_inds]
            }
            # 模型的输出是token的logits和value
            mb_logits, mb_vpreds = model(**mb_model_inputs)
            # 去掉最后一个token
            mb_logits = torch.nn.functional.log_softmax(
                mb_logits[:, :-1, :],
                dim=-1
            )
            # 取出真实标签对应的概率
            mb_logprobs = torch.gather(
                mb_logits,
                2,
                mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)
            ).squeeze(-1)

            loss, loss_v = compute_loss(
                logprobs[mini_batch_inds],
                values[mini_batch_inds],
                mb_logprobs,
                mb_vpreds[:, :-1],
                masks[mini_batch_inds],
                advantages[mini_batch_inds],
                returns[mini_batch_inds]
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            i += 1
            print(f'loss/total_{(ep+1) * (i)}', loss.item())
    print('mini-batch training finished')

In [50]:
mini_batch_train()

loss/total 0.0630950778722763
loss/total 0.08078935742378235
loss/total -0.11473453044891357
loss/total 0.4748775362968445
loss/total 0.10602464526891708
loss/total 0.1982056200504303
loss/total 0.15830230712890625
loss/total -0.307778000831604
loss/total 0.24319370090961456
loss/total -0.043497439473867416
loss/total -0.0032646358013153076
loss/total -0.09986380487680435
loss/total -0.3570232689380646
loss/total -0.10384085774421692
loss/total 0.2232135683298111
loss/total -0.11403181403875351
loss/total -0.0358172170817852
loss/total -0.232884019613266
loss/total -0.10820959508419037
loss/total 0.40299054980278015
loss/total -0.12002972513437271
loss/total -0.11467400938272476
loss/total -0.513324499130249
loss/total -0.08656121790409088
loss/total 0.13452798128128052
loss/total -0.12799489498138428
loss/total -0.3488343060016632
loss/total -0.1875464767217636
loss/total -0.3424496054649353
loss/total -0.1724880486726761
loss/total -0.038934655487537384
loss/total 0.30629023909568787

In [51]:
# 正式开始训练
num_epochs = 1

for epoch in range(num_epochs):
    for batch in train_dataloader:
        # 生成补全内容（回复）
        query_tensors = batch['input_ids'] # 提示词的张量
        query_attention_masks = batch['attention_mask']

        response_tensors = [] # 补全的张量
        query_response_tensors = [] # 提示词+补全的张量
        score_tensors = [] # 分数的张量

        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            # 随机挑一个补全的长度
            new_tokens = random.choice(list(range(
                output_min_length,
                output_max_length)))
            # 设置补全长度属性
            generation_kwargs["max_new_tokens"] = new_tokens
            # 提示词 + 补全
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
            ).squeeze(0)
            # 补全的长度
            response_len = len(query_response) - len(query)
            # 补全的张量
            response_tensors.append(query_response[-response_len:])
            query_response_tensors.append(query_response)

            with torch.no_grad():
                # 提示词 + 补全 + reward_token
                query_response_score = torch.cat([
                    query_response,
                    torch.tensor([REWARD_TOKEN_ID]).to(device)])
                attention_mask = torch.ones_like(
                    query_response_score,
                    dtype=torch.long)
                # 奖励模型的评分
                score = reward_model(
                    query_response_score.unsqueeze(0),
                    attention_mask.unsqueeze(0)
                ).squeeze(0)[-1]
                # 将奖励模型的评分从(0,1)缩放到(-1,1)
                score = 2 * (score - 0.5)
            score_tensors.append(score)

        input_data = data_collator([
            {
                'input_ids': ids,
                'attention_mask': torch.ones_like(ids)
            }
            for ids in query_response_tensors
        ]).to(device)

        # 奖励和优势
        logprobs, rewards, values, masks = compute_rewards(
            input_data,
            query_tensors,
            response_tensors,
            score_tensors
        )
        advantages, returns = compute_advantage(rewards, values, masks)

        # 小批次训练
        mini_batch_train()
    print(f'epoch {epoch + 1} finished')

loss/total 0.17095354199409485
loss/total 0.26528221368789673
loss/total -0.07916722446680069
loss/total 0.17161089181900024
loss/total 0.48755955696105957
loss/total 0.048691775649785995
loss/total 0.28528913855552673
loss/total 0.08579196780920029
loss/total 0.6262637376785278
loss/total -0.18119752407073975
loss/total 0.05027185007929802
loss/total 0.23638170957565308
loss/total -0.05799861624836922
loss/total -0.004532668739557266
loss/total -0.07040932774543762
loss/total 0.008601825684309006
loss/total -0.05938967689871788
loss/total 0.15219682455062866
loss/total 0.26793372631073
loss/total -0.27551257610321045
loss/total -0.6159469485282898
loss/total 0.10920918732881546
loss/total 0.09216439723968506
loss/total 0.4059474468231201
loss/total -0.13892461359500885
loss/total 0.31113505363464355
loss/total 0.3974112272262573
loss/total 0.3811631202697754
loss/total -0.2912694811820984
loss/total -0.005252772010862827
loss/total -0.3749743402004242
loss/total -0.2954665422439575
mi

In [54]:
print(len(tokenized_dataset_val))
val_gen_lengths = [0] * len(tokenized_dataset_val)
for i in range(len(tokenized_dataset_val)):
    val_gen_lengths[i] = random.choice(list(range(
        output_min_length,
        output_max_length)))
val_gen_lengths[:10]

807


[9, 8, 8, 7, 8, 14, 9, 5, 13, 15]

In [55]:
def validate():
    scores = []
    for b, batch in enumerate(val_dataloader):
        # 生成补全内容
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']
        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = val_gen_lengths[b * len(query_tensors) + i]
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
            ).squeeze(0)
            query_response_score = torch.cat([
                query_response,
                torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(
                query_response_score, dtype=torch.long)
            score = reward_model(
                query_response_score.unsqueeze(0),
                attention_mask.unsqueeze(0)
            ).squeeze(0)[-1]
            score = 2 * (score - 0.5)
            scores.append(score.item())
    print('平均分数:', sum(scores) / len(scores))

In [57]:
validate()

平均分数: 0.8099561136451352


In [58]:
# 保存模型
torch.save(model.state_dict(), 'gpt2-ppo.pt')

In [60]:
# 测试一下gpt2-sft模型的输出
model_path = '/Users/zhangyf/llm/gpt2-sft'
model = ModelForCausalLMWithValueHead(model_path).to(device)
validate()

平均分数: 0.16474913472432037
