In [1]:
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pylab as plt
from dataclasses import dataclass
from torchtyping import TensorType
from typing import Iterable, Sequence, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from datasets import load_dataset

from transformers import (
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
    set_seed,
)

from itertools import cycle
from random import randint

import gc


In [2]:
from dataclasses import dataclass

set_seed(2024)

@dataclass
class Config:
    model_name="lvwerra/gpt2-imdb"                          
    reward_model_name="EKKam/opt350m_imdb_sentiment_reward" #"lvwerra/distilbert-imdb"
    seq_length= 1024    # 1024 compatible with GPT2, Distilbert supports 512
    batch_size= 64      #128      # input to dataloader function
    lr= 7.41e-6         #6.41e-6    #9.41e-6     #1.41e-5         #0.00006
    prompt_size= 10     #8      #10     #30     # query length for the generator
    epochs = 1          # number of epochs
    mini_batch_size= 4      
    gen_kwargs = {'max_new_tokens': 40,     # reduce if run crashes for OOM
                  'top_k': 0,
                  'top_p': 1.0,
                  'do_sample': True
                  }
    kl_coef= 0.2    #0.01
    gamma= 1
    lam= 0.95
    cliprange= 0.2
    cliprange_value= 0.2
    vf_coef= 0.0001
    entropy_coef= 0.5


args = Config()


### Data Preparation

In [3]:
# padding_side=left for decoder-only architectures
# https://discuss.huggingface.co/t/the-effect-of-padding-side/67188/4
tokenizer = AutoTokenizer.from_pretrained(args.model_name,
                                          padding_side='left')

#tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # CUDA error: device-side assert triggered


1

Build a dataset to be used for the training.
It is a series of prompts, will be used to generate the responses and compute the rewards.

In [4]:
def build_dataset(args, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):

    # load the IMDB dataset
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})

    # Only choose reviews with more than 200 tokens
    ds = ds.filter(lambda x: len(x["review"]) > 50, batched=False)
    
    def tokenize(sample):
        # From each review just keep the first `input_size` tokens, this represents the prompt used to generate the response
        sample["input_ids"] = tokenizer(sample["review"], 
                                        padding='max_length', 
                                        max_length=args.prompt_size, 
                                        truncation=True).input_ids
        sample["attention_mask"] = torch.ones_like(torch.tensor(sample["input_ids"]))
        sample["query"] = tokenizer.decode(sample["input_ids"])
        
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset(args)



Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [5]:
def prepare_dataloader(dataset, data_collator=None):
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=data_collator,
        shuffle=True,
        drop_last=True,
    )
    return dataloader

#dataloader = iter(prepare_dataloader(dataset, data_collator=None))
dataloader = cycle(prepare_dataloader(dataset, data_collator=None))


### LLM as an RL Agent

In [6]:
class Agent(nn.Module):
    def __init__(self, trainable=False):
        super().__init__()
        self.trainable = trainable
        self.model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        # Gen params during training overrides these settings
        #self.model.generation_config.pad_token_id = tokenizer.pad_token_id
        #self.model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to
        #self.model.generation_config.pad_token_id = None  # generate tokens without truncation / padding

        if not self.trainable:
            self.model = self.model.eval()
            self.model.requires_grad_(False)
        else:
            n_embd = self.model.lm_head.in_features
            num_labels = 1
            self.value_head = nn.Sequential(
                nn.LayerNorm(n_embd),
                nn.GELU(),
                nn.Linear(n_embd, 4*n_embd),
                nn.GELU(),
                nn.Linear(4*n_embd, num_labels),
            ).to(torch.bfloat16).to(self.model.device)
        
        self.logit_head = self.model.get_output_embeddings()

    def generate(self, input_ids, **x):
        return self.model.generate(input_ids, **x)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids,
                             attention_mask=attention_mask,
                             output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        lm_logits = self.logit_head(last_hidden_state)
        if self.trainable:
            value = self.value_head(last_hidden_state).squeeze(-1)
            return lm_logits, value
        else:
            return lm_logits


ref_model = Agent(trainable=False)

model = Agent(trainable=True)


In [7]:
def logprobs_from_logits(logits, labels):
    logprobs = F.log_softmax(logits, dim=-1)
    logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))
    return logprobs_labels.squeeze(-1)


### Reward Model

In [8]:
reward_model = AutoModelForSequenceClassification.from_pretrained(
     args.reward_model_name,
     torch_dtype=torch.bfloat16,
     device_map="auto",
     num_labels=1,
)

reward_model = reward_model.eval()
reward_model = reward_model.to('cpu')


In [9]:
def reward_fn(samples):
    ins = tokenizer(samples, padding=True, truncation=True, max_length=args.seq_length, return_tensors='pt')
    logits = reward_model(**ins.to(reward_model.device)).logits
    temperature = 0.3
    sentiments = torch.sigmoid(logits*temperature)[:,0].detach().cpu().tolist()
    return sentiments



### Data Classes

In [10]:
@dataclass
class PPORLElement:
    query_tensor: TensorType["query_size"]
    response_tensor: TensorType["response_size"]
    logprobs: TensorType["response_size", "vocab_size"]
    values: TensorType["response_size"]
    rewards: TensorType["response_size"]


@dataclass
class PPORLBatch:
    query_tensors: TensorType["batch_size", "query_size"]
    response_tensors: TensorType["batch_size", "response_size"]
    logprobs: TensorType["batch_size", "response_size", "vocab_size"]
    values: TensorType["batch_size", "response_size"]
    rewards: TensorType["batch_size", "response_size"]


def create_loader(elems, mini_batch_size: int, shuffle: bool) -> DataLoader:
    def collate_fn(elems: Iterable[PPORLElement]):
        return PPORLBatch(
            pad_sequence(
                [elem.query_tensor for elem in elems],
                padding_value=tokenizer.pad_token_id,
                batch_first=True,
            ),
            pad_sequence(
                [elem.response_tensor for elem in elems],
                padding_value=tokenizer.pad_token_id,
                batch_first=True,
            ),
            pad_sequence(
                [elem.logprobs for elem in elems],
                padding_value=tokenizer.pad_token_id,
                batch_first=True,
            ),
            pad_sequence(
                [elem.values for elem in elems],
                padding_value=tokenizer.pad_token_id,
                batch_first=True
            ),
            pad_sequence(
                [elem.rewards for elem in elems],
                padding_value=tokenizer.pad_token_id,
                batch_first=True,
            ),
        )

    return DataLoader(elems, mini_batch_size, shuffle=shuffle, collate_fn=collate_fn)



### Losses

In [11]:
def gae(
    values,
    rewards,
):
    advantages = torch.zeros_like(rewards, device=rewards.device)
    last_advantage = 0
    last_value = 0
    with torch.no_grad():
        for t in reversed(range(rewards.shape[1])):
            delta = rewards[:, t] + args.gamma * last_value - values[:, t]
            last_advantage = delta + args.gamma * args.lam * last_advantage
            advantages[:, t] = last_advantage
            last_value = values[:, t]
        returns = advantages + values
    return advantages, returns


In [12]:
# Source: https://github.com/huggingface/trl/blob/main/trl/core.py

def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    """Calculate entropy from logits."""
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
    return entropy


In [13]:
# masked_mean: https://github.com/huggingface/trl/blob/main/trl/core.py
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
    """Compute mean of tensor with a masked values."""
    if axis is not None:
        return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
    else:
        return (values * mask).sum() / mask.sum()
    

In [14]:
def ppo_loss(
    logits,
    logprobs,
    values,
    old_logprobs,
    old_values,
    advantages,
    returns,
    mask,
):

    ### Value Loss
    values_clipped = torch.clamp(values,
                                 old_values - args.cliprange_value,
                                 old_values + args.cliprange_value,)
    
    n = mask.sum()
    
    # As per the paper without clipping
    # vf_loss = 0.5 * torch.sum(vf_loss * mask) / n

    vf_loss1 = (values - returns) ** 2
    vf_loss2 = (values_clipped - returns) ** 2
    #vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n
    vf_loss = 0.5 * masked_mean(torch.max(vf_loss1, vf_loss2), mask)
    vf_loss = args.vf_coef * vf_loss
    #print("vf_loss: ", vf_loss)
    

    ### Entropy Loss
    # The entropy to force the model to explore
    # entropy_loss = torch.sum(entropy_from_logits(logits) * mask) / n
    entropy_loss = masked_mean(entropy_from_logits(logits), mask)
    entropy_loss = args.entropy_coef * entropy_loss
    #print("entropy_loss: ", entropy_loss)


    ### Policy Gradient Loss
    # Ratio between the log probability of the new policy and the old policy
    log_ratio = (logprobs - old_logprobs) * mask
    ratio = torch.exp(log_ratio)
    
    # "minus" sign : to maximize the objective function since torch optimizer minimizes the loss by default
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), mask)
    #pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n
    #print("policy loss: ", pg_loss)


    ### Total Loss
    loss = pg_loss + vf_loss + entropy_loss
    

    del(logits, logprobs, values, old_logprobs, old_values, advantages, returns, mask,
        values_clipped, vf_loss1, vf_loss2, vf_loss, entropy_loss, pg_loss1, pg_loss2, pg_loss)
    torch.cuda.empty_cache()


    return loss


In [15]:
def loss_fn(mini_batch):
    query_tensors = mini_batch.query_tensors
    response_tensors = mini_batch.response_tensors
    old_logprobs = mini_batch.logprobs
    old_values = mini_batch.values
    old_rewards = mini_batch.rewards

    response_length = old_rewards.shape[1]

    advantages, returns = gae(old_values, old_rewards)

    trajectories = torch.hstack([mini_batch.query_tensors, mini_batch.response_tensors])
    attention_mask = trajectories.not_equal(tokenizer.pad_token_id).long()

    logits, values_pred = model(trajectories, attention_mask=attention_mask)

    values_pred = values_pred[:, :-1]
    logprobs = logprobs_from_logits(logits[:, :-1, :], trajectories[:, 1:])
    attention_mask = attention_mask[:, :-1]

    start = query_tensors.shape[1] - 1
    end = start + response_length
    logits, logprobs, values_pred, mask = (
        logits[:, start:end, :],
        logprobs[:, start:end],
        values_pred[:, start:end],
        attention_mask[:, start:end],
    )

    loss = ppo_loss(
        logits=logits,
        logprobs=logprobs,
        values=values_pred,
        old_logprobs=old_logprobs,
        old_values=old_values,
        advantages=advantages,
        returns=returns,
        mask=mask,
    )


    del (query_tensors, response_tensors, old_logprobs, old_values, returns,
         advantages, trajectories, attention_mask, logits, values_pred, logprobs)
    torch.cuda.empty_cache()


    return loss, old_rewards[:,-1].mean().item()


### Inference before alignment

In [None]:
ins = tokenizer(
    ["This is an action Western.", "I saw this movie recently because"],
    return_tensors='pt',
    padding=True)

with torch.no_grad():
    outs = model.generate(
        ins['input_ids'].to(ref_model.model.device),
        attention_mask=ins['attention_mask'].to(ref_model.model.device),
        max_new_tokens=128,
        do_sample=True,
        temperature=0.5,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=2.,
    )

for i in range(len(outs)):
    generated_text = tokenizer.decode(outs[i], skip_special_tokens=True)
    print("\n" + "\033[1;30m" + generated_text)
    print("\033[1;32m" +'Score: ', np.round(reward_fn([generated_text]), decimals=4)[0], "\n")


### Training Loop

In [17]:
opt = torch.optim.AdamW(model.parameters(), args.lr)


In [None]:
all_scores = []
iter_in_a_epoch = 380   # for batch_size = 64

for iter, batch in tqdm(enumerate(dataloader)):

    args.gen_kwargs['max_new_tokens'] = randint(8, 16)      #randint(20, 40)
    
    generate_kwargs = dict(
            args.gen_kwargs,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        )

    #### Phase 1: Get trajectories from the offline policy
    query_tensors = batch["input_ids"].to(model.model.device)
    attention_mask = batch['attention_mask'].to(model.model.device)

    trajectories = model.generate( query_tensors,
                                attention_mask=attention_mask,
                                **generate_kwargs)

    response_tensors = trajectories[:, query_tensors.shape[1]:]

    attention_mask = trajectories.not_equal(tokenizer.eos_token_id).long()

    ref_model = ref_model.to('cuda')
    with torch.no_grad():
        logits, values = model(
            trajectories,
            attention_mask=attention_mask,
        )

        ref_logits = ref_model(
            trajectories,
            attention_mask=attention_mask,
        )
    ref_model = ref_model.to('cpu')

    logprobs = logprobs_from_logits(logits[:, :-1, :], trajectories[:, 1:])
    ref_logprobs = logprobs_from_logits(ref_logits[:, :-1, :], trajectories[:, 1:])
    n_trajectories = trajectories.shape[0]
    values = values[:, :-1]

    start = batch['input_ids'].shape[1] - 1
    ends = start + attention_mask[:, start:].sum(1)
    truncated_values = [values[i, start : ends[i]] for i in range(n_trajectories)]
    truncated_logprobs = [logprobs[i, start : ends[i]] for i in range(n_trajectories)]

    del logits, ref_logits, values
    torch.cuda.empty_cache()

    texts = tokenizer.batch_decode(trajectories, skip_special_tokens=True)
    trajectories = trajectories.detach().cpu()

    reward_model= reward_model.to('cuda')
    scores = reward_fn(texts)
    reward_model = reward_model.to('cpu')
    
    kl_div = -args.kl_coef * (logprobs - ref_logprobs)
    all_rewards = [None] * n_trajectories
    for i in range(n_trajectories):
        rs = kl_div[i][start : ends[i]]
        rs[-1] += scores[i]
        all_rewards[i] = rs

    new_rollout = [
        PPORLElement(
            query_tensor=query_tensors[i],
            response_tensor=response_tensors[i],
            logprobs=truncated_logprobs[i],
            values=truncated_values[i],
            rewards=all_rewards[i],
        )
        for i in range(n_trajectories)
    ]

    score = torch.tensor(scores).mean().detach().cpu().item()

    all_scores.append(score)

    if (iter % 95 == 0): 
        epoch = iter // 380
        rem = iter % 380
        print(f"Epoch: {epoch}, Iteration: {rem}, Reward: {score:.3f}")
        print("Length of trajectory: ", len(trajectories[5]))
        print("Generated tokens: ", trajectories[0, start:])
        print("="*90)

    train_loader = create_loader(new_rollout, args.mini_batch_size, shuffle=True)


    del (query_tensors, response_tensors, attention_mask, trajectories, logprobs, 
         ref_logprobs, truncated_logprobs, truncated_values, texts, scores, kl_div,
         all_rewards, new_rollout)
    torch.cuda.empty_cache()


    #### Phase 2: loss calculation and PPO Update
    for i, mini_batch in enumerate(train_loader):
        model.train()
        loss, reward = loss_fn(mini_batch)
        loss.backward()
        opt.step()
        opt.zero_grad()


    del train_loader
    torch.cuda.empty_cache()
    gc.collect()

    if (iter / iter_in_a_epoch == args.epochs):
        break


In [None]:
plt.plot(all_scores)


### Inference after alignment

In [None]:
ins = tokenizer(
    ["This is an action Western.", "I saw this movie recently because"],
    return_tensors='pt',
    padding=True
)

with torch.no_grad():
    outs = model.generate(
        ins['input_ids'].to(model.model.device),
        attention_mask=ins['attention_mask'].to(model.model.device),
        max_new_tokens=128,
        do_sample=True,
        temperature=0.5,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=2.,
    )

    print("Length of the outputs: ", len(outs[0]), len(outs[1]))
    print("Out-0 in encoded form:", outs[0])
    print("Out-1 in encoded form:", outs[1])

    for i in range(len(outs)):
        generated_text = tokenizer.decode(outs[i], skip_special_tokens=True)
        print("\n" + "\033[1;30m" + generated_text)
        print("\033[1;32m" +'Score: ', np.round(reward_fn([generated_text]), decimals=4)[0], "\n")



## References:

1. [Paper - Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)
2. [TRL - PPO Trainer](https://github.com/huggingface/trl)
3. [Reinforcement Learning from Human Feedback: From Zero to chatGPT](https://www.youtube.com/watch?v=2MBJOuVq380)
4. [Coding chatGPT from Scratch | Lecture 3: Full Pipeline](https://www.youtube.com/watch?v=11M_kfuPJ5I)
5. [Reinforcement Learning from Human Feedback explained with math derivations and the PyTorch code](https://www.youtube.com/watch?v=qGyFrqc34yc)
6. [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
