In [1]:
import config
from model import GPT2RewardModel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = str(config.GPT2_PATH)
reward_model = GPT2RewardModel(model_name)
reward_model.load_state_dict(torch.load(str(config.REWARD_MODEL_PATH), map_location='cpu'))

<All keys matched successfully>

In [3]:
from model import ModelForCausalLMWithValueHead

model_path = str(config.GPT2_SFT_PATH)
model = ModelForCausalLMWithValueHead(model_path)

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

from datasets import load_dataset

dataset = load_dataset(str(config.SST2_PATH))
print(dataset)

ds_train, ds_val = dataset['train'], dataset['validation']

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})


In [6]:
print(len(ds_train))
ds_train = ds_train.filter(lambda x: len(x['sentence'].split(' ')) > 8)
ds_val = ds_val.filter(lambda x: len(x['sentence'].split(' ')) > 8)

print(len(ds_train))
print(len(ds_val))

67349


Filter: 100%|██████████| 67349/67349 [00:00<00:00, 621088.39 examples/s]
Filter: 100%|██████████| 872/872 [00:00<00:00, 218654.46 examples/s]

31105
807





In [7]:
import random

input_min_token_length = 2
input_max_token_length = 8
input_token_length_range = list(range(
    input_min_token_length,
    input_max_token_length))
print(input_token_length_range)
print(random.choice(input_token_length_range))

[2, 3, 4, 5, 6, 7]
2


In [8]:
def tokenize(sample):
    input_size = random.choice(input_token_length_range)
    sample['input_ids'] = tokenizer.encode(sample['sentence'])[:input_size]
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = tokenizer.decode(sample['input_ids'])
    return sample


map_kwargs = {
    "batched": False,
    "remove_columns": ['idx', 'sentence', 'label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

Map: 100%|██████████| 31105/31105 [00:01<00:00, 18974.49 examples/s]
Map: 100%|██████████| 807/807 [00:00<00:00, 16626.24 examples/s]


In [9]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

print(tokenized_dataset_train[6])

{'input_ids': tensor([ 1640,   883,  3807, 31006,   508, 13121]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1]), 'query': 'for those moviegoers who complain'}


In [10]:
REWARD_TOKEN_ID = tokenizer.eos_token_id

In [11]:
from torch.utils.data import DataLoader

batch_size = 32


def collator(batch):
    return dict((key, [d[key] for d in batch]) for key in batch[0])


train_dataloader = DataLoader(tokenized_dataset_train, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_val, batch_size=batch_size, collate_fn=collator, shuffle=True)

batch = next(iter(train_dataloader))
print(batch)

{'input_ids': [tensor([ 270,  705,   82, 4071,  284, 1064]), tensor([  272, 14274]), tensor([ 265, 8699]), tensor([24267, 19522]), tensor([ 271,  326,  607, 6628]), tensor([18302, 11184,   284,  1416,    64,   705,    82]), tensor([270, 705,  82]), tensor([ 9930,  2499, 15013,   306,   880]), tensor([  301,  2645, 16772,   837,   262,  3807,   318]), tensor([ 505, 1917,  351,  262, 3807]), tensor([  325,   368,   284,   307, 18951, 13658]), tensor([5171,  691, 2148,  340,  351,  523]), tensor([ 5661,   318, 18700, 49291,   837]), tensor([27144,   452,   689,   290,  2523,   703]), tensor([   64,  4601,    88,    12, 34670]), tensor([1169, 2761,  290, 3435,  340]), tensor([   11,   290,   511, 20929, 17777,  7702]), tensor([   64,  4713,   837, 17774, 10997,   326]), tensor([  505,   286,   995, 22041,   705,    82,   749]), tensor([ 1169, 10092, 11375,   326,   739]), tensor([  65, 1044,  306]), tensor([ 3919, 15119,   994]), tensor([5171,  991,  307]), tensor([  325,  5431,   837, 195

In [12]:
output_min_length = 5
output_max_length = 16

# https://huggingface.co/docs/trl/how_to_train#how-to-generate-text-for-training
# gpt2-sft输出的配置
# - 模型会从整个词汇表中按照原始概率分布进行采样
# - 每个词被选中的概率完全由模型的原始输出决定
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,  # 所有词汇表中的词都可能被选中
    "top_p": 1.0,  # 包含整个概率分布
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id
}

In [13]:
new_tokens = random.choice(list(range(output_min_length, output_max_length)))
generation_kwargs["max_new_tokens"] = new_tokens
sample = tokenizer('Hi, this')
print(sample)

{'input_ids': [17250, 11, 428], 'attention_mask': [1, 1, 1]}


In [14]:
query_response = model.generate(
    input_ids=torch.tensor(sample['input_ids']).unsqueeze(0),
    attention_mask=torch.tensor(sample['attention_mask']).unsqueeze(0),
    **generation_kwargs
).squeeze(0)
print(query_response)

tensor([17250,    11,   428,   318,   257,  3807,   326,  7138, 23007,   262,
        13843,   286])


In [15]:
print(tokenizer.decode(query_response))

Hi, this is a movie that perfectly captures the imagination of


In [16]:
with torch.no_grad():
    query_response_score = torch.cat([
        query_response,
        torch.tensor([REWARD_TOKEN_ID])])
    attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
    score = reward_model(
        query_response_score.unsqueeze(0),
        attention_mask.unsqueeze(0)
    ).squeeze(0)[-1]
print(score)

tensor(0.9988)


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
reward_model = reward_model.to(device)

query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']

response_tensors = []
query_response_tensors = []
score_tensors = []

for i, query in enumerate(query_tensors):
    query = query.to(device)
    query_attention_mask = query_attention_masks[i].to(device)
    new_tokens = random.choice(list(range(output_min_length, output_max_length)))
    generation_kwargs["max_new_tokens"] = new_tokens
    query_response = model.generate(
        input_ids=query.unsqueeze(0),
        attention_mask=query_attention_mask.unsqueeze(0),
        **generation_kwargs
    ).squeeze(0)

    response_len = len(query_response) - len(query)
    response_tensors.append(query_response[-response_len:])
    query_response_tensors.append(query_response)

    with torch.no_grad():
        query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
        attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
        score = reward_model(
            query_response_score.unsqueeze(0),
            attention_mask.unsqueeze(0)
        ).squeeze(0)[-1]
        score = 2 * (score - 0.5)
    score_tensors.append(score)

batch["response"] = [tokenizer.decode(response) for response in response_tensors]
from pprint import pprint

pprint(batch['response'])

[' an entire country that embraces the measure of maturity iced tea descends',
 'oly mess   ',
 ' minutes long , this movie',
 ' is simply overpowered .     ',
 ' grows with every passing frame      ',
 ' upscale , offensive tone throughout but icky , pot-',
 ' sexier and boldier',
 ' in this regard .       ',
 ' flat and unconvincing ',
 " 's direction      ly",
 ' enough to tolerate a fetish iced tea party iced tea party ,',
 ' much sympathy , that it is impossible',
 " and ian park jones 's veracity",
 ' willing the audience was to speak out while others in',
 'y kiddie flick iced',
 ' explores \xa0are rather mind bogg',
 ' changes about every single aspect of them , both',
 ' manages to make a point .',
 ' ironic gifted artists ian holm ian freeman does just that',
 "lies the film 's wonderfully shaped narrative ",
 ' and tragic ética   iken     ike',
 ' .     I might have been disappointed',
 ' a leap of faith in an era of mergers and downs',
 ' weirdly shapable ',
 ' an amusement

In [18]:
from copy import deepcopy

sft_model = deepcopy(model)

In [19]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

input_data = data_collator([
    {'input_ids': ids,
     'attention_mask': torch.ones_like(ids)} for ids in query_response_tensors
]).to(device)
print(input_data)

{'input_ids': tensor([[  270,   705,    82,  4071,   284,  1064,   281,  2104,  1499,   326,
         37872,   262,  3953,   286, 24841,   220,  3711,  8887,  1715,  2412,
         50256],
        [  272, 14274,  3366,  2085,   220,   220,   220, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [  265,  8699,  2431,   890,   837,   428,  3807, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [24267, 19522,   318,  2391, 49313,   764,   220,   220,   220,   220,
           220, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [  271,   326,   607,  6628, 13676,   351,   790,  6427,  5739,   220,
           220,   220,   220,   220,   220, 50256, 50256, 50256, 50256, 50256,
         50256],
        [18302, 11184,   284,  1416,    64,   705,    82, 44918,   837,  5859,
          8216,  3690,   475,   

In [20]:
def compute_rewards(
        input_data,
        query_tensors,
        response_tensors,
        score_tensors
):
    with torch.no_grad():
        # 正在微调的模型所输出的token的logits和token的价值
        # 模型输出所有token的概率分布
        logits, values = model(**input_data)  # b, seq, vocab
        # 冻结的模型的输出和价值
        ref_logits, _ = sft_model(**input_data)
        # 正在微调的模型的输出的对数概率
        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        # 冻结的模型的输出的对数概率
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)
        # 实际生成的token序列
        labels = input_data['input_ids'][:, 1:]  # b, seq
        # 使用gather提取实际token的概率
        logp = torch.gather(
            logp,
            2,
            labels.unsqueeze(-1)
        ).squeeze(-1)  # batch, seq
        ref_logp = torch.gather(
            ref_logp,
            2,
            labels.unsqueeze(-1)
        ).squeeze(-1)  # batch, seq
        # kl散度
        kl = logp - ref_logp
        # kl散度的权重
        beta = 0.2
        # 最终奖励的计算
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:, :] = attention_mask[:, 1:]
        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks

In [22]:
logprobs, rewards, values, masks = compute_rewards(
    input_data,
    query_tensors,
    response_tensors,
    score_tensors
)
print(rewards[0])
print(input_data['input_ids'][0])
print(input_data['attention_mask'][0])
print(masks[0])
print(values[0])

tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        0.9417, -0.0000])
tensor([  270,   705,    82,  4071,   284,  1064,   281,  2104,  1499,   326,
        37872,   262,  3953,   286, 24841,   220,  3711,  8887,  1715,  2412,
        50256])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0])
tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0])
tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.9542, -0.0225,  0.7976,
         0.5902,  1.7268,  3.9876, -0.0290,  1.9629,  2.6377,  1.9755,  0.1747,
         1.0370,  2.4842, -0.6464,  0.0000])


In [23]:
def masked_mean(values, mask):
    # 计算带掩码的平均值
    return (values * mask).sum() / mask.sum()


def masked_var(values, mask):
    # 计算带掩码的方差
    mean = masked_mean(values, mask)
    centred_values = values - mean
    return masked_mean(centred_values ** 2, mask)


def masked_whiten(values, mask):
    '''
    对数据进行带掩码的白化处理，
    让有效数据的方差变为1，但均值保持不变
    '''
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    whitened += mean
    return whitened


def compute_advantage(rewards, values, masks):
    '''
    广义优势估计（GAE）
    '''
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95

    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)

    returns = advantages + values
    return advantages, returns

In [24]:
advantages, returns = compute_advantage(rewards, values, masks)
print(advantages[0])
print(returns[0])

tensor([ 0.2937,  0.3265,  0.3611,  0.3975,  0.4358, -0.9009,  0.4619, -0.0743,
         0.0853, -0.6937, -2.3060,  0.4204, -0.9437, -1.4515, -1.0439,  0.1875,
        -0.3929, -1.4160,  0.7329, -0.3302])
tensor([ 0.2937,  0.3265,  0.3611,  0.3975,  0.4358,  1.0532,  0.4394,  0.7233,
         0.6755,  1.0331,  1.6816,  0.3914,  1.0192,  1.1862,  0.9316,  0.3622,
         0.6441,  1.0682,  0.0865, -0.3302])


In [25]:
import numpy as np

learning_rate = 1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# 随机排列一下各个批次大小
np.random.permutation(batch_size)

array([28, 13,  4, 21, 18, 29, 30, 14,  7,  5,  0, 12,  1, 16,  8, 26, 22,
       23, 24,  2, 11, 20, 15, 27,  3, 19, 31,  6, 25, 10, 17,  9])

In [26]:
# 最小的批次大小
mini_batch_size = 4
# 训练 4 个 epoch
ppo_epochs = 4
# ε = 0.2
cliprange_ratio = 0.2

v_loss_coeff = 0.1
# 比例的阈值
ratio_threshold = 10


def compute_loss(
        old_logprobs,
        values,
        logprobs,
        vpreds,
        masks,
        advantages,
        returns
):
    ratio = torch.exp(logprobs - old_logprobs)
    pg_loss1 = - ratio * advantages
    pg_loss2 = - torch.clamp(
        ratio,
        1 - cliprange_ratio,
        1 + cliprange_ratio
    ) * advantages
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)

    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    loss = pg_loss + v_loss_coeff * v_loss

    avg_ratio = masked_mean(ratio, masks)
    if avg_ratio > ratio_threshold:
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0

    return loss, v_loss


def mini_batch_train():
    # 过滤掉输入数据为空的批次
    if input_data['input_ids'].shape[0] == 0:
        return
    for ep in range(ppo_epochs):
        batch_inds = np.random.permutation(batch_size)

        for start in range(0, batch_size, mini_batch_size):
            end = start + mini_batch_size
            mini_batch_inds = batch_inds[start:end]

            mb_model_inputs = {
                'input_ids': input_data['input_ids'][mini_batch_inds],
                'attention_mask': input_data['attention_mask'][mini_batch_inds]
            }
            mb_logits, mb_vpreds = model(**mb_model_inputs)
            mb_logits = torch.nn.functional.log_softmax(
                mb_logits[:, :-1, :],
                dim=-1
            )
            mb_logprobs = torch.gather(
                mb_logits,
                2,
                mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)
            ).squeeze(-1)

            loss, loss_v = compute_loss(
                logprobs[mini_batch_inds],
                values[mini_batch_inds],
                mb_logprobs,
                mb_vpreds[:, :-1],
                masks[mini_batch_inds],
                advantages[mini_batch_inds],
                returns[mini_batch_inds]
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('loss/total', loss.item())
    print('mini-batch training finished')

In [27]:
mini_batch_train()

loss/total 1.727279543876648
loss/total 0.7686131596565247
loss/total 1.262438416481018
loss/total 1.662282943725586
loss/total 1.1627838611602783
loss/total 0.7339403033256531
loss/total 1.5563760995864868
loss/total 0.7673659324645996
loss/total 1.1146233081817627
loss/total 0.8554384708404541
loss/total 0.7849952578544617
loss/total 0.8114762306213379
loss/total 0.8627503514289856
loss/total 0.8448833227157593
loss/total 1.2873870134353638
loss/total 1.0911635160446167
loss/total 0.7411577105522156
loss/total 0.9025058746337891
loss/total 0.7879412770271301
loss/total 0.4668281674385071
loss/total 1.384494423866272
loss/total 0.7904667854309082
loss/total 1.183447003364563
loss/total 1.2399688959121704
loss/total 1.110142469406128
loss/total 0.3123805820941925
loss/total 0.7012345790863037
loss/total 1.2559528350830078
loss/total 0.9507017135620117
loss/total 0.4577327072620392
loss/total 0.9585307240486145
loss/total 1.0084588527679443
mini-batch training finished


In [None]:
num_epochs = 1

for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Generate responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']

        response_tensors = []
        query_response_tensors = []
        score_tensors = []

        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = random.choice(list(range(
                output_min_length,
                output_max_length)))
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
            ).squeeze(0)

            response_len = len(query_response) - len(query)
            response_tensors.append(query_response[-response_len:])
            query_response_tensors.append(query_response)

            with torch.no_grad():
                query_response_score = torch.cat([
                    query_response,
                    torch.tensor([REWARD_TOKEN_ID]).to(device)])
                attention_mask = torch.ones_like(
                    query_response_score,
                    dtype=torch.long)
                score = reward_model(
                    query_response_score.unsqueeze(0),
                    attention_mask.unsqueeze(0)
                ).squeeze(0)[-1]
                score = 2 * (score - 0.5)
            score_tensors.append(score)

        input_data = data_collator([
            {
                'input_ids': ids,
                'attention_mask': torch.ones_like(ids)
            }
            for ids in query_response_tensors
        ]).to(device)

        # 奖励和优势
        logprobs, rewards, values, masks = compute_rewards(
            input_data,
            query_tensors,
            response_tensors,
            score_tensors
        )
        advantages, returns = compute_advantage(rewards, values, masks)

        # 小批次训练
        mini_batch_train()
    print(f'epoch {epoch + 1} finished')

In [29]:
print(len(tokenized_dataset_val))
val_gen_lengths = [0] * len(tokenized_dataset_val)
for i in range(len(tokenized_dataset_val)):
    val_gen_lengths[i] = random.choice(list(range(
        output_min_length,
        output_max_length)))
val_gen_lengths[:10]

807


[7, 8, 15, 12, 14, 5, 5, 9, 9, 7]

In [30]:
def validate(model):
    scores = []
    for b, batch in enumerate(val_dataloader):
        # Generate_responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']
        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = val_gen_lengths[b * len(query_tensors) + i]
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
            ).squeeze(0)
            query_response_score = torch.cat([
                query_response,
                torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(
                query_response_score, dtype=torch.long)
            score = reward_model(
                query_response_score.unsqueeze(0),
                attention_mask.unsqueeze(0)
            ).squeeze(0)[-1]
            score = 2 * (score - 0.5)
            scores.append(score.item())
    print('平均分数:', sum(scores) / len(scores))

In [31]:
validate(model)

平均分数: 0.5725250226414337


In [32]:
torch.save(model.state_dict(), str(config.PPO_MODEL_PATH))

In [33]:
validate(sft_model)

平均分数: 0.11989582124105026
