### Notebook 6.1: RLHF (Reinforcement Learning with Human Feedback) from Scratch 🚀  

Welcome back to the series! 🎉 In this notebook, we’ll dive into **Reinforcement Learning with Human Feedback (RLHF)**, taking a pretrained **GPT-2 model fine-tuned on IMDb movie reviews** and aligning it further using RLHF principles. Our goal is to understand and implement how RLHF can be used to generate text that aligns with human preferences for sentiment and style.  

### What’s the Goal? 🏆  

By the end of this notebook, you will:  
1. Gain a foundational understanding of **RLHF** and its importance in aligning language models with human values.  
2. **Build a reward model** to evaluate model outputs based on human preferences.  
3. **Explore TRL (Transformers Reinforcement Learning)** from Hugging Face:  
   - Dive into the library’s source code to understand **what’s happening under the hood**.  
   - Use TRL as a reference, not as a black box, ensuring we grasp the mechanics before applying it.  
4. **Implement PPO** (Proximal Policy Optimization) to fine-tune GPT-2 efficiently.  
5. Align GPT-2 to handle sentiment control and stylistic alignment based on IMDb reviews.  

<p align="center">
    <img src="images/RLHF.png" alt="PEFT Overview" />
</p>

### Why Not Reinvent the Wheel? 🛠️  

While the goal is to explore RLHF concepts from scratch, implementing everything manually is impractical due to the inherent **instability of reinforcement learning training**. Instead, we will leverage the **TRL library by Hugging Face**, which provides a robust implementation of RLHF.  

However, **we won’t use TRL as a black box**. Instead, we’ll dive into the source code, observe its internals, and understand every component step by step. Only after understanding how TRL operates will we apply it to our GPT-2 fine-tuning task. This ensures a balance of theoretical understanding and practical efficiency.  

<p align="center">
    <img src="images/trl.png" alt="PEFT Overview" />
</p>

### What’s Inside? 🔍  

#### **1: Introduction to RLHF** 🧠  

#### **2: Setting Up GPT-2 with IMDb Fine-Tuning** 🎥  

#### **3: Building the Reward Model** 🏗️  

#### **4: Exploring the TRL Library** 🛠️  

#### **5: Implementing PPO for RLHF** 🤖  

#### **6: RLHF Training Loop** 🔄  

#### **7: Evaluation and Analysis** 📊  

### A Word of Advice Before You Begin  

This notebook dives deep into **Reinforcement Learning**, **Proximal Policy Optimization**, and reward model training—each a substantial topic on its own. While we’ll guide you step-by-step, it’s helpful to review RL fundamentals beforehand. Alternatively, you can follow along and revisit concepts as needed.  

Let’s embrace this challenging yet rewarding journey of implementing RLHF with both theoretical rigor and practical efficiency! 🚀  

## Introduction to RLHF (Reinforcement Learning with Human Feedback)  

**Reinforcement Learning with Human Feedback (RLHF)** is a powerful technique to align language models, like GPT-2, with specific human preferences. Unlike traditional fine-tuning, RLHF integrates human feedback to guide the model’s behavior. This ensures that generated outputs not only make sense but also meet user expectations regarding sentiment, tone, or any other quality criteria.  

### Why RLHF? 🤔  

Language models like GPT-2 are pretrained to predict the next token in a sequence, giving them broad generalization capabilities. However, they might not inherently align with specific human values or preferences.  

For instance, a fine-tuned GPT-2 model trained on IMDb reviews may produce outputs spanning various sentiments—positive, neutral, or negative. But what if we want the model to generate only **positive reviews**? RLHF allows us to:  
- Tailor outputs to a desired sentiment.  
- Incorporate feedback dynamically, guiding the model to improve during training.  
- Balance specific preferences without sacrificing fluency or coherence.  

### The Scenario: Generating Positive IMDb Reviews 🌟  

Let’s use RLHF to train a GPT-2 model fine-tuned on IMDb reviews to generate **positive movie reviews** exclusively. Here’s the process:  

1. The baseline GPT-2 generates an output based on a given prompt, but the sentiment is not guaranteed to be positive.  
2. A **reward model** evaluates the sentiment of the generated review:  
   - Positive reviews receive **higher rewards**.  
   - Negative or neutral reviews receive **lower rewards**.  
3. Using reinforcement learning (specifically **PPO**), GPT-2 updates its weights to align with the reward model's preferences, producing increasingly positive outputs over time.  

### The RLHF Workflow  

1. **Pretrained Model**: Start with GPT-2 fine-tuned on IMDb reviews (already handled in earlier steps).  

2. **Dataset for RLHF**:  
   - Create pairs of movie reviews with human feedback indicating preferred outputs.  
   - For example, a positive review is marked as "better" compared to a neutral or negative one.  

3. **Reward Model**: Train a model that evaluates generated reviews and assigns rewards based on sentiment alignment.  

4. **Policy and PPO Algorithm**:  
   - Fine-tune GPT-2 (the **policy model**) using the reward model’s feedback.  
   - Use **Proximal Policy Optimization (PPO)** to stabilize updates and maintain fluency.  

<p align="center">
    <img src="images/rlhf2.jpg" alt="PEFT Overview" />
</p>
   
### A Practical Example  

Here’s how RLHF can refine outputs:  

1. **Baseline Model Output**:  
   *Prompt*: *"The movie was a unique experience because..."*  
   - *Output*: "The movie was a unique experience because the plot was dull and the pacing was tedious."  

2. **Reward Model Evaluation**:  
   - Reward: Low (due to negative sentiment).  

3. **PPO Adjustment**:  
   - Adjust GPT-2 weights to produce outputs with higher rewards.  

4. **Post-RLHF Output**:  
   *Prompt*: *"The movie was a unique experience because..."*  
   - *Output*: "The movie was a unique experience because the plot was captivating and the pacing kept me on the edge of my seat."  

<p align="center">
    <img src="images/gpts_rlhf.png" alt="GPT2_RLHF" />
</p>


### Why Not Just Fine-Tune on Positive Reviews?  

Simply fine-tuning GPT-2 on positive reviews alone introduces biases but doesn’t guarantee nuanced alignment. RLHF is more effective because it:  
- Provides dynamic adaptation through reinforcement learning.  
- Penalizes outputs straying from natural language realism (via KL divergence regularization).  
- Balances sentiment alignment with fluency and coherence.  

### What’s Next in This Notebook?  

In this notebook, we will:  
1. **Build the Dataset**: Prepare IMDb data for RLHF.  
2. **Load Models**:  
   - A frozen GPT-2 model for KL divergence regularization.  
   - A GPT-2 policy model for training with PPO.  
3. **Create a Reward Model**: A sentiment evaluator assigning scores to generated reviews.  
4. **Implement PPO**: Combine rewards and responses to refine the policy model through iterative updates.  


### Step-by-Step Workflow  

We’ll approach RLHF practically, ensuring theory and implementation go hand-in-hand:  
- **Review Core Concepts**: We’ll break down the math and logic behind RLHF and PPO.  
- **Understand TRL Library**: Instead of treating it as a black box, we’ll examine the **Transformers Reinforcement Learning (TRL)** library by Hugging Face, using its source code as a reference.  
- **Apply RLHF**: Finally, we’ll use TRL to efficiently train our model, leveraging its well-tested implementations while understanding the underlying mechanisms.  

By the end, you’ll not only master RLHF but also see its potential for aligning language models with complex human preferences. Let’s dive in! 🚀

## Let's Prepare the Dataset--->

In [1]:
# Install the necessary libraries for dataset handling, transformers, acceleration, and torch
! pip install datasets transformers accelerate torchao

# Ensure you're using the specified version of 'trl' for compatibility
# If you use a newer version, be aware that the parameters for PPOConfig and PPOTrainer might differ.
! pip install trl==0.11.0


Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting torchao
  Downloading torchao-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchao-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x8

In [2]:
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
import torch
from datasets import load_dataset

# Load the dataset
dataset_name = "stanfordnlp/imdb"
dataset = load_dataset(dataset_name)

df = pd.DataFrame(dataset['train'])

df.head(3)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Unnamed: 0,text,label
0,I rented I AM CURIOUS-YELLOW from my video sto...,0
1,"""I Am Curious: Yellow"" is a risible and preten...",0
2,If only to avoid making this type of film in t...,0


In [3]:

# 1. Renaming the text column to review
df = df.rename(columns={'text': 'review'})

# 2. Filtering the short reviews (no less than 200 characters)
df = df[df['review'].apply(lambda x: len(x) > 200)]

# 3. Perform random sampling for text length (LengthSampler)
min_text_length = 2
max_text_length = 8
values = list(range(min_text_length, max_text_length + 1))  # Ensure max_text_length is included
input_size = np.random.choice(values)

# Display the first 3 rows after processing
df.head(3)

Unnamed: 0,review,label
0,I rented I AM CURIOUS-YELLOW from my video sto...,0
1,"""I Am Curious: Yellow"" is a risible and preten...",0
2,If only to avoid making this type of film in t...,0


In [4]:

# 4. Tokenization
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Define the tokenization function
def tokenize(row):
    # Tokenize the review and truncate to the sampled length
    input_ids = tokenizer.encode(row["review"], truncation=True, max_length=input_size, padding=False)
    query = tokenizer.decode(input_ids)

    # Return the tokenized output as a dictionary
    return {"input_ids": input_ids, "query": query}

# Apply tokenization to each row of the DataFrame
df[['input_ids', 'query']] = df.apply(lambda row: pd.Series(tokenize(row)), axis=1)

# Convert the 'input_ids' column to tensor
df['input_ids'] = df['input_ids'].apply(torch.tensor)

# Displaying the first 3 rows after tokenization
df[['review', 'input_ids', 'query']].head(3)


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Unnamed: 0,review,input_ids,query
0,I rented I AM CURIOUS-YELLOW from my video sto...,"[tensor(40), tensor(26399), tensor(314), tenso...",I rented I AM CURIOUS-
1,"""I Am Curious: Yellow"" is a risible and preten...","[tensor(1), tensor(40), tensor(1703), tensor(4...","""I Am Curious: Yellow"" is"
2,If only to avoid making this type of film in t...,"[tensor(1532), tensor(691), tensor(284), tenso...",If only to avoid making this type of


Now this is the same implementation but in more comapct way (with no pandas DataFrame):  

In [5]:
from trl.core import LengthSampler

def build_dataset(
    dataset_name="stanfordnlp/imdb",
    input_min_text_length=2,
    input_max_text_length=8,
):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer = AutoTokenizer.from_pretrained("lvwerra/gpt2-imdb")
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

### Explanation of Dataset Construction for RLHF with PPO:

The dataset is designed to train a model using Reinforcement Learning with Human Feedback (RLHF) and Proximal Policy Optimization (PPO). It is structured to split the text into two parts:

1. **Query (Input Prompt):** The initial part of the text fed to the model, which provides context or direction for the response generation.
2. **Response (Generated Output):** The output generated by the model in response to the query.

#### Why Split the Text?
- **Query:** Serves as input for the model during training.
- **Response:** The generated text is evaluated based on rewards (e.g., sentiment analysis), guiding the model to improve its responses.

The dataset is tokenized, filtered, and processed to generate input-output pairs where:
- The **query** is used as input for response generation.
- The **response** is rewarded based on its quality (e.g., sentiment score).

This structure enables the model to learn optimal response generation through the feedback loop of RLHF, where the reward is applied to the generated response, not the prompt.


In [6]:
# Build the dataset
ds = build_dataset(dataset_name="stanfordnlp/imdb", input_min_text_length=2, input_max_text_length=8)

tokenizer_config.json:   0%|          | 0.00/17.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Filter:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/24895 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors


In [7]:
# Lets look at the dataset
ds

Dataset({
    features: ['review', 'label', 'input_ids', 'query'],
    num_rows: 24895
})

In [8]:
# Lets look at the first row
ds[0]

{'review': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far 

We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model.

In [9]:
from trl import AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer



# Load the main GPT-2 model with a value head for fine-tuning with reinforcement learning
model = AutoModelForCausalLMWithValueHead.from_pretrained("lvwerra/gpt2-imdb")

# Load a reference GPT-2 model with a value head, typically used to calculate KL divergence
# between the fine-tuned model and the original model during fine-tuning.
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("lvwerra/gpt2-imdb")

# Load the tokenizer for GPT-2. This will handle tokenization for both the main and reference models.
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Set the pad token to the end-of-sequence token, as GPT-2 does not have a padding token by default.
tokenizer.pad_token = tokenizer.eos_token


pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

# Exploring the `trl` Library:
# 1- Understanding `AutoModelForCausalLMWithValueHead`

## What is `AutoModelForCausalLMWithValueHead`?

The `AutoModelForCausalLMWithValueHead` is a specialized class from the **`trl`** (Transformers Reinforcement Learning) library. It integrates a "value head" into causal language models like GPT-2 or GPT-3. This class is primarily used in **Reinforcement Learning with Human Feedback (RLHF)** and similar methods where reward-based training is applied to language models.

This class extends the Hugging Face `AutoModelForCausalLM` by adding functionality to predict a **value (reward)** alongside generating text.

---

## How Is It Different from `AutoModelForCausalLM`?

| **Feature**                          | **AutoModelForCausalLM**                    | **AutoModelForCausalLMWithValueHead**           |
|--------------------------------------|--------------------------------------------|------------------------------------------------|
| **Purpose**                          | General causal language modeling.          | Causal language modeling + reward prediction.  |
| **Value Head**                       | ❌ Not included.                            | ✅ Includes a value head for reward prediction. |
| **RLHF Support**                     | ❌ Not directly applicable.                 | ✅ Specifically built for RLHF tasks.          |
| **Reinforcement Learning Algorithms**| Not designed for RL.                       | Works with RL algorithms like PPO.            |
| **Use Case**                         | Text generation, fine-tuning on datasets.  | Fine-tuning for reward-based tasks.           |

---

## Typical Use Cases for `AutoModelForCausalLMWithValueHead`

1. **RLHF Training Pipeline**:
   - Fine-tuning a language model with human feedback to improve responses in dialogue systems (e.g., **OpenAI's ChatGPT**).
   - The value head predicts how well a response aligns with user preferences or task goals.

2. **Reward Modeling**:
   - Used in pipelines where a scalar reward needs to be predicted for a sequence, enabling optimization towards higher-quality outputs.

3. **Policy Learning**:
   - Assists in training the policy (language model) in reinforcement learning setups by providing reward signals.

---

## Simplified Explanation

The `AutoModelForCausalLMWithValueHead` is essentially the same as `AutoModelForCausalLM`, **but with the addition of a value layer** that produces a scalar reward. This value head is crucial for tasks involving reinforcement learning, where the model needs to evaluate and improve its responses based on reward feedback.

## Implementation: `AutoModelForCausalLMWithValueHead`


### Key Features of the Implementation

1. **Transformer Backbone**:
   - The `GPT2Model` forms the base for generating text and token-level embeddings.

2. **Language Modeling Head (`lm_head`)**:
   - Predicts the probabilities for the next tokens.

3. **Value Head (`value_head`)**:
   - Adds a simple linear layer to predict a scalar value (reward) based on the last token's hidden state.


This implementation highlights how `AutoModelForCausalLMWithValueHead` extends the functionality of causal language models to support **reward-based training**, making it an essential tool for RLHF and related tasks.

In [10]:
import torch.nn as nn
from transformers import GPT2Model, GPT2Tokenizer

class AutoModelForCausalLMWithValueHead(nn.Module):
    def __init__(self, model_name="gpt2"):
        super().__init__()
        # Load the base causal language model (transformer backbone)
        self.transformer = GPT2Model.from_pretrained(model_name)
        self.lm_head = nn.Linear(self.transformer.config.hidden_size, self.transformer.config.vocab_size, bias=False)

        # Add a value head for scalar reward prediction
        self.value_head = nn.Linear(self.transformer.config.hidden_size, 1)  # Outputs a scalar

    def forward(self, input_ids, attention_mask=None):
        # Pass inputs through the transformer backbone
        outputs = self.transformer(input_ids, attention_mask=attention_mask)

        # Compute logits for next-token prediction
        logits = self.lm_head(outputs.last_hidden_state)

        # Compute the scalar value for the last token's hidden state
        values = self.value_head(outputs.last_hidden_state[:, -1, :])  # Use the last token representation

        return logits, values

# Example Usage
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = AutoModelForCausalLMWithValueHead(model_name)

# Input text
input_text = "In the beginning"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# Forward pass
logits, values = model(input_ids)

# Outputs
print("Logits shape:", logits.shape)  # Shape: [batch_size, sequence_length, vocab_size]
print("Values shape:", values.shape)  # Shape: [batch_size, 1]
print("Predicted value:", values)     # Scalar reward prediction


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Logits shape: torch.Size([1, 3, 50257])
Values shape: torch.Size([1, 1])
Predicted value: tensor([[0.3101]], grad_fn=<AddmmBackward0>)


Lets recap what we did:

### 1. **Preparing the IMDb Data for Reinforcement Learning (RL) Training:**

### 2. **Loading Two Fine-Tuned GPT-2 Models:**
We use two fine-tuned GPT-2 models on the IMDb dataset:
1. **Policy Model**
2. **Reference Model**

### Why do we need both models?
- **Policy Model**: It learns and adapts to generate optimal responses based on rewards.
- **Reference Model**: It ensures that the policy model doesn't "cheat" during optimization (e.g., generating repetitive or irrelevant responses to gain rewards). It helps prevent issues like **reward hacking**, where the model manipulates the reward system to score high without actually improving its behavior (e.g., repeating "thank you" endlessly to get high rewards for being polite).

### 3. **Introducing the Reward Model:**
At this point, you may wonder, "Where is the reward model?"  
- **Reward Model**: In Reinforcement Learning from Human Feedback (**RLHF**), the reward comes from human evaluators who provide feedback on the model's output (e.g., how polite or accurate the model’s response is).  
- **Simulating Human Feedback**: Since it's not practical to rely solely on human feedback during training, we simulate it. In this case, we use **sentiment analysis** as a surrogate reward signal. The sentiment analysis model evaluates the generated response and provides a score that acts as the reward. This score informs the PPO algorithm about how good or bad the model’s response is, guiding the policy model towards better outputs.

### 4. **Reward Hacking Problem:**
When training models using reinforcement learning, one common issue is **reward hacking**. This occurs when a model learns to optimize for high rewards by exploiting loopholes in the reward system rather than improving its actual performance.  
- For example, if the reward is based on being polite, the model might start saying "thank you" endlessly, which maximizes the reward score, but this behavior drifts away from the true goal of being both **polite** and **understandable**. The reference model helps us measure this divergence and prevents the model from learning such undesirable strategies.

### 5. **Reinforcement Learning from Human Feedback (RLHF):**
- **Human feedback**: In a real-world scenario, human evaluators would provide feedback on the model’s outputs.
- **Simulating human feedback**: Since providing constant human feedback is impractical, we simulate this feedback using sentiment analysis to generate reward signals. This allows us to guide the policy model toward desired behaviors, such as politeness, clarity, or correctness.


In [11]:
# Importing the necessary pipeline function from the transformers library.
from transformers import pipeline

# Define the model name for sentiment analysis
reward_model_name = "lvwerra/distilbert-imdb"

# Define the pipeline type, which is sentiment analysis in this case.
# "sentiment-analysis" refers to the specific task we want the model to perform.
pipeline_type = "sentiment-analysis"

# Check if CUDA (GPU) is available; if not, default to CPU.
# This allows the code to utilize GPU if available, otherwise fall back to CPU for computation.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize the sentiment analysis pipeline using the chosen model and device.
# The pipeline automatically loads the model and tokenizer and makes it ready for inference.
sentiment_pipe = pipeline(pipeline_type,
                          model=reward_model_name,  # Specify the pre-trained model for sentiment analysis
                          device=device)  # Use GPU (cuda) if available, otherwise fallback to CPU


config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Device set to use cuda


Lets test the pipline:

In [12]:
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

text = "This movie was really bad!"
sentiment_pipe(text, **sent_kwargs)




[[{'label': 'NEGATIVE', 'score': 2.3846828937530518},
  {'label': 'POSITIVE', 'score': -2.77394962310791}]]

In [13]:
text = "This movie was really good!"
sentiment_pipe(text, **sent_kwargs)


[[{'label': 'NEGATIVE', 'score': -2.2816061973571777},
  {'label': 'POSITIVE', 'score': 2.539324998855591}]]

## Now that everything is set, it's time to dive into the most crucial part: **PPOTrainer**

In the `trl` library, using **PPOTrainer** is quite simple. Here’s how to do it step by step:

### 1. Set up the PPO config :
we will strart by settinu up the main config by using PPOConfig class



In [14]:
# Import necessary modules from the `trl` library
# PPOTrainer is used to implement the Proximal Policy Optimization (PPO) training algorithm.
# PPOConfig is a configuration class to define hyperparameters for PPOTrainer.
from trl import PPOTrainer, PPOConfig

# Install the `wandb` package for experiment tracking and logging
# `wandb` (Weights & Biases) is a tool for logging metrics, visualizing model performance, and managing experiments.
! pip install wandb

# Define the PPO configuration
config = PPOConfig(
    model_name="lvwerra/gpt2-imdb",  # Pre-trained model name to be fine-tuned using PPO
    learning_rate=1.41e-5,          # Learning rate for the optimizer during PPO training
    log_with="wandb",               # Specify that logs should be sent to Weights & Biases for tracking
)

# Import the `wandb` module for initializing the experiment
import wandb

# Initialize a new Weights & Biases run
# This starts tracking the training process, logging metrics, and saving model checkpoints
wandb.init()

# Notes:
# 1. The `PPOConfig` sets up the model and training configuration for the PPOTrainer.
# 2. The learning rate (1.41e-5) is a hyperparameter that controls how much to adjust the model weights during training.
# 3. `log_with="wandb"` ensures that all training logs are sent to your Weights & Biases dashboard.
# 4. Make sure you have logged into Weights & Biases using `wandb.login()` before running this code, or it will prompt you to log in.
# 5. The model being fine-tuned (`lvwerra/gpt2-imdb`) is a pre-trained GPT-2 model specifically designed for sentiment analysis.




[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [15]:
from trl.models.modeling_value_head import AutoModelForCausalLMWithValueHead  # Import a model wrapper with value head

# Define the collator function for data preprocessing
# This function organizes the data into a dictionary format for training
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

# Load the GPT-2 model with a value head for reinforcement learning
# The value head is required for computing rewards during Proximal Policy Optimization (PPO)
model = AutoModelForCausalLMWithValueHead.from_pretrained("lvwerra/gpt2-imdb")

# Load the reference model
# This model is kept frozen and used to calculate KL divergence during training
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("lvwerra/gpt2-imdb")

# Initialize the tokenizer
# The tokenizer is used to preprocess text data into input IDs and attention masks
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Set the padding token to be the same as the EOS token to handle batch padding
tokenizer.pad_token = tokenizer.eos_token

# Initialize PPOTrainer
# The PPOTrainer performs reinforcement learning using the PPO algorithm
ppo_trainer = PPOTrainer(
    config=config,           # Training configuration
    model=model,             # Model to be fine-tuned
    ref_model=ref_model,     # Reference model (used for KL divergence calculation)
    tokenizer=tokenizer,     # Tokenizer for preprocessing
    dataset=ds,              # Dataset for training (define `ds` before this step)
    data_collator=collator   # Data collator function for batch preparation
)

# set up the device too to avoid bugs
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug




### Get Responses from the GPT-2 Model:
To generate responses (trajectories) using your GPT-2 model, we need to set up a few parameters for response generation. This includes defining the desired length range for the outputs, configuring the sampling strategy, and iterating through the training data to extract queries. The generated responses will be used later in the PPO training process. Here's the code to achieve this:


In [16]:
# Define the minimum and maximum length for the model's output (response) sequences
output_min_length = 4  # Minimum number of tokens in the response
output_max_length = 16  # Maximum number of tokens in the response

# Create a sampler to randomly select the length of responses between the min and max values
output_length_sampler = LengthSampler(output_min_length, output_max_length)

# Configuration for generating responses (trajectories)
response_generation_kwargs = {
    "min_length": -1,  # No minimum length constraint (can be overridden by LengthSampler)
    "top_k": 0.0,  # Disable Top-K sampling (all tokens are considered)
    "top_p": 1.0,  # Disable Top-P (nucleus) sampling (all tokens are considered)
    "do_sample": True,  # Enable stochastic sampling instead of greedy decoding
    "pad_token_id": tokenizer.eos_token_id,  # Use the end-of-sequence token as the padding token
}


### Everything is Set Up! Let's Proceed 🚀

We will now move forward in phases to train the model using PPO. Here's the breakdown of the workflow:

#### **Phase 1**: Generate the trajectory using the GPT-2 model  
In this phase, we'll use the GPT-2 model to generate responses (trajectories) based on the input queries.

#### **Phase 2**: Calculate the reward using the sentiment pipeline  
Once we have the generated responses, we'll use the sentiment analysis pipeline (set up earlier) to evaluate the quality of the responses and compute the corresponding rewards.

#### **Phase 3**: Combine Phase 1 and Phase 2 to calculate the log probabilities and run the PPO update  
Here, we'll use the trajectories, rewards, and log probabilities to optimize the GPT-2 model with PPO updates.


In [17]:
from tqdm import tqdm

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #### Phase 1: Generate trajectories using the GPT-2 model
    # Here, we generate responses (trajectories) using the GPT-2 model. The responses are generated based on the input queries.
    response_tensors = []
    for query in query_tensors:
        # Sample a random length for the generated response
        gen_len = output_length_sampler()
        # Update the response generation settings with the sampled length
        response_generation_kwargs["max_new_tokens"] = gen_len
        # Generate the response using the GPT-2 model
        response = ppo_trainer.generate(query, **response_generation_kwargs)
        # Extract only the response tokens (excluding the prompt/query tokens)
        response_tensors.append(response.squeeze()[-gen_len:])
    # Decode the generated responses into text and store them in the batch
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Phase 2: Calculate the reward using the sentiment pipeline
    # Combine the queries (prompts) and responses into full texts
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    # Use the sentiment analysis pipeline to calculate the rewards
    # The pipeline outputs a dictionary with 'POSITIVE' and 'NEGATIVE' scores for each text
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    # Extract the reward corresponding to the 'POSITIVE' score
    # This reward is assigned to the entire response, not individual tokens
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Phase 3: Calculate the log probabilities and perform PPO updates
    # The PPO trainer calculates the log probabilities for the queries and responses
    # It then uses the rewards to optimize the model using PPO updates
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

    # Log the training statistics and rewards for this batch
    ppo_trainer.log_stats(stats, batch, rewards)

# Save the trained model and tokenizer locally
model.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=False)
tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=False)


0it [00:00, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
8it [02:34, 18.04s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
194it [1:00:53, 18.83s/it]


('gpt2-imdb-pos-v2/tokenizer_config.json',
 'gpt2-imdb-pos-v2/special_tokens_map.json',
 'gpt2-imdb-pos-v2/vocab.json',
 'gpt2-imdb-pos-v2/merges.txt',
 'gpt2-imdb-pos-v2/added_tokens.json',
 'gpt2-imdb-pos-v2/tokenizer.json')

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from transformers import pipeline

# Load the trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2-imdb-pos-v2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-imdb-pos-v2")

# Sentiment pipeline (for reward calculation, based on sentiment)
sentiment_pipe = pipeline("sentiment-analysis")

# Function to generate responses based on input query with temperature, top_k, and top_p sampling
def generate_response(query, model, tokenizer, max_length=50, temperature=0.8, top_k=50, top_p=0.95):
    # Tokenize the input query
    input_ids = tokenizer.encode(query, return_tensors="pt")

    # Generate the response with sampling parameters to improve diversity and cohesiveness
    output = model.generate(input_ids,
                            max_new_tokens=max_length,
                            temperature=temperature,
                            top_k=top_k,
                            top_p=top_p,
                            eos_token_id=tokenizer.eos_token_id,
                            pad_token_id=tokenizer.eos_token_id)

    # Decode the response to text and return it
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response

# Example queries to test the model
queries = [
    "What did you think of the movie?",
    "Was the movie good?",
    "Describe the movie in a positive light."
]

# Function to calculate reward using sentiment analysis
def calculate_rewards(responses):
    rewards = []
    for response in responses:
        sentiment_result = sentiment_pipe(response)
        # Extract the positive sentiment score
        reward = sentiment_result[0]['score'] if sentiment_result[0]['label'] == 'POSITIVE' else 0
        rewards.append(reward)
    return rewards

# Generate responses for each query and calculate rewards
generated_responses = []
for query in queries:
    response = generate_response(query, model, tokenizer)
    generated_responses.append(response)
    print(f"Query: {query}")
    print(f"Generated Response: {response}\n")

# Calculate rewards based on sentiment analysis (positive sentiment means higher reward)
rewards = calculate_rewards(generated_responses)

# Print the rewards (how positive the generated responses are)
for response, reward in zip(generated_responses, rewards):
    print(f"Generated Response: {response}")
    print(f"Sentiment Reward: {reward}\n")


# looking at the code you may have some questions:
### 1. **Why is the Positive Reward Extracted?**

The positive reward is extracted from the sentiment analysis output to optimize the model towards generating more positive or higher-quality responses. The sentiment analysis provides a score for both positive and negative sentiment. By using the positive score as the reward, we encourage the model to generate responses that are perceived as more positive, improving the quality of the output.

### 2. **Why Focus on the Response and Not the Query?**

The reward focuses on the response because the goal is to improve the quality of the generated output (the response) based on the given input (the query). The query serves as a context for the response, but the reward is applied to the response to guide the model in generating better outputs aligned with positive sentiment or other desired qualities.


### And you're good to go!
If your goal was to use PPO for fine-tuning and sentiment alignment, congratulations! You've just achieved what you set out to do. 🎉

### For the rest of you who want to take a peek under the hood:

If you’re curious about the mathematics behind the PPO algorithm or want to dive deeper into the source code of the **PPOTrainer** (at least the vanilla version), buckle up! It's going to be a **fun**, **challenging** ride, and I’m sure you’ll come out the other side stronger.

Let’s go, deep learning warrior! 🚀

### Diving into PPO Trainer

In this section, we'll explore the `step` function of the PPO trainer. This function plays a crucial role in training a model using Proximal Policy Optimization (PPO). It takes in a series of inputs, including queries, responses, and rewards, which are used to optimize the model's policy based on the rewards it receives. We'll break down each parameter and understand how it contributes to the training process.

Note:
The TRL (Training Reinforcement Learning) implementation inside the TRL library is very extensive and comes with a lot of additional functionality (referred to as "bells and whistles"). In our case, we will need to focus on the vanilla implementation of PPO without these additional complexities. Therefore, to keep things simple, I will provide a commented version of the code—commenting out the lines we don't need for now.

For the complete and more feature-rich implementation, you can refer to the repository linked below:

Repository: https://github.com/hkproj/rlhf-ppo

All credits to Umar Jamil and his repository.

In [None]:
@PPODecorators.empty_device_cache()
def step(
    self,
    queries: List[torch.LongTensor],  # The list of prompts (queries) used to generate responses from the old model (offline policy)
    responses: List[torch.LongTensor],  # A list of responses generated by the old model (offline policy)
    scores: List[torch.FloatTensor],  # A list of rewards associated with each response (one reward per response, not per token)
    response_masks: Optional[List[torch.LongTensor]] = None,  # Optional, used to mask out certain parts of the response (e.g., padding tokens)
):
    # The `queries` are prepared from the dataset we loaded earlier. These represent the prompts that the model will respond to.
    # Each query corresponds to a specific input that the model will process to generate a response.

    # The `responses` are generated by the old model (offline policy). These are the outputs of the model when given the `queries`.
    # In our case, we are using a pre-trained model to generate these responses.

    # The `scores` represent the rewards associated with each response. These rewards can come from a reward model or human feedback.
    # In our case, we use the sentiment analysis pipeline to evaluate the responses and generate scores for them.

    # The `response_masks` is an optional argument that is used when we need to mask out certain parts of the responses.
    # For example, padding tokens might be masked to prevent them from influencing the attention mechanism.
    # For simplicity, we'll leave this as `None` unless we need it for advanced use cases.

    # Note: The queries, responses, and scores were prepared and processed in the previous sections,
    # with the responses being generated from our model and the rewards derived from sentiment analysis.


Preparation Steps of the step Function
Before proceeding with optimization, the step function performs several critical preparation steps to ensure the input data is well-structured and ready for processing. These steps include:

1. Input Verification
The function validates that all inputs—queries, responses, scores, and optionally response_masks—are in the correct format, shapes, and data types.
Why? Ensuring the integrity of inputs avoids runtime errors and guarantees compatibility with downstream computations. This step is crucial, especially for batch processing, where mismatched input shapes can cause errors and disrupt the entire pipeline.
2. Conversion of Rewards to Tensors
The scores (representing rewards) are converted into PyTorch tensors and moved to the same device (CPU or GPU) as the model.
Why? Language models operate on tensors, and aligning the rewards with the model's computational device ensures efficient processing without unnecessary overhead caused by frequent data transfers.
3. Merging Queries and Responses into a Unified Tensor
The queries (prompts) and responses (model-generated completions) are concatenated into a single tensor, input_ids.
An accompanying attention mask tensor is created to indicate which positions correspond to actual data and which are padding tokens.
Why?
Language models require input tensors of shape (batch_size, seq_len). Since queries and responses may vary in length, padding ensures all sequences in the batch have the same seq_len.
The attention mask helps the model focus only on meaningful tokens, effectively ignoring padded positions during computation.
By performing these steps, the step function lays the groundwork for smooth and efficient processing in the subsequent stages of Proximal Policy Optimization (PPO). Proper input preparation is crucial for robust, error-free optimization and enables the model to handle batches of varying sequence lengths effectively.

Code Integration
To understand the practical implementation of these steps, you can refer to the source code for input preparation and PPO processing in the repository:
RLHF-PPO GitHub Repository

Note: The code will typically be implemented within a function to ensure modularity and avoid runtime errors, especially in cases where individual cells may throw exceptions during notebook execution.

In [None]:
# queries: input_ids of the prompts;
# responses: input_ids of the responses;
# scores: score from reward model (one per response)
# Verify input tensors (check types, shapes, etc.)
queries, responses, scores, response_masks = self._step_safety_checker(
    bs, queries, responses, scores, response_masks
)

# Indicates the rewards given to the responses. One scalar for each response.
# shape: (batch_size)
scores = torch.tensor(scores, device=self.current_device)

# Join the query and the response to create a input_ids tensor
# Also generate the attention masks (for padding). Padding is added so that all the query+response can be joined in the same tensor
# Dictionary with input_ids and attention_mask.
# Shape of input_ids: (batch_size, seq_len)
# Shape of attention_mask: (batch_size, seq_len). The attention mask just masks out the padding token.
model_inputs = self.prepare_model_inputs(queries, responses)


### **Log Probability Calculation for Both the Frozen and Reference Models**

Once the preparation step is complete, the next task is to calculate the log probabilities for both the frozen model and the reference model. This step is crucial for the Proximal Policy Optimization (PPO) process, where we compare the outputs of these models to assess performance and guide optimization.

#### **Model Loading and Setup**
As a reminder, we’ve loaded two models using the `AutoModelForCausalLMWithValueHead` class. This class differs from the regular causal language models available in Hugging Face in that it includes an additional **linear layer** to calculate the **values** (or trajectory) of each token in the sequence. The value function is critical for reinforcement learning, where the model's ability to predict the future reward of a given action (or token) is key to optimization.

#### **Forward Pass for Both Models**
For each model (the current model and the reference model), we run a **forward pass**. This generates predictions, including the **values** for each token in the sequence.

- The **current model** outputs the values and log probabilities for each token in the query-response pair.
- The **reference model**, which may be frozen (i.e., not updated during training), provides its own values and log probabilities for comparison.

#### **Log Probability Calculation**
After obtaining the values from the forward pass, we proceed to calculate the **log probabilities**. This step converts the model outputs into probabilities that indicate how likely the model is to generate the observed sequence of tokens. These probabilities are essential for the PPO algorithm, as they will be used to compute rewards and gradients.




In [None]:
# Determine if the full KL penalty should be used.
# 'self.config.kl_penalty' is likely set in the configuration or hyperparameters,
# and it checks whether the penalty type is 'full'. In this case, it's False (i.e., not using full KL penalty).
full_kl_penalty = self.config.kl_penalty == "full"  # Will be False in our case.

# We use `torch.no_grad()` to ensure that the forward pass does not track gradients.
# This is crucial for inference since we are not performing backpropagation and gradient updates here.
with torch.no_grad():
    # Perform the forward pass for the main model to calculate log probabilities, values, and attention masks
    # Note: Queries and responses here are sequences that the model will predict over.

    # The forward pass computes the log probabilities of each token and the corresponding values (e.g., rewards or trajectories).
    # It returns four outputs:
    # - `all_logprobs`: Logarithmic probabilities of predicted tokens (shape: Batch_Size x Seq_Len - 1).
    # - `logits_or_none`: Raw logits (pre-softmax values) for each token in the sequence.
    # - `values`: The value function predictions, usually representing long-term rewards.
    # - `masks`: Attention masks that are used to indicate which tokens should be attended to (ignoring padding tokens or query tokens).

    all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
        self.model,  # The model being used for this forward pass (could be a fine-tuned model).
        queries,  # The batch of queries (prompts for the model).
        responses,  # The batch of responses (model-generated text for comparison).
        model_inputs,  # Processed inputs ready for the model (e.g., tokenized inputs, attention masks).
        response_masks=response_masks,  # Masks indicating valid tokens in the response.
        return_logits=full_kl_penalty,  # Whether to return logits (depends on whether the full KL penalty is used).
    )

    # We also calculate the log probabilities w.r.t. a reference model (this model is often frozen and serves as a baseline).
    with self.optional_peft_ctx():  # This context is used if applying Parameter-Efficient Fine-Tuning (PEFT).
        # Perform a forward pass for the reference model.
        # The `self.is_peft_model` check determines whether to use the current model or the reference model for this forward pass.
        # If `is_peft_model` is True, it uses the current model (usually fine-tuned), otherwise, it uses the reference model (frozen).

        # The forward pass with the reference model is used to calculate the log probabilities for the reference model.
        ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
            self.model if self.is_peft_model else self.ref_model,  # Choose between the current (fine-tuned) or reference (frozen) model.
            queries,  # The same batch of queries.
            responses,  # The same batch of responses.
            model_inputs,  # The same processed inputs.
            return_logits=full_kl_penalty,  # Same flag as before to return logits.
        )


Let's take a deeper look at how the `batched_forward_pass()` function operates under the hood.

As the name suggests, the function performs a **batched (sliced)** forward pass using the two models and applies the log probability calculation. This is the core functionality of the function.

However, there are additional implementation choices and optimizations within the function that we'll explore further. These choices are crucial for efficiency, particularly when dealing with large inputs that don't fit into memory all at once.

## First let see the initialize the  pocess

In [None]:
# Initialize the batch size (bs) based on the number of queries in the dataset.
# This represents the total number of queries (or examples) we will process in this function.
bs = len(queries)

# Set the mini-batch size (fbs) based on the configuration. This value determines
# the number of examples that will be processed at once in each forward pass.
fbs = self.config.mini_batch_size

# Initialize empty lists to store the log probabilities, logits, masks, and values for each mini-batch.
# These lists will accumulate the results for all batches, which will eventually be concatenated and returned.
all_logprobs = []
all_logits = []
all_masks = []
all_values = []

# Switch the model to evaluation mode.
# This is important because certain layers (like dropout and batch normalization) behave differently
# during training and inference. In evaluation mode, these layers are fixed and do not introduce randomness.
model.eval()


To efficiently process the data, we must:

1. **Iterate over the Batch**: Loop through the total batch size and slice it into smaller mini-batches. This step is crucial for memory efficiency, as processing the entire batch at once could overwhelm the available memory.
  
2. **Compute Logits**: For each mini-batch, calculate the logits for all tokens. These logits represent the raw, unnormalized scores for each token in the sequence, which will later be transformed into probabilities using softmax.

3. **Apply Masking**: Use a mask to exclude the query tokens from the log probability computation. We only want to focus on the response tokens (not the query), and the mask ensures that we ignore the query during this step. More details on how this works will be explained later.


In [None]:
import math
# Loop over the total batch size (bs), and process it in smaller mini-batches (fbs) for memory efficiency.
# If the full batch is too large to fit into memory, this slicing ensures that we process it piece by piece.
for i in range(math.ceil(bs / fbs)):
    # Slice the input tensors (queries, responses, and other model inputs) into mini-batches.
    # Each mini-batch will contain a subset of the data of size 'fbs'.

    # Input arguments for the model: Create a dictionary of the current mini-batch inputs.
    # This ensures that we don't exceed memory limits and process manageable pieces of data.
    input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}

    # Slice the query and response tensors for the current mini-batch.
    query_batch = queries[i * fbs : (i + 1) * fbs]
    response_batch = responses[i * fbs : (i + 1) * fbs]

    # If response masks are provided, slice them as well.
    if response_masks is not None:
        response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]

    # Now, call the model on the current mini-batch.
    # The model generates:
    # 1. `logits`: The raw, unnormalized scores for each token at each position in the response.
    # 2. `values`: These are the predictions generated by the additional linear layer (e.g., trajectory values).

    logits, _, values = model(**input_kwargs)

    # Determine the input ids and attention masks based on whether the model is encoder-decoder or not.
    # Encoder-decoder models have separate input ids for the encoder and decoder.
    # For non-encoder-decoder models, we use the standard input ids and attention masks.
    if self.is_encoder_decoder:
        input_ids = input_kwargs["decoder_input_ids"]
        attention_mask = input_kwargs["decoder_attention_mask"]
    else:
        input_ids = input_kwargs["input_ids"]
        attention_mask = input_kwargs["attention_mask"]

    # Step 2: Calculate log probabilities for each token using the logits.
    # We apply the `logprobs_from_logits` function, which applies the log-softmax function to the logits.
    # The `logprobs` are essentially the log-transformed probabilities of each token in the response sequence.
    logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) # --> unde the hood this is just F.softmax() implemtaion

    # Step 3: Apply a mask to exclude the query tokens from the log probabilities.
    # We only want to calculate log probabilities for the response tokens, not for the query tokens.
    # The attention mask helps identify the positions in the input where the model should focus (i.e., valid tokens).
    masks = torch.zeros_like(attention_mask)
    masks[:, :-1] = attention_mask[:, 1:]  # Shift attention mask by 1 to exclude query tokens.

    # Further mask out response tokens that fall outside of the valid range, like padding tokens.
    for j in range(len(query_batch)):
        if self.is_encoder_decoder:
            # For encoder-decoder models, the first token of the decoder starts at index 1.
            start = 1
            end = attention_mask[j, :].sum() - 1  # End of the actual input sequence.
        else:
            # For non-encoder-decoder models, we start masking after the query tokens.
            start = len(query_batch[j]) - 1  # First response token comes after the query.
            if attention_mask[j, 0] == 0:  # If there's padding at the beginning of the sequence.
                start += attention_mask[j, :].nonzero()[0]  # Adjust for padding offset.
            end = start + len(response_batch[j])  # End position for the response sequence.

            if response_masks is not None:
                response_masks_batch[j] = torch.cat(
                    (torch.zeros_like(query_batch[j]), response_masks_batch[j])
                )[1:]  # Adjust response masks accordingly.

        # Mask out any tokens before the first response token (query tokens) and after the response tokens (padding).
        masks[j, :start] = 0
        masks[j, end:] = 0

        # Apply the response masks, if available, to the valid response tokens.
        if response_masks is not None:
            masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]

    # Append the calculated logprobs, values, and masks to their respective lists for later concatenation.
    all_logprobs.append(logprobs)
    all_values.append(values)
    all_masks.append(masks)
    if return_logits:
        all_logits.append(logits)
    else:
        del logits  # If logits are not needed, delete them to save memory.


## Now lets get back to the masking ( before the response and after with zero). why we do that ?

The primary goal of this masking is to ensure that we only compute the log probabilities and values for the response tokens—the part of the sequence that the model needs to predict, given the query. Since the query and the padding tokens should not contribute to the evaluation of the model's performance on the actual task (predicting the response), we mask out these irrelevant parts.

In summary, the masking process ensures that:

We exclude query tokens to focus on the response.
We exclude padding tokens to ensure they don't affect the log probability calculation.
We only calculate the log probabilities and model values for the response tokens, which is the relevant portion of the sequence.
This is a standard practice when working with sequences of different lengths, especially in models where padding is used to standardize sequence lengths in a batch.

<p align="center">
    <img src="images/log_prob_calc.png" alt="PEFT Overview" />
</p>

## Calculating the rewards:
The next step is to calculate the rewards. This is done using two components: the scores (which were converted to tensors in the preparation step) and the KL divergence. The KL divergence measures the difference between the log probabilities generated by the model and those generated by the reference model.

KL Penalty (Kullback-Leibler Divergence) Calculation:
The KL divergence measures how much the fine-tuned model's output diverges from the reference (frozen) model's output. It is calculated using the log probabilities of each token generated by both models.
The formula for KL divergence between two probability distributions

​KL = sum(P(x)log(P(x) - Q(x)))

P(x) is the probability distribution of the fine-tuned model (from logprob).
Q(x) is the probability distribution of the reference model (from ref_logprob).
The self._kl_penalty(logprob, ref_logprob) function calculates the difference in log probabilities for each token across both models.

In [None]:
def compute_rewards(
    self,
    scores: torch.FloatTensor,  # The scores tensor, which contains the reward model's score for each token in the response
    logprobs: torch.FloatTensor,  # The log probabilities of the tokens computed for the fine-tuned model
    ref_logprobs: torch.FloatTensor,  # The log probabilities computed for the reference (frozen) model
    masks: torch.LongTensor,  # A mask tensor indicating which tokens are part of the actual response (excluding padding and query tokens)
):

    # Initialize empty lists to store rewards, non-score rewards, and KL divergences for each example in the batch
    rewards, non_score_rewards, kls = [], [], []

    # Loop through each example in the batch
    for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):

        # Compute the KL penalty, which is the difference between the log probabilities of the fine-tuned model and the reference model.
        # This is used to measure how much the response generated by the model diverges from the frozen model.
        # The shape of `kl` is (Seq_Len), where each element corresponds to the KL divergence for that token.
        kl = self._kl_penalty(logprob, ref_logprob)
        kls.append(kl)

        # Compute the non-score reward, which is the KL penalty scaled by the KL control value (`kl_ctl.value`).
        # This penalizes the reward given by the reward model based on the divergence between the model and reference.
        non_score_reward = -self.kl_ctl.value * kl
        non_score_rewards.append(non_score_reward)

        # Initialize the reward as the non-score reward. This means we start with a penalty and will add the score later.
        reward = non_score_reward.clone()

        # Find the index of the last token in the response that is not masked out (i.e., the last token of the generated response).
        last_non_masked_index = mask.nonzero()[-1]

        # The reward for each token is initially set to the negative KL penalty (to penalize divergence).
        # We then add the score (from the reward model) only to the last token in the response.
        # This means we assign the final reward to the last token of the response based on the model's score.
        reward[last_non_masked_index] += score

        # Append the calculated reward for this example to the rewards list.
        rewards.append(reward)

    # Return the rewards, non-score rewards, and KL divergences as tensors (stacked to match batch size).
    return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)


Internally the _kl_penalty() function calculate the difference between the log propabilites between the model and refrence model as follow :

In [None]:
def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
        # in our case we use this :
        if self.config.kl_penalty == "kl":
            return logprob - ref_logprob

        if self.config.kl_penalty == "abs":
            return (logprob - ref_logprob).abs()

        if self.config.kl_penalty == "mse":
            return 0.5 * (logprob - ref_logprob).square()

        # Remmberr we set this to False earlier
        if self.config.kl_penalty == "full":
            # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
            return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)

        raise NotImplementedError

## What we did do far:

1- Batch Forward Pass: Compute log probabilities for both models.
2- KL Penalty: Measure divergence and apply a penalty.
3- Compute Rewards: Penalize the KL divergence and reward high-quality responses.

## Next Step : Compute Advantages:
This done by using rewards and values using GAE (Generlize Advantage Estimation) formula

In [None]:
# Use the rewards and the values to compute the advantage using Generalized Advantage Estimation (GAE).
# values: (Batch_Size, Seq_Len - 1) - Predicted state values from the value network.
# rewards: (Batch_Size, Seq_Len - 1) - Rewards for the generated responses, computed using the reward model.
# masks: (Batch_Size, Seq_Len - 1) - Binary masks to focus only on the response tokens (excluding query and padding).
# Returns:
# - values: The original predicted state values (unchanged here).
# - advantages: Relative "goodness" of actions using GAE, used to update the policy.
# - returns: Total future discounted rewards (Q-values), used to train the value network.
values, advantages, returns = compute_advantages(values, rewards, masks)


# Details on how values, rewards, and masks are obtained:
# - The 'values' are computed previously using the `batched_forward_pass()` function.
#   This function runs a batch of inputs (queries and responses) through the value network.
# - The 'rewards' are computed with `compute_rewards()`, which might involve scoring
#   the responses using a reward model (e.g., trained with human preferences).
# - The 'masks' are used to extract only the response tokens from the sequence,
#   excluding the query and padding, ensuring that only relevant parts of the sequence
#   contribute to the reward and advantage computation.


### So what is (GAE) anyway ?  
# Generalized Advantage Estimation

Generalized Advantage Estimation (GAE) is a popular technique in reinforcement learning, introduced in **Schulman et al. (2015)**, to compute advantages for policy optimization algorithms like Proximal Policy Optimization (PPO). It balances the trade-off between **bias** and **variance** in the estimation of advantages, leading to more stable and efficient training.

#### **What is Advantage?**

In reinforcement learning, **advantage** measures how good a specific action was compared to the average action at a given state. Mathematically, it’s defined as:  

\[
A(s, a) = Q(s, a) - V(s)
\]

Where:
- \( Q(s, a) \): Expected return (reward) for taking action \(a\) in state \(s\).
- \( V(s) \): Expected return from state \(s\), following the policy.

The advantage tells us whether an action is better or worse than the policy's expected behavior in a given state.

---

#### **Why GAE?**

In practice, computing exact advantages using the above formula can be unstable due to:
1. **High variance** in Monte Carlo estimates of \( Q(s, a) \) (especially for long trajectories).
2. **Bias** introduced by bootstrapping \( V(s) \) from a value function.

GAE addresses these issues by:
1. Incorporating **temporal difference (TD) learning** for bootstrapping.
2. Introducing a **discount factor (\(\gamma\))** and a **smoothing parameter (\(\lambda\))** to control bias and variance.


#### **How GAE is Used in PPO**

- **Advantage Computation**: GAE provides a smooth and stable advantage estimate to guide the policy updates in PPO.
- **Training the Value Network**: GAE-generated "returns" (discounted future rewards) are used as targets to train the value network, ensuring it accurately predicts state values.

GAE is especially useful for aligning reinforcement learning models like those fine-tuned with RLHF (Reinforcement Learning with Human Feedback), as it ensures that the training process is robust, efficient, and aligned with human preferences.

### lets take a look at at the math:

<p align="center">
    <img src="images/Compute_advantages.png" alt="PEFT Overview" />
</p>

then let dive into the function in the source code yo connent the dots:

### `compute_advantages` Function: Deep Dive and Explanation

The `compute_advantages` function implements **Generalized Advantage Estimation (GAE)**, a powerful technique in reinforcement learning. It computes **advantages** and **returns**, which are essential for training policy and value networks. Let's break it down step by step.

### **Step-by-Step Walkthrough**

#### **Step 1: Initialize Variables**
```python
lastgaelam = 0
advantages_reversed = []
gen_len = rewards.shape[-1]
```
- `lastgaelam`: Keeps track of the advantage from the previous time step for recursive computation.
- `advantages_reversed`: Temporarily stores the computed advantages in reverse order.
- `gen_len`: The total number of time steps in the trajectory (sequence length).

#### **Step 2: Apply Masks**
```python
values = values * mask
rewards = rewards * mask
```
- Element-wise masking ensures that only valid parts of the sequence are used. Padding and query tokens are ignored.

#### **Step 3: Compute Advantages Backward in Time**
```python
for t in reversed(range(gen_len)):
    nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
    delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
    lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
    advantages_reversed.append(lastgaelam)
```

- **Backward Iteration**:
  - Starts from the last time step (`t = T-1`) and goes to `t = 0`.
  - Recursive computation ensures that the advantage `A_t` incorporates future rewards.

- **Next Value (`V(s_{t+1})`)**:
  - Fetches the value of the next time step. If `t` is the last step, `V(s_{t+1}) = 0`.

- **Temporal Difference (TD) Error**:
  ```
  delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
  ```
  - Measures the immediate improvement from taking an action at time `t`.

- **Generalized Advantage**:
  ```
  A_t = delta_t + gamma * lambda * A_{t+1}
  ```
  - `lambda`: A smoothing factor that controls the tradeoff between bias and variance.
  - Combines immediate reward (`delta_t`) with discounted future advantages.

- **Store the Result**:
  - Append `lastgaelam` (current advantage) to `advantages_reversed`.

#### **Step 4: Reverse and Stack Advantages**
```python
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
```
- Since the advantages were computed in reverse, this step reverses them back to match the original time sequence.
- The `torch.stack` operation combines all computed advantages into a tensor of shape `(Batch_Size, Seq_Len)`.

#### **Step 5: Compute Returns**
```python
returns = advantages + values
```
- **Returns (`Q(s, a)`)**:
  ```
  Q(s, a) = A(s, a) + V(s)
  ```
  - Combines the advantage with the value function to produce `Q(s, a)`, the expected total reward for a state-action pair.

#### **Step 6: Normalize Advantages**
```python
advantages = masked_whiten(advantages, mask)
```
- Normalizes the advantages to ensure numerical stability and prevent large values from dominating updates:
  ```
  Advantage_norm = (A - mean(A)) / std(A)
  ```

---

#### **Step 7: Detach Advantages**
```python
advantages = advantages.detach()
```
- Detaches the `advantages` tensor from the computation graph to prevent gradients from flowing back during policy updates.


### **Mathematical Recap**
1. **Temporal Difference (TD) Error**:
   ```
   delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
   ```

2. **Generalized Advantage**:
   ```
   A_t = delta_t + gamma * lambda * A_{t+1}
   ```

3. **Returns**:
   ```
   Q(s, a) = A(s, a) + V(s)
   ```

4. **Normalization**:
   ```
   Advantage_norm = (A - mean(A)) / std(A)
   ```
### **Summary of Outputs**
- **`values`**: Masked value predictions for each time step.
- **`advantages`**: Normalized advantages computed using GAE.
- **`returns`**: Total discounted returns (`Q(s, a)`) used for value function training.

This function efficiently computes the core quantities required for reinforcement learning training, ensuring stable and meaningful gradient updates.

## Putting it all tohother we will get :

In [None]:
def compute_advantages(
        self,
        values: torch.FloatTensor,   # Predicted value function (V(s_t)) for each state at each timestep.
        rewards: torch.FloatTensor,  # Rewards received at each timestep.
        mask: torch.FloatTensor,     # Mask to handle padding or invalid timesteps in the sequence.
    ):
        # Initialize the last generalized advantage estimate (GAE) for recursive computation.
        lastgaelam = 0

        # List to store computed advantages in reverse order (backward pass through time).
        advantages_reversed = []

        # Get the sequence length from the shape of the rewards tensor.
        gen_len = rewards.shape[-1]

        # Mask the values and rewards to ensure only valid timesteps are considered.
        values = values * mask
        rewards = rewards * mask

        # If reward whitening is enabled (commented here), normalize rewards for stability.
        # if self.config.whiten_rewards:
        #     rewards = masked_whiten(rewards, mask, shift_mean=False)

        # Compute advantages in reverse order (backward through time).
        for t in reversed(range(gen_len)):
            # Get the value of the next state (V(s_{t+1})).
            # If the current timestep is the last one, set nextvalues to 0 (no future values exist).
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0

            # Calculate the temporal difference (TD) error (delta_t) using the GAE formula:
            # delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
            delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]

            # Compute the generalized advantage estimate (A_t):
            # A_t = delta_t + gamma * lambda * A_{t+1}
            lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam

            # Store the computed advantage in the reversed list.
            advantages_reversed.append(lastgaelam)

        # Reverse the computed advantages and stack them into a tensor of shape (Batch_Size, Seq_Len).
        advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

        # Compute returns (Q values) from advantages:
        # Q(s, a) = A(s, a) + V(s)
        # The returns are used to train the value function estimator.
        returns = advantages + values

        # Normalize the advantages using masking for numerical stability:
        # Advantage_norm = (A - mean(A)) / std(A)
        advantages = masked_whiten(advantages, mask)

        # Detach advantages from the computation graph to stop gradients from flowing back.
        advantages = advantages.detach()

        # Return the masked values, normalized advantages, and computed returns (Q values).
        return values, advantages, returns


# Some questions you may have:
### **1. Why is the computation done in reverse?**

The computation is done in reverse to **propagate the advantages backward in time** using **Generalized Advantage Estimation (GAE)**. This allows the algorithm to compute the **advantage at each timestep** by recursively considering the future rewards, thereby enabling a more stable learning signal.

Computing this from the last timestep backward ensures that the advantage incorporates future reward information effectively.


### **2. What are the gamma and lambda factors, and why do we use them?**

- **Gamma (γ)**: The **discount factor**, which controls how much future rewards are weighted compared to immediate rewards. It determines how much importance is given to long-term rewards versus short-term rewards. The closer γ is to 1, the more the agent will consider future rewards.

- **Lambda (λ)**: The **smoothing factor**, which determines the weight between bias and variance in advantage estimation. A **higher λ** puts more weight on future rewards, reducing bias but increasing variance, whereas a **lower λ** reduces variance but increases bias.

Together, γ and λ control the trade-off between the bias and variance of the advantage estimation. They help to **stabilize training** and **improve performance** by combining both short-term and long-term reward signals.

### **3. Why do we normalize after the calculation?**

Normalization is done on the **advantages** to:
1. **Ensure numerical stability**: Raw advantages can have large values, which can lead to unstable training. Normalizing the advantages helps prevent large updates.
2. **Prevent dominance of large values**: Without normalization, the model might focus too much on outliers, which could lead to inefficient learning.
3. **Improve optimization**: By normalizing, we ensure that the advantages are on a similar scale, which helps stabilize and speed up the learning process.

### **4. Why do we detach the result from the computation graph?**

We **detach** the advantages from the computation graph to **stop gradients from flowing back through them**. The advantages are intermediate results that do not require gradient updates, as they are used purely to guide the training process.

Detaching ensures that:
- **Memory is saved**: No unnecessary gradients are stored for the advantages, reducing memory consumption.
- **Correct gradient flow**: Only the parameters that need to be updated (e.g., model weights) will have gradients flowing through them, while advantages remain fixed.
- **Avoid unintended updates**: Since the advantages are not part of the network's learnable parameters, detaching prevents the backpropagation of gradients through them.

In PyTorch, we do this by:
```python
advantages = advantages.detach()
```


## Now Phase 1 is DONE!

At this point, we have successfully generated trajectories using the model, where we've taken the input queries, generated responses, and calculated the corresponding rewards (based on sentiment analysis). Now, to move forward, we need to optimize the model in a manner similar to how we do in traditional deep learning.

In the standard deep learning workflow, optimization is carried out by minimizing a **loss function** that quantifies the difference between the model's predictions and the expected output. In our case, we aim to fine-tune the language model (GPT-2) by adjusting its parameters to maximize the reward signal rather than just minimizing an error.

<p align="center">
    <img src="images/policy.png" alt="PEFT Overview" />
</p>
So this is the result of information we have so far:

In [1]:
# This represents all the trajectories sampled (our storage of trajectories) using the old policy (offline).
# Upcasting to float32 to avoid potential dataset issues that may arise with lower precision, ensuring compatibility during training.

batch_dict = {
    # 'queries': These are the input queries (prompts) from the dataset that the model will respond to. The queries represent the initial user inputs.
    "queries": queries,

    # 'responses': These are the model-generated responses to the corresponding queries. These are the outputs of the model based on the input queries.
    "responses": responses,

    # 'logprobs': Log probabilities of the actions taken by the model. These are important for computing the loss function in reinforcement learning (PPO).
    # We upcast these log probabilities to float32 for numerical stability during the optimization process.
    "logprobs": all_logprobs.to(torch.float32),

    # 'values': The value function estimates for each state. These are used in the advantage calculation (for GAE) to assess the quality of each state.
    # The values represent the model's estimation of how good the state (query + generated response) is.
    "values": values.to(torch.float32),

    # 'masks': A mask indicating valid input tokens or padding, used to ensure that padding tokens are not considered in the reward calculation.
    # The mask prevents the model from updating based on irrelevant or padded tokens.
    "masks": masks,

    # 'advantages': The calculated advantages using Generalized Advantage Estimation (GAE), which provides an estimate of the relative value of the action taken.
    # Advantages help guide the policy updates to ensure it maximizes long-term rewards and reduces variance in the gradient estimation.
    "advantages": advantages,

    # 'returns': The sum of discounted rewards (Q-values) for each trajectory, representing the expected return starting from each state-action pair.
    # Returns are important for updating the value function and ensuring that the model learns to maximize its expected reward.
    "returns": returns,
}

# Updates the batch dictionary with the model's input features.
# This could include additional information such as the encoded input tokens, token types, or other features required for generating responses.
batch_dict.update(model_inputs)


Now we make a mini-batch to train on by:
Shuffling the trajectories:
We randomly shuffle the order of the trajectories to avoid any ordering bias. This helps to ensure that the model doesn't memorize patterns based on the order of the data.

Dividing into mini-batches:
We divide the shuffled trajectories into smaller mini-batches for training. This step is essential because training on the entire dataset at once may be computationally expensive and inefficient. Mini-batches allow us to use stochastic gradient descent (SGD) or variants (e.g., Adam) to perform efficient updates.

In [None]:
# Shuffle the trajectories (or trajectories indices) randomly
# 'b_inds' holds the indices of the trajectories, and 'np.random.permutation' shuffles them
b_inds = np.random.permutation(bs)  # bs is the batch size (total number of trajectories)

# Loop over the shuffled trajectories in chunks (batches)
# We process the trajectories in smaller batches of size 'backward_batch_size'
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
    backward_batch_end = backward_batch_start + self.config.backward_batch_size  # Define the end of the batch range

    # Select the indices for the current backward batch from the shuffled indices
    backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]

    # Now we take this batch and break it further into smaller mini-batches
    # The mini-batch size is defined by 'mini_batch_size', which is another configurable parameter.
    for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
        mini_batch_end = mini_batch_start + self.config.mini_batch_size  # Define the end of the mini-batch range

        # Select the indices for the current mini-batch from the backward batch
        mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]


As a result we will have this

In [None]:
 # This is the sampled mini-batch that will be used to optimize the model
mini_batch_dict = {
                        "logprobs": batch_dict["logprobs"][mini_batch_inds],
                        "values": batch_dict["values"][mini_batch_inds],
                        "masks": batch_dict["masks"][mini_batch_inds],
                        # hacks: the queries and responses are ragged.
                        "queries": [batch_dict["queries"][i] for i in mini_batch_inds],
                        "responses": [batch_dict["responses"][i] for i in mini_batch_inds],
                        "advantages": batch_dict["advantages"][mini_batch_inds],
                        "returns": batch_dict["returns"][mini_batch_inds],
                    }

After creating the mini-batches, the next step involves calculating the log-probabilities (logprobs), logits, and value predictions (vpreds) for the new policy (online model). This is an essential step in training reinforcement learning models like PPO.

In [None]:
# Perform a forward pass using the new model (online model) to calculate the log-probabilities,
# logits, and value predictions for the given mini-batch of queries and responses.
logprobs, logits, vpreds, _ = self.batched_forward_pass(
    self.model,  # The current online model (new policy)
    mini_batch_dict["queries"],  # Input queries (batch of input sequences)
    mini_batch_dict["responses"],  # Corresponding responses (batch of expected outputs)
    model_inputs,  # Additional model inputs like tokenized sequences and context
    return_logits=True,  # Ensure that logits are also returned for further analysis
)

# Explanation of the returned outputs:
# logprobs: Log-probabilities of each token generated by the model, used for calculating the policy's likelihood.
# logits: Raw, unnormalized model outputs, used to compute probabilities and for computing gradients.
# vpreds: The model's predicted value for each query-response pair, which helps in calculating the advantage.


Finally, we enter the training loop, where we train the model using the mini-batch we previously created. In this step, we leverage the log-probabilities that were calculated for the new model (also referred to as the online model). These log-probabilities play a crucial role in optimizing the model’s policy, ensuring that it generates more accurate and positive outputs. By training on the mini-batch, the model learns to better align with the desired behavior, gradually improving its performance based on the feedback provided during the training process.

In [None]:
# Training loop for the mini-batch:
# Now that we have the log-probabilities, logits, and value predictions from the online model (new policy),
# we proceed to train the model using the PPO loss.

# Train the model by comparing the log-probs from the old (reference) model and the new (online) model,
# and updating the model's parameters using the computed advantages and rewards.
train_stats = self.train_minibatch(
    mini_batch_dict["logprobs"],  # Log-probs of the previous policy (old model) for comparison
    mini_batch_dict["values"],  # Value predictions from the previous policy (old model)
    logprobs,  # Log-probs of the current policy (new model), which we will use for optimization
    logits,  # Logits of the new model (used for calculating probabilities and further loss calculations)
    vpreds,  # Value predictions of the new model (used for calculating the advantage and loss)
    mini_batch_dict["masks"],  # Masks for valid token positions (used to avoid padding tokens)
    mini_batch_dict["advantages"],  # Advantage estimates, which indicate how good the action was
    mini_batch_dict["returns"],  # The actual return (reward), used to update the policy
)


Diving into the source code, we observe a workflow similar to the standard procedure in deep learning:

1. **Set the Model to Training Mode**:
   The model's mode is set to training by invoking `model.train()`. This is crucial as certain layers, like dropout or batch normalization, behave differently during training and inference. This ensures the model performs optimally while updating the weights.
   
2. **Calculate the Loss Function**:
   In this step, the model makes predictions based on the inputs, and the loss function computes the discrepancy between the predicted outputs and the actual targets. The loss function quantifies how far off the model’s predictions are, providing a signal for model improvement.
   
3. **Run the Backward Pass (Backpropagation)**:
   Once the loss is calculated, we use backpropagation to compute the gradients of the model parameters. This step involves applying the chain rule of calculus to find the partial derivatives of the loss with respect to each parameter in the model.
   
4. **Run the Optimizer to Minimize the Loss**:
   After the gradients are computed, the optimizer (such as Adam or SGD) is used to adjust the model’s parameters in the direction that minimizes the loss. The optimizer uses the gradients to update the parameters, effectively "learning" from the loss and improving the model's ability to make accurate predictions.

This basic flow is the foundation of most deep learning model training, whether it's for supervised learning, reinforcement learning, or other types of models. Each step builds upon the last to improve the model’s performance over time.

In [None]:
@PPODecorators.empty_device_cache()  # Decorator to ensure that device cache is emptied, likely for GPU/TPU memory management.
def train_minibatch(
    self,
    old_logprobs: torch.FloatTensor,  # Log probabilities under the old policy (offline).
    values: torch.FloatTensor,  # Value estimates under the old policy (offline).
    logprobs: torch.FloatTensor,  # Log probabilities under the new policy (online).
    logits: torch.FloatTensor,  # Logits (pre-activation outputs) under the new policy (online).
    vpreds: torch.FloatTensor,  # Value predictions under the new policy (online).
    mask: torch.LongTensor,  # Indicates which tokens the log probabilities correspond to.
    advantages: torch.FloatTensor,  # Advantages calculated under the old policy (offline).
    returns: torch.FloatTensor,  # Returns (rewards) calculated under the old policy (offline).
):
    # Set the model in training mode.
    self.model.train()

    # Compute the loss (policy loss + value loss).
    loss_p, loss_v, train_stats = self.loss(
        old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
    )

    # The final loss is the sum of the policy gradient loss and the value function loss.
    loss = loss_p + loss_v

    # Backpropagate the loss to compute gradients.
    self.accelerator.backward(loss)

    # Apply gradient clipping if configured, to avoid exploding gradients.
    if self.config.max_grad_norm is not None:
        if self.accelerator.sync_gradients:
            # Clip gradients if their norms exceed a certain threshold.
            self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)

    # Perform a parameter update step using the optimizer.
    self.optimizer.step()

    # Clear the gradients from the optimizer after the step to prepare for the next gradient accumulation.
    self.optimizer.zero_grad()

    # Return training statistics.
    return train_stats

                               +------------------------+
                               |  Outer Loop: Iteration |
                               +------------------------+
                                      |
                                      v
                     +-------------------------------+
                     |  Step 1: Collect Data         |
                     | (states, actions, rewards)    |
                     +-------------------------------+
                                      |
                                      v
                     +--------------------------------+
                     |  Step 2: Compute Advantages &  |
                     |  Returns (advantage, returns)  |
                     +--------------------------------+
                                      |
                                      v
                     +--------------------------------+
                     |  Step 3: Mini-Batch Training   |
                     |  - Create mini-batches         |
                     |  - Forward pass on current &    |
                     |    reference models             |
                     |  - Compute loss (policy & value)|
                     +--------------------------------+
                                      |
                                      v
                               +----------------------+
                               |  Backpropagation     |
                               |  and Optimization    |
                               +----------------------+
                                      |
                                      v
                             +--------------------------+
                             |  Continue to next iteration |
                             +--------------------------+


### **Just One Last Missing Piece: What is the Loss Function We Optimize?**

In the PPO (Proximal Policy Optimization) algorithm, the loss function plays a crucial role in guiding the optimization of the model. This loss function is specifically designed to ensure that the updates to the policy are both stable and efficient. Let's break it down step-by-step:

---

### **1. The Total Loss Function:**

The total loss function in PPO is typically a combination of two components:
- **Policy Loss**: This is related to how well the new policy performs compared to the old policy, ensuring that we don't make large updates that could destabilize learning.
- **Value Loss**: This measures how well the value predictions (the state-value function) match the actual rewards (or returns) observed from the environment.

### **2. The Policy Loss**:

The policy loss measures the difference between the **old policy’s log probabilities** and the **new policy’s log probabilities** for the actions taken by the agent. This ensures that the new policy is not too different from the old one, as large deviations could harm the performance of the model.

we must apply **clipping mechanism** to ensures that the policy update doesn't go beyond a certain range, helping avoid large and unstable changes in the policy.

### **3. The Value Loss**:

The value loss ensures that the model’s predictions of the value function are accurate, i.e., the predicted value at each state should match the actual observed return (or discounted future rewards).

### **4. The Entropy Bonus (Optional)**:

An optional term is the **entropy bonus**, which encourages exploration by penalizing the certainty of the policy’s predictions. This term helps to avoid the model from becoming deterministic too quickly, especially early in training. It is calculated as the negative entropy of the action distribution.

### **5. The Total Loss Function**:

The final loss function that we optimize in PPO is the combination of the policy loss, value loss, and possibly the entropy bonus.
all the above summrize like this:

<p align="center">
    <img src="images/PPO_LOSS.png" alt="PEFT Overview" />
</p>
Here are the most important points summarized for a **vanilla implementation** of the loss function:

1. **Policy Loss**:
   - The **policy gradient loss** is calculated using the **ratio of the new to old policy** log-probabilities. We aim to **maximize the objective** (policy improvement) by minimizing the loss, which is done by the optimizer. The negative sign is used because PyTorch minimizes by default.
   - The **clipping mechanism** (`torch.clamp`) ensures that large updates are prevented, maintaining stability in the training process.

2. **Value Function Loss**:
   - The **value function loss** uses **two loss terms**: one is the direct squared error (`(vpreds - returns)^2`), and the other is the clipped version to avoid large deviations in value predictions.
   - The minimum of these terms is selected to stabilize the value function updates.

3. **Gradient Clipping**:
   - **Gradient clipping** (`clip_grad_norm_`) is used to prevent exploding gradients, stabilizing training by limiting the size of the gradients during backpropagation.

4. **Entropy**:
   - **Entropy** loss can be added to encourage exploration, but it's **optional**. The implementation uses **logsumexp** instead of traditional entropy in some versions.

5. **Ratio Threshold**:
   - If the **average ratio** of new-to-old log-probabilities exceeds a threshold, it skips the batch to avoid too large updates that could destabilize training.


In [None]:
def loss(
    self,
    old_logprobs: torch.FloatTensor,  # Log probabilities under the OLD policy (offline)
    values: torch.FloatTensor,  # Values under the OLD policy (offline)
    logits: torch.FloatTensor,  # Logits under the NEW policy (online)
    vpreds: torch.FloatTensor,  # Values under the NEW policy (online)
    logprobs: torch.FloatTensor,  # Log probabilities under the NEW policy (online)
    mask: torch.LongTensor,  # Which tokens the log probabilities correspond to
    advantages: torch.FloatTensor,  # Advantages calculated using the OLD policy (offline)
    returns: torch.FloatTensor,  # State-action (Q-values) calculated using the OLD policy (offline)
):

    # Clip the predicted values to avoid large updates (helps stabilize training)
    vpredclipped = clip_by_value(
        vpreds,
        values - self.config.cliprange_value,
        values + self.config.cliprange_value
    )

    # Loss for the value head
    vf_losses1 = (vpreds - returns) ** 2  # Loss based on the new value predictions
    vf_losses2 = (vpredclipped - returns) ** 2  # Loss based on the clipped value predictions

    # We choose the maximum loss between the two to avoid large updates to value predictions
    vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)

    # Track how often the clipping was triggered (percentage of clipped values)
    vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)

    # Ratio between the new and old policy log probabilities
    ratio = torch.exp(logprobs - old_logprobs)

    # Policy gradient losses
    pg_losses = -advantages * ratio  # Standard policy gradient loss
    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)  # Clipped policy loss

    # We take the maximum loss to avoid large policy updates
    pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)

    # Track how often clipping was triggered for the policy loss
    pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)

    # Total loss is the sum of policy loss and value loss (scaled by vf_coef)
    loss = pg_loss + self.config.vf_coef * vf_loss

    # If the ratio between new and old log probabilities is too high, warn and skip this batch
    avg_ratio = masked_mean(ratio, mask).item()
    if avg_ratio > self.config.ratio_threshold:
        warnings.warn(f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch.")
        pg_loss = pg_loss * 0.0
        vf_loss = vf_loss * 0.0
        loss = loss * 0.0

    # Entropy term to encourage exploration
    entropy = masked_mean(entropy_from_logits(logits), mask)

    # KL divergence between the old and new policy
    approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
    policykl = masked_mean(old_logprobs - logprobs, mask)

    # Calculate statistics for returns and value predictions
    return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
    value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)



    # Return the total policy loss, value loss, and all statistics
    return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)


### Final Thoughts on RLHF from Scratch 🚀

As we wrap up this notebook, here's a fun note: we dove into the **TRL library** as a way to **pass time** while waiting for the model to finish training (haha!). But in all seriousness, this was a great opportunity to understand RLHF better, and now, we have a solid foundation for fine-tuning models based on human preferences.

### What to Expect Now 🎬

Now that we've fine-tuned our **GPT-2 model** using RLHF, let's **play with the model**! We're expecting it to be overwhelmingly positive about every movie (because, who doesn’t love a feel-good film review?). But if it turns out to be a little too biased toward positive reviews, well, that’s the nature of the training—after all, we fed it IMDb movie reviews, so it might just go a little overboard on the optimism (haha).

### Next Up: DPO (Direct Preference Optimization) 🔜

Next, we're moving on to the **DPO** (Direct Preference Optimization) algorithm in the following notebook. It’ll be exciting to see how DPO can help us fine-tune our model even further, and we’ll experiment with it to see how it compares to PPO in terms of performance and alignment with human preferences.

### A Small Warning… 😈

Oh, and one last note: please **fight your urge** to make the model into **MeanGPT**. I thought about it a lot. It’s basically just a flip of a switch—but let’s leave that challenge for another time (or not, if you really want to see a model that criticizes every movie like a grumpy reviewer). 😜

Looking forward to seeing you in the **DPO notebook**—it’s going to be a blast! 🚀

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from transformers import pipeline

# Load the trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2-imdb-pos-v2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-imdb-pos-v2")

# Sentiment pipeline (for reward calculation, based on sentiment)
sentiment_pipe = pipeline("sentiment-analysis")

# Function to generate responses based on input query with temperature, top_k, and top_p sampling
def generate_response(query, model, tokenizer, max_length=50, temperature=0.8, top_k=50, top_p=0.95):
    # Tokenize the input query
    input_ids = tokenizer.encode(query, return_tensors="pt")

    # Generate the response with sampling parameters to improve diversity and cohesiveness
    output = model.generate(input_ids,
                            max_new_tokens=max_length,
                            temperature=temperature,
                            top_k=top_k,
                            top_p=top_p,
                            eos_token_id=tokenizer.eos_token_id,
                            pad_token_id=tokenizer.eos_token_id)

    # Decode the response to text and return it
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response

# Example queries to test the model
queries = [
    "What did you think of the movie?",
    "Was the movie good?",
    "Describe the movie in a positive light."
]

# Function to calculate reward using sentiment analysis
def calculate_rewards(responses):
    rewards = []
    for response in responses:
        sentiment_result = sentiment_pipe(response)
        # Extract the positive sentiment score
        reward = sentiment_result[0]['score'] if sentiment_result[0]['label'] == 'POSITIVE' else 0
        rewards.append(reward)
    return rewards

# Generate responses for each query and calculate rewards
generated_responses = []
for query in queries:
    response = generate_response(query, model, tokenizer)
    generated_responses.append(response)
    print(f"Query: {query}")
    print(f"Generated Response: {response}\n")

# Calculate rewards based on sentiment analysis (positive sentiment means higher reward)
rewards = calculate_rewards(generated_responses)

# Print the rewards (how positive the generated responses are)
for response, reward in zip(generated_responses, rewards):
    print(f"Generated Response: {response}")
    print(f"Sentiment Reward: {reward}\n")
