# RLHF-PPO Pytorch Implemention

Author: *xiaodongguaAIGC*

Github: *dhcode-cpp*

# define basic function

## config

In [1]:
vocab_size = 32
hidden_size = 256
intermediate_size = 512
num_hidden_layers = 2
num_attention_heads = 4
num_key_value_heads = 4
batch_size = 2
length_x = 6
max_new_tokens = 5

## pretrained model

In [2]:
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification

torch.manual_seed(1)

# 加载模型
config = LlamaConfig(vocab_size=vocab_size,      # default is 32000
                     hidden_size=hidden_size,
                     intermediate_size=intermediate_size,
                     num_hidden_layers=num_hidden_layers,
                     num_attention_heads=num_attention_heads,
                     num_key_value_heads=num_key_value_heads,
                     )
model = LlamaForCausalLM(config)
model.config.pad_token_id = model.config.eos_token_id
model.save_pretrained('./lm_pretrained')

In [5]:
x = torch.randint(0, vocab_size, (batch_size, length_x))
xy = torch.randint(0, vocab_size, (batch_size, length_x + max_new_tokens))
print(x.shape)
print(x)

In [6]:
mask_zero = torch.zeros(batch_size, length_x)
mask = torch.ones(batch_size, length_x+max_new_tokens)
mask[:, :length_x] = mask_zero
print(mask)

## lm model generate

https://huggingface.co/docs/transformers/main_classes/text_generation

In [7]:
def get_generate(model, x, max_new_tokens):
    idx = {'input_ids': x}  # ignore mask
    y = model.generate(**idx,
                       max_new_tokens=max_new_tokens,
                       forced_eos_token_id=True)
    return y


print('input:', x)
gen_xy = get_generate(model, x, max_new_tokens)
print('lm output:\n', gen_xy)
print('lm output shape:\n', gen_xy.shape)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


## reward model

In [8]:
model_rm = LlamaForSequenceClassification.from_pretrained(
    './lm_pretrained', num_labels=1)
print(model_rm)

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at ./lm_pretrained and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## rm get reward

In [11]:
def get_reward(rm_model, x):
    idx = {'input_ids': x}  # ignore mask
    y = model_rm(**idx).logits
    return y


print(gen_xy)
rm_xy = get_reward(model_rm, gen_xy)
print(rm_xy)

# PPO Pipeline

### ppo config

In [12]:
class PPOConfig():
    def __init__(self):
        self.ppo_epochs = 3
        self.mini_batch_size = 1
        self.epochs = 2
        self.kl_ctl = 0.01
        self.vl_coef = 0.01
        self.lam = 0.9
        self.gamma = 0.9
        self.cliprange_value = 0.2
        self.cliprange = 0.2
        # self.param = {'gamma':0.9, 'lambda':0.9,}

    def __str__(self):
        return f'ppo_epochs:{self.ppo_epochs}\nmini_batch_size:{self.mini_batch_size}\nepochs:{self.epochs}\nkl_ctl:{self.kl_ctl}'


ppo_config = PPOConfig()
print(ppo_config)

### PPO models

In [13]:
# model -> [sft] -> model_ref
# model_ref -> model_rm
# model_ref -> model_actor
# model_rm -> model_critic
model_ref = LlamaForCausalLM(config)
# model_critic = LlamaForSequenceClassification.from_pretrained('./lm_pretrained', num_labels=1)

### model critic

In [14]:
config = LlamaConfig(vocab_size=vocab_size,      # default is 32000
                     hidden_size=hidden_size,
                     intermediate_size=intermediate_size,
                     num_hidden_layers=num_hidden_layers,
                     num_attention_heads=num_attention_heads,
                     num_key_value_heads=num_key_value_heads,
                     )
model_base = LlamaForCausalLM(config)
model_base.config.pad_token_id = model_base.config.eos_token_id


class ModelValueHead(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.dropout = torch.nn.Dropout(0.05)
        self.summary = torch.nn.Linear(model.config.hidden_size, 1)

    def forward(self, xy):
        hidden_states = model(**xy, output_hidden_states=True).hidden_states
        last_hidden_states = hidden_states[-1]
        # print(last_hidden_states.shape)
        output = self.dropout(last_hidden_states)
        output = self.summary(output)[:, :, 0]
        return output


def get_values(model, xy):
    idx = {'input_ids': xy}
    value = model(idx)
    return value


model_critic = ModelValueHead(model_base)

print(xy)
value = get_values(model_critic, xy)
print(value)

In [15]:
class PPOModels():
    def __init__(self, model_actor, model_ref, model_rm, model_critic):
        self.actor = model_actor
        self.ref = model_ref
        self.rm = model_rm
        self.critic = model_critic


model_ref.eval()
model_rm.eval()
models = PPOModels(model, model_ref, model_rm, model_critic)

## PPO Pipeline

In [16]:
def void_ppo_step(models, ppo_config, x, xy, rewards):
    # just for show pipeline
    return 0.1, 0.2

In [17]:
def ppo_train(models,
              ppo_config,
              ppo_step,
              x):
    for i in range(ppo_config.epochs):
        # use dataloader get batch
        print(x)
        xy = get_generate(models.actor, x, max_new_tokens)
        print(xy)
        rewards = get_reward(models.rm, xy)
        print(rewards)
        loss, reward = ppo_step(models, ppo_config, x, xy, rewards)
    return loss, reward


loss, reward = ppo_train(models, ppo_config, void_ppo_step, x)
print(loss)
print(reward)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


## ppo_step pipeline

1. make experience
2. ppo_train

### define ppo experience dict

In [18]:
ppo_batchs = {
    'prompt': None,
    'sequence': None,
    'mask': None,
    'logprobs_ref': None,
    'logprobs_old': None,
    'logprobs': None,
    'values_old': None,
    'values': None,
    'rewards': None,
    'rewards_kl': None,
    'loss': None,
    'logits': None,
}

ppo_batchs['prompt'] = x
ppo_batchs['sequence'] = xy
ppo_batchs['mask'] = mask

### Forward Get Policy

In [19]:
def logprobs_from_logits(logits, labels, gather=True):
    logp = F.log_softmax(logits, dim=2)
    logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logpy, logp


# get policy
logits = torch.randn(1, 5, vocab_size)
labels = torch.tensor([[3, 2, 4, 3, 2]])
logpy, logp = logprobs_from_logits(logits, labels)

# get prob
idx = 2
print('one of logits:', logits[:, idx, :])
print('one of lobp:', logp[:, idx, :])
print('logp get labels[-1]', logp[:, idx, labels[:, idx]])
print('one of logpy:', logpy[:, idx])

In [20]:
def get_logits(model, xy):
    idx = {'input_ids': xy}  # ignore mask
    logits = model(**idx).logits
    return logits


print(xy)
logits = get_logits(models.actor, xy)
print(logits.shape)
probs, _ = logprobs_from_logits(logits, xy, True)
print(probs.shape)

### PPO Forward Code

In [21]:
def batch_forward(model_policy, model_value, xy):
    logits = get_logits(model_policy, xy)
    logprobs, _ = logprobs_from_logits(logits, xy, True)
    values = None
    if model_value != None:
        values = get_values(model_value, xy)
    return logprobs, values, logits


with torch.no_grad():
    logprobs_ref, _, _ = batch_forward(models.ref, None, xy)
logprobs_old, values_old, _ = batch_forward(models.actor, models.critic, xy)

print(logprobs_ref.shape)
print(logprobs_old.shape)
print(values_old.shape)

ppo_batchs['logprobs_ref'] = logprobs_ref
ppo_batchs['logprobs_old'] = logprobs_old
ppo_batchs['values_old'] = values_old

### reward with KL

In [33]:
def compute_rewards_kl(
    reward,
    ref_logprobs,
    old_logprobs,
    kl_ctl,
):
    kl = logprobs_old[:, :] - logprobs_ref[:, :]
    kl_reward = -kl_ctl * kl
    kl_reward[:, -1] += reward[:, 0]
    return kl_reward


kl_ctl = 0.01
score = 100.0  # obeserver the value
kl = logprobs_old - logprobs_ref
reward = -kl_ctl * kl
print('only kl')
print(reward[0, :])
print(reward[0, -1])
print('reward with kl')
reward[:, -1] += (rm_xy[:, 0] + score)
print(reward[0, :])
print(reward[0, -1])

print('compute rewards with kl for PPO')
rm_xy = torch.randn(2, 1)
reward_kl = compute_rewards_kl(rm_xy, logprobs_ref, logprobs_old, 0.01)
print(reward_kl)
ppo_batchs['rewards'] = rm_xy
ppo_batchs['rewards_kl'] = reward_kl

about reward tips

pad token id = 1

input :

x1 : 3, 4, 6, 8, 20, 29, 30, 2

x2 : 8, 2, 6, 9, 13, 2,  1, 1

x1 mask:

x1 : 1, 1, 1, 1, 1,  1,  1,  1

x2 mask:

x1 : 1, 1, 1, 1, 1,  1,  0,  0

KL+rewards

x1 : 1, 1, 1, 1, 1,  1,  1,  [1]+rewards

x1 : 1, 1, 1, 1, 1,  [1]+rewards,  0,  0

### data before ppo_train

In [34]:
'''
1. we use forward get policy,policy_old,value,
2. we compute reward with kl
'''


def ppo_prepare(models, ppo_config, xy, rm_xy):
    with torch.no_grad():
        logprobs_ref, _ = batch_forward(models.ref, None, xy)
    logprobs_old, values_old = batch_forward(models.actor, models.critic, xy)
    reward_kl = compute_rewards_kl(
        rm_xy, logprobs_ref, logprobs_old, ppo_config.kl_ctl)
    return logprobs_old, values_old, logprobs_ref, reward_kl

# logprobs_old, values_old, logprobs_ref, reward_kl = ppo_prepare(models, ppo_config, xy, rm_xy)


ppo_batchs


[1m{[0m
[2;32m│   [0m[32m'prompt'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m [1;36m7[0m, [1;36m17[0m,  [1;36m6[0m,  [1;36m1[0m, [1;36m11[0m, [1;36m28[0m[1m][0m,
[2;32m│   │   [0m[1m[[0m[1;36m11[0m,  [1;36m5[0m, [1;36m30[0m, [1;36m22[0m, [1;36m18[0m,  [1;36m5[0m[1m][0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'sequence'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m [1;36m3[0m,  [1;36m2[0m, [1;36m21[0m, [1;36m16[0m, [1;36m31[0m, [1;36m21[0m, [1;36m18[0m, [1;36m27[0m,  [1;36m8[0m, [1;36m11[0m,  [1;36m4[0m[1m][0m,
[2;32m│   │   [0m[1m[[0m [1;36m7[0m,  [1;36m3[0m, [1;36m11[0m, [1;36m27[0m, [1;36m13[0m,  [1;36m4[0m, [1;36m19[0m, [1;36m29[0m, [1;36m26[0m,  [1;36m3[0m, [1;36m23[0m[1m][0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'mask'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m., [1;36m0[0m., [1;36m0[0m., [1;36m0[0m., [1;36m1[0m., [1;36m1[0m

# PPO Train Step

In [35]:
# if batch size is 32
# the mini_bacth_size = 8
# like data loader would run 4 step
def get_minibatch(ppo_batchs, batch_size, mini_batch_size):
    step = batch_size // mini_batch_size
    ppo_batchs_iter = []
    print(step)
    for i in range(step):
        ppo_batchs_iter.append(ppo_batchs)
    return ppo_batchs_iter


def void_compute_loss(models, mini_batchs, ppo_config):
    print('roll mini batch to compute ppo loss')
    print(mini_batchs)
    loss = 1.0
    pl_loss = 2.0
    vl_loss = 3.0
    loss = pl_loss + ppo_config.vl_coef * vl_loss
    return loss, pl_loss, vl_loss


def ppo_train_step(models, ppo_batchs, ppo_config, compute_loss):
    losses = []
    for i in range(ppo_config.ppo_epochs):
        ppo_batchs_iter = get_minibatch(
            ppo_batchs, batch_size, ppo_config.mini_batch_size)
        for mini_batchs in ppo_batchs_iter:
            # get training policy
            logprobs, values, logits = batch_forward(
                models.actor, models.critic, mini_batchs['sequence'])
            mini_batchs['logprobs'] = logprobs
            mini_batchs['values'] = values
            mini_batchs['logits'] = logits
            # compute loss
            loss, pl_loss, vl_loss = compute_loss(
                models, mini_batchs, ppo_config)
            # loss.backward()
            losses.append(loss)
            break  # just for debug,
        break  # just for debug,
        mini_batchs['loss'] = losses
    return losses


loss = ppo_train_step(models, ppo_batchs, ppo_config, void_compute_loss)
print(torch.tensor(loss).mean())

### Compute ppo loss

when calulative ppo loss in TOKEN-LEVEL
PPO loss pipeline:
1. get GAE
2. get Critic Value loss
3. get Policy loss
4. get entropy loss

#### Get GAE

In [36]:
def get_GAE(rewards, mask, values, gamma, lam):
    lastgaelam = 0
    advantages_reversed = []
    gen_len = rewards.shape[-1]
    print(gen_len)

    values = values * mask
    rewards = rewards * mask

    for t in reversed(range(gen_len)):
        nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgaelam = delta + gamma * lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
    return advantages


# for oberserve gae data
mini_batchs = ppo_batchs.copy()
tmp_reward = torch.ones_like(mini_batchs['rewards_kl']) * 10.0
tmp_reward[:, :-1] = 0.0
print(tmp_reward)
tmp_value = torch.ones_like(mini_batchs['values_old']) * 0.0

result = get_GAE(tmp_reward,
                 mini_batchs['mask'],
                 tmp_value,
                 0.9,
                 0.9)
print(result)


# for oberserve gae data
mini_batchs = ppo_batchs.copy()
print(mini_batchs['rewards_kl'])
tmp_reward = torch.ones_like(mini_batchs['rewards_kl']) * 10.0
tmp_reward[:, :-1] = 0.0
print(tmp_reward)
tmp_value = torch.ones_like(mini_batchs['values_old']) * 1.0

result = get_GAE(tmp_reward,
                 mini_batchs['mask'],
                 tmp_value,
                 0.9,
                 0.9)
print(result)

In [37]:
# use mini_batchs data
mini_batchs = ppo_batchs.copy()
GAE = get_GAE(mini_batchs['rewards_kl'],
              mini_batchs['mask'],
              mini_batchs['values'],
              ppo_config.gamma,
              ppo_config.lam)
print(GAE)

#### Get Critic Value loss

In [38]:
def masked_mean(values, mask, axis=None):
    if axis is not None:
        return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
    else:
        return (values * mask).sum() / mask.sum()


def clip_by_value(x, tensor_min, tensor_max):
    clipped = torch.max(torch.min(x, tensor_max), tensor_min)
    return clipped


def get_value_loss(advantages, values, values_old, mask, cliprange_value):
    returns = advantages + values_old
    advantages = advantages.detach()
    vpredclipped = clip_by_value(
        values,
        values_old - cliprange_value,
        values_old + cliprange_value,
    )

    vf_losses1 = (values - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2

    # print("vf_losses2:", vf_losses2)
    vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
    vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
    return vf_loss


vf_loss = get_value_loss(GAE,
                         mini_batchs['values'],
                         mini_batchs['values_old'],
                         mini_batchs['mask'],
                         ppo_config.cliprange_value)
print(vf_loss)

#### Get Policy loss

In [39]:
def get_policy_loss(logprobs, lobprobs_old, advantages, mask, cliprange):
    ratio = torch.exp(logprobs - lobprobs_old)

    pg_losses = -advantages * ratio
    pg_losses2 = -advantages * \
        torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)

    pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
    pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
    return pg_loss


pg_loss = get_policy_loss(mini_batchs['logprobs'],
                          mini_batchs['logprobs_old'],
                          GAE,
                          mini_batchs['mask'],
                          ppo_config.cliprange,)
print(pg_loss)

#### Get Entropy loss

In [40]:
def get_entropy_loss(logits, mask):
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, axis=-1) - \
        torch.sum(pd * logits, axis=-1)
    return entropy


entropy = get_entropy_loss(mini_batchs['logits'], mini_batchs['mask'])
print(entropy)

### gather all loss

In [41]:
def compute_loss(mini_batchs, ppo_config):

    GAE = get_GAE(mini_batchs['rewards_kl'],
                  mini_batchs['mask'],
                  mini_batchs['values'],
                  ppo_config.gamma,
                  ppo_config.lam)

    vl_loss = get_value_loss(GAE,
                             mini_batchs['values'],
                             mini_batchs['values_old'],
                             mini_batchs['mask'],
                             ppo_config.cliprange_value)

    pg_loss = get_policy_loss(mini_batchs['logprobs'],
                              mini_batchs['logprobs_old'],
                              GAE,
                              mini_batchs['mask'],
                              ppo_config.cliprange,)

    loss = pg_loss + ppo_config.vl_coef * vl_loss

    return loss, pg_loss, vl_loss


mini_batchs = ppo_batchs.copy()
loss, _, _ = compute_loss(mini_batchs, ppo_config)
print(loss)