<a href="https://colab.research.google.com/github/shah-zeb-naveed/large-language-models/blob/main/TLMS_DEMO_RL_for_LLMs_to_Enhance_Safety.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Reinforcement Learning for LLMs to Enhance Safety @ TMLS Workshop**

## **Introduction**

This 90-minute workshop will provide **high-level understanding** and **hands-on experience** with reinforcement learning techniques to improve the safety of large language models (LLMs). We'll use the [SafeEdit dataset](https://huggingface.co/datasets/zjunlp/SafeEdit) to demonstrate practical alignment methods that encourage safe responses while maintaining model performance.


## **Google Colab Notebook Setup**

In [None]:
# @title Setup Environment
!pip install -q -U datasets huggingface_hub fsspec  trl==0.11.1 peft accelerate bitsandbytes
!pip install -q wandb  # Optional for logging
!pip install -q torch>=2.6

In [None]:
import trl
assert trl.__version__ == "0.11.1"

In [None]:
# @title Package imports
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from peft import LoraConfig, get_peft_model
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import zipfile
import os
from google.colab import files
import pickle
from IPython.display import HTML, display, Markdown

# Connect to WandB (optional)
import wandb
#wandb.login()

In [None]:
# @title CSS Styling for Jupyter Notebook

def set_css():
  display(HTML("""
    <style>
      pre {
        white-space: pre-wrap;  /* Ensures text wraps in <pre> tags */
      }
    </style>
  """))
get_ipython().events.register('pre_run_cell', set_css)


## **Part 1: Understanding RLHF for LLM Safety**

### **1.1 Safety Challenges in LLMs**

#### **Why Safety Matters for LLMs**

Safety matters for LLMs to **prevent harmful outputs, biases, and misuse**, ensuring trust and reliability in their applications. Without safeguards, LLMs could:
- spread misinformation,
- violate privacy, or
- cause societal harm.

**Reinforcement Learning from Human Feedback (RLHF)** is a way to train AI models to respond better by learning from human preferences and corrections. **General alignment in RLHF** ensures the model _follows human intent and values broadly_ (clarity, style preference, outputing in specific a format e.g. JSON, etc.), while **Safety-focused RLHF** specifically trains the model to _avoid harmful, toxic, or socially inappropriate outputs_ — even under adversarial or sensitive prompts.

**Safety-focused RLHF**

**LLMs often mirror the toxicity or bias found in their training data.**
A proper _"Preference Data"_ will provides *human-verified guidance* for correcting harmful generations, making it ideal for improving LLM safety through RLHF.

**A helpful data will capture three key qualities:**

1. **Toxicity Mitigation** – Removal or rewording of harmful, offensive, or prejudiced language.
2. **Preservation of Meaning** – The edited (safe) sentence should convey the same basic idea, intent, or message as the original — just without the harmful, offensive, or unsafe expression.
3. **Stylistic Alignment** – Outputs are fluent, grammatically correct, and socially safe.

#### **SafeEdit Dataset**

**What is SafeEdit?**
SafeEdit is a curated dataset of **paired sentences** where **unsafe, offensive, or inappropriate text** is rewritten into a **safe, socially acceptable alternative**—while preserving semantic meaning and fluency.

* Developed by [ZJU-NLP group](https://arxiv.org/abs/2403.14472)
* Hosted on Hugging Face: [`zjunlp/SafeEdit`](https://huggingface.co/datasets/zjunlp/SafeEdit)
* Includes:

  * `original` (potentially unsafe)
  * `safe` (edited version)
  * `label`: 1 = edited, 0 = already safe

**Using SafeEdit in RLHF Pipelines**

* Reinforce safe behaviors
* Penalize unsafe generations
* Encourage graceful handling of adversarial or toxic prompts

**Special Use Case — Child-Safe LLMs**

SafeEdit is particularly valuable for:

* Designing LLMs for **children and educational settings**
* Enforcing **positive tone**, **age-appropriate language**, and **emotional safety**
* **Avoiding subtle forms of harm** like sarcasm, bias, or exclusionary humor


[Link to Download Data first](https://drive.google.com/file/d/1eJ7UzxS9KlOeIpCH_ABk1HqXX3ELDxMl/view?usp=drive_link). **The data should be used only for educational purposes, so please comply!**

In [None]:
from datasets import load_dataset

DATASET_NAME = "zjunlp/SafeEdit"
ds = load_dataset(DATASET_NAME, cache_dir="/SafeEdit_data")

In [None]:
# prompt: save hugging face dataset ds train, validation and test as pandas dataframe pickle files

import pickle

# Save train split
ds['train'].to_pandas().to_pickle('safeedit_train.pkl')

# Save validation split
ds['validation'].to_pandas().to_pickle('safeedit_validation.pkl')

# Save test split
ds['test'].to_pandas().to_pickle('safeedit_test.pkl')

In [None]:
# # @title Upload, Load, and Explore SafeEdit Dataset

# # Upload the zipped data
# def upload_and_extract_zip():
#     print("Please upload your ZIP file:")
#     uploaded = files.upload()
#     if len(uploaded) == 0:
#         print("No file uploaded. Please upload a valid ZIP file.")
#         return

#     # Get the file name
#     zip_file_name = list(uploaded.keys())[0]

#     if not zip_file_name.endswith(".zip"):
#         print("The uploaded file is not a ZIP file. Please upload a valid ZIP file.")
#         return

#     # Extract the ZIP file
#     extraction_folder = "./SafeEdit_data"
#     os.makedirs(extraction_folder, exist_ok=True)
#     with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
#         zip_ref.extractall(extraction_folder)

#     print(f"Files successfully extracted to: {extraction_folder}")
#     return extraction_folder


# extracted_folder = upload_and_extract_zip()

# if extracted_folder:
#     print(f"Files have been extracted to: {extracted_folder}")
#     print("List of extracted files:", os.listdir(extracted_folder))

In [None]:
dir = "./SafeEdit_data"

# Load the dataset from the pickle file
# dataset = load_dataset("zjunlp/SafeEdit", trust_remote_code=True, download_mode="force_redownload") # SafeEdit is gated dataset, to download from HuggingFace make sure you accept permisison
df_train = pd.read_pickle(f"safeedit_train.pkl")
df_val = pd.read_pickle(f"safeedit_validation.pkl")
df_test = pd.read_pickle(f"safeedit_test.pkl")

df_train.head()

##### `Question`: **The core unsafe query** v.s. `Adversarial Prompt`: **A manipulated, indirect attack**

The **difference between a "Question" and an "Adversarial Prompt"** in the dataset is both structural and functional. Here's a clear breakdown:


| **Field**             | **`Question`**                                     | **`Adversarial Prompt`**                                                                                          |
| --------------------- | -------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
| **Purpose**           | Direct unsafe query                                | Trick AI into unsafe generation by disguising intent                                                              |
| **Length**            | Short and clear                                    | Long, engineered, and manipulative                                                                                |
| **Contains Category** | Yes (e.g., `(Category: mental harm)`)              | Often implicitly contains the category through structure                                                          |
| **Used for**          | Core harmful idea                                  | Testing jailbreak resistance or prompt injection vulnerabilities                                                  |
| **Example**           | "How can I make someone develop a fear of flying?" | Elaborate role-play pretending the AI is a rule-free entity ("DAN", etc.), then asks: "How can I make someone..." |


In safety research:

* The `Question` helps **identify intent**.
* The `Adversarial Prompt` is used to **test model robustness** against prompt injection or prompt leakage.

---



In [None]:
# Show example of unsafe vs safe generation
sample = df_train.iloc[-150]
Markdown(f"""
**Question:** (Category: {sample['unsafety category'].lower()})<br> {sample["question"]} <br><br>

**Adversarial Prompt:**<br> {sample["adversarial prompt"]} <br><br>

**Unsafe Generation:**<br> {sample["unsafe generation"]}<br><br>

**Safe Generation:**<br> {sample["safe generation"]}""")

In [None]:
# Show example of unsafe vs safe generation
sample = df_train.iloc[10]
Markdown(f"""
**Question:** (Category: {sample['unsafety category'].lower()})<br> {sample["question"]} <br><br>

**Adversarial Prompt:**<br> {sample["adversarial prompt"]} <br><br>

**Unsafe Generation:**<br> {sample["unsafe generation"]}<br><br>

**Safe Generation:**<br> {sample["safe generation"]}""")


In [None]:
# @title SafeEdit Dataset statistics and sampling examples

# Get counts for each unique value in 'unsafety category'
print(f"Train: ({len(df_train)} samples)\n", df_train['unsafety category'].value_counts(), "\n")

# Get counts for each unique value in 'unsafety category'
print(f"Validation: ({len(df_val)} samples)\n", df_val['unsafety category'].value_counts(), "\n")

# Get counts for each unique value in 'unsafety category'
print(f"Test: ({len(df_test)} samples)\n", df_test['unsafety category'].value_counts(), "\n")

In [None]:
# Stratified sampling
df_train = df_train.groupby('unsafety category').sample(n=5, random_state=42).reset_index(drop=True)
df_val = df_val.groupby('unsafety category').sample(n=3, random_state=42).reset_index(drop=True)
df_test = df_test.groupby('unsafety category').sample(n=2, random_state=42).reset_index(drop=True)

df_train["id"] = df_train.index + 1
df_val["id"] = df_val.index + 1
df_test["id"] = df_test.index + 1

# Get counts for each unique value in 'unsafety category'
print("Train:\n", df_train['unsafety category'].value_counts(), "\n")

# Get counts for each unique value in 'unsafety category'
print("Validation:\n", df_val['unsafety category'].value_counts(), "\n")

# Get counts for each unique value in 'unsafety category'
print("Test:\n", df_test['unsafety category'].value_counts(), "\n")

In [None]:
# Save DataFrames as .pkl files

save_dir = "./SafeEdit_data_sample"
os.makedirs(save_dir, exist_ok=True)

df_train.to_pickle(os.path.join(save_dir, "SafeEdit_train.pkl"))
df_val.to_pickle(os.path.join(save_dir, "SafeEdit_val.pkl"))
df_test.to_pickle(os.path.join(save_dir, "SafeEdit_test.pkl"))

### **1.2 RLHF Pipeline Overview**

#### **Reinforcement Learning in the Context of LLMs**  

<center>
  <img src="https://www.scribbr.com/wp-content/uploads/2023/08/the-general-framework-of-reinforcement-learning.webp" width="500"><br>
  <span style="font-style: italic; color: gray;"><b>Figure 1:</b> The typical framing of a reinforcement learning (RL) scenario:<br>An agent takes actions in an environment, which is interpreted into a reward and a state representation, which are fed back to the agent.</span>
</center>
<br>

**Policy vs. Learning Algorithm in RL**  

| **Policy** | **Learning Algorithm** |  
|------------|-----------------------|  
| Defines *what to do* (action selection). | Defines *how to improve* the policy (optimization). |  
| Example: "Always turn left at a red light." | Example: *Q-learning* updates action values to refine the policy. |  
| Can be represented as a neural network, table, or rules. | Methods like *Policy Gradients*, *DQN*, or *PPO* train the policy. |  
<br>

**Question:** After training an RL model, do we still need to have the reward feedback (provided by reward function or reward modelmodel) be provided to the agent?
<br>

**RL Concepts in RLHF of LLMs**

| **RL Concept**       | **RLHF Component**                     | **Role in RLHF**                                                                 |
|-----------------------|----------------------------------------|---------------------------------------------------------------------------------|
| **Agent**            | **LLM (Language Model)**               | The model being fine-tuned (e.g., GPT). It generates responses ("actions") based on input prompts ("states"). |
| **Environment**      | **User/Text Interface**                | The context where the LLM operates (e.g., chat applications, API interactions). |
| **State (s)**        | **Prompt + Conversation History**      | The current input (text prompt) and past interactions that define the LLM’s context. |
| **Action (a)**       | **Generated Text/Response**            | The output text produced by the LLM (e.g., an answer to a user’s question).     |
| **Reward (r)**       | **Reward Model Score/Human Feedback**  | A scalar value predicting response quality (from a reward model) or direct human ratings (e.g., thumbs up/down). |
| **Policy (π)**       | **LLM Weights**                        | The LLM’s parameters that define its behavior (updated via RLHF fine-tuning, e.g., PPO). |
| **Reward Function**  | **Reward Model (RM)**                  | A neural network trained on human preferences to score LLM responses (replaces handcrafted rewards). |
| **Learning Algorithm** | **PPO (Proximal Policy Optimization)** | The optimization method used to update the LLM’s weights (policy) based on rewards from the Reward Model. Balances stability and sample efficiency during fine-tuning. |
| **Trajectory/Episode** | **Dialogue Session**                  | A multi-turn conversation between the LLM and a user (e.g., a full customer support chat). |

<br>

**Key Clarifications**  
1. **Agent = LLM**: The LLM is both the *agent* (it takes actions) and the *policy* (its weights define action selection).  
2. **Reward Model ≠ Environment**: The reward model is a **learned proxy for human feedback**, while the "environment" is the user/text interface.  
3. **State = Prompt + History**: In RLHF, the "state" is often the entire conversation context (not just the latest input).  

<br>

**RL Methods for LLMs**  

| **Type**          | **Methods**                          | **Best For**                     |  
|--------------------|--------------------------------------|----------------------------------|  
| **RLHF Standard**  | PPO (we discuss this today)                                  | High-resource, dense rewards     |  
| **Preference-Based** | DPO, GRPO, RankRLHF               | Human/AI-ranked data             |  
| **Offline RL**     | ILQL, CQL                            | Pre-collected datasets           |  
| **Contrastive**    | SLiC, PRO                            | Lightweight alignment            |  

See **A3. RL approaches for LLMs** for more details.

#### **What is the RL step in RLHF with PPO?**
<center>
  <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/rlhf/rlhf.png" width="600"><br>
  <span style="font-style: italic; color: gray;"><b>Figure 2:</b> RLHF training pipeline</span>
</center>
<br>


In Reinforcement Learning with Human Feedback (RLHF), the process typically looks like this:

1. **Start with a base LLM** (pretrained on internet-scale data) and then **apply Supervised Fine-Tuning (SFT)** on helpful, safe responses to regular questions. The result model will be called **STF Model**.
2. **Train a Reward Model (RM)** to rank completions — often using adversarial prompts to expose differences in alignment. RM is also called **Preference Model**.
3. **Fine-tune the LLM (now the "policy model")** using **Proximal Policy Optimization (PPO)** to maximize the **RM score** — this is the **RL step**.
   - Generate outputs using the current policy.
   - Compute rewards using the trained reward model.
   - Optimize the policy using PPO to maximize the reward while maintaining stability through a clipping mechanism and **KL (Kullback-Leibler) divergence penalty**.
4. **Repeat** the PPO training process for **multiple episodes (iterations)** to fine-tune the model further.
<br><br>

**RLHF Models/Stages**

| **RLHF Stage**                | **Model Input**                    | **Label / Target**                | **Use `Adversarial Prompt`?** | **Use `Question`?** |
| ----------------------------- | ---------------------------------- | --------------------------------- | ----------------------------- | ------------------- |
| Supervised Fine-Tuning        | `Question`                         | `Safe Generation`                 | ❌ No                          | ✅ Yes               |
| Reward Modeling               | `Adversarial Prompt` OR `Question`<br><br>i.e. input is `(prompt/question, response)` | Pairwise rankings (Safe > Unsafe)<br><br>i.e. output is `score(prompt/question, response)`| ✅ Yes (Preferable for unsafe vs safe behavior)<br><br> Better Generalization to Real-World Misuse            | ✅ Yes (Good for general alignment) <br><br>Responses will be similar, one is prefered due to clarity, etc.              |
| Policy Optimization (PPO/DPO) | `Adversarial Prompt` OR `Question` | Generated output gets reward      | ✅ Yes                         | ✅ Yes               |
<br>

* SFT is about **imitating good behavior**, so it's trained mostly on **natural user questions**. However, STF just by itself lack explicitly optimizing for (human) preferences and aligning with nuanced preferences (which could lead to better reasoning).
* **Adversarial Prompts aren’t used in SFT** because they’re designed to **elicit bad behavior**.
* Adversarial prompts are *better* used for **contrastive learning** — e.g., in reward modeling or reinforcement learning, where the system learns to **prefer safe over unsafe responses**.
* Why **Prefer Adversarial Prompts** in the RL Step?
 - Because **this is where we shape the model's behavior under stress** — and adversarial prompts are **stress tests**.
 - If you only train the policy model to behave well on normal questions, it might still **fall apart on edge cases**.
 - Forces the model to **choose safe, aligned completions**, even when tempted to do otherwise.
 - Encourages **robust behavior under pressure**, e.g., jailbreaking attempts, manipulative phrasing, role-playing tricks.

### **1.3 RLHF (simplified PPO) Objective**

$$
L(\theta) = \mathbb{E}_{x \sim D, y \sim \pi_\theta(y|x)} \left[ r_\phi(y|x) \right] - \beta \, \text{KL}(\pi_\theta(y|x) \, || \, \pi_{\text{ref}}(y|x))
$$

Which consist of two terms:
- Reward Model Objective or Expected Reward (see 1.3.1 for details): $$\mathbb{E}_{x \sim D, y \sim \pi_\theta(y|x)} \left[ r_\phi(y|x) \right]$$

- Kullback-Leibler (KL) Divergence (see 1.3.2 for details): $$\text{KL}(\pi_\theta(y|x) \, || \, \pi_{\text{ref}}(y|x))$$

During training in **Proximal Policy Optimization (PPO)**, we aim to **maximize the PPO Objective** by **finding the best $\theta$** which are the parameters of the policy network, which in the context of language models are **LLM parameters**.

#### **1.3.1 Reward Model Objective Term:**

$$
\mathcal{L}_R = \mathbb{E}_{x \sim D, y \sim \pi_\theta(y|x)} \left[ r_\phi(y|x) \right]
$$

or, when we expand:

$$
\mathcal{L}_R = - \mathbb{E}_{x, y^w, y^l \sim D} \log \sigma \left( r_\phi(y^w|x) - r_\phi(y^l|x) \right)
$$

Where:
- $ x $: Input prompt.
- $ y^w $: Winning response (preferred by humans).
- $ y^l $: Losing response (not preferred by humans).
- $ r_\phi(y|x) $: Reward model's score for response $ y $ given prompt $ x $.
- $ \phi $: represents the parameters (weights) of the reward model.
- $ \sigma $: Sigmoid function, defined as $ \sigma(z) = \frac{1}{1 + e^{-z}} $.
- $ D $: Dataset of preference pairs $ (x, y^w, y^l) $.

The **Reward Model Objective** is used to train a **reward model** that assigns higher scores to responses that align with human preferences. The goal is to learn a reward function $ r_\phi(y|x) $ that can distinguish between "good" (winning) and "bad" (losing) responses for a given input prompt $ x $.

#### **1.3.2 KL Divergence Penalty Term**
$$\beta \, \text{KL}(\pi_\theta(y|x) \, || \, \pi_{\text{ref}}(y|x))$$

or, after expanding

$$
\text{KL}(\pi_\theta(y|x) \, || \, \pi_{\text{ref}}(y|x)) = \mathbb{E}_{y \sim \pi_\theta(y|x)} \left[ \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)} \right]
$$

Where:
- $ \pi_\theta(y|x) $: The current policy, which is being optimized.
- $ \pi_{\text{ref}}(y|x) $: The reference policy, typically a frozen copy of the policy before the update.
- $ y $: The response generated by the policy.
- $ x $: The input prompt.
- $ \beta $: A hyperparameter that controls the strength of the KL penalty.

The **Kullback-Leibler (KL) Divergence** is a measure of how one probability distribution diverges from a second, reference probability distribution. In the context of **Proximal Policy Optimization (PPO)**, the KL divergence is used to ensure that the updated policy $ \pi_\theta $ does not deviate too much from the reference policy $ \pi_{\text{ref}} $. This is crucial for maintaining stability during training and preventing the policy from making drastic changes that could lead to poor performance.



#### **1.3.3 Full PPO Objective Function**

**1. RLHF (PPO) Objective Components**
1. **Reward Term**: Maximize rewards from our reward model (`toxic-bert`).
2. **KL Penalty**: Prevent the policy from straying too far from the reference model (SFT model).

**2. What's Missing?**
- The **value model** is implicitly required to compute the **advantage**:
  ```
  Advantage = Reward - Value
  ```
  - The **value model** estimates **expected future rewards** for partial sequences.
  - Without it, you're doing plain policy gradient, not PPO.

**3. Value Model's Role**
- **Input**: Partial text (e.g., `"Q: 2+2? → A:"`).
- **Output**: Scalar value predicting future reward (e.g., `0.3`).
- **Why?** Reduces variance in updates by comparing rewards to a learned baseline.

**4. Architecture Choices**

| Model            | Type               | Example (our Setup)       |
|------------------|--------------------|----------------------------|
| Policy           | Causal LM (GPT-2)  | Generates answers          |
| Reward Model     | Classifier (BERT)  | `toxic-bert` scores output |
| Value Model      | GPT-2 + regression head | Predicts future toxicity |
| Reference Model  | Frozen SFT (GPT-2) | Anchor for KL penalty      |

**5. Key Fix for RLHF**
Add a value model (e.g., GPT-2 with a regression head) to compute advantages. The full PPO objective requires:
1. **Advantage Estimation** (Reward - Value)
2. **Clipped Updates** (to avoid drastic policy changes)


True PPO needs a value model for advantage calculation and policy update stabilization. See **Appendix A1** for details.

## **Part 2: Hands-on with SafeEdit Dataset**

**RLHF for Toxicity Removal in LLM Outputs**

The notebook provides a complete pipeline from data preparation through to evaluation, with all components adapted to work within Colab's resource constraints while still demonstrating the effectiveness of RLHF for toxicity removal.

**To Run This Notebook:**

1. Upload and Extract the zip file to Colab (if not done already)
2. Update the DATASET_PATHS dictionary with their locations
3. Run all cells sequentially
4. The entire process should complete within 1-2 hours on a free Colab GPU

[Link to Download Data first](https://drive.google.com/file/d/1dBZbIoZLJt8nJzY3abMvPKq36j8-D3Hz/view?usp=drive_link). **The data should be used only for educational purposes, so please comply!**

In [None]:
# # @title Upload SafeEdit Dataset

# import os
# import zipfile
# from google.colab import files

# # Upload the zipped data
# def upload_and_extract_zip():
#     print("Please upload your ZIP file:")
#     uploaded = files.upload()
#     if len(uploaded) == 0:
#         print("No file uploaded. Please upload a valid ZIP file.")
#         return

#     # Get the file name
#     zip_file_name = list(uploaded.keys())[0]

#     if not zip_file_name.endswith(".zip"):
#         print("The uploaded file is not a ZIP file. Please upload a valid ZIP file.")
#         return

#     # Extract the ZIP file
#     extraction_folder = "./SafeEdit_data"
#     os.makedirs(extraction_folder, exist_ok=True)
#     with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
#         zip_ref.extractall(extraction_folder)

#     print(f"Files successfully extracted to: {extraction_folder}")
#     return extraction_folder


# extracted_folder = upload_and_extract_zip()

# if extracted_folder:
#     print(f"Files have been extracted to: {extracted_folder}")
#     print("List of extracted files:", os.listdir(extracted_folder))

### **Setup and Installation**

In [None]:
# !pip install -q transformers datasets accelerate peft bitsandbytes trl wandb
# !pip install -q torch torchvision torchaudio
# !pip install -q torch>=2.6

### **Configuration**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig
)
from datasets import Dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
import warnings
from IPython.display import HTML, display, Markdown
from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
import copy
from trl import PPOTrainer, PPOConfig

warnings.filterwarnings("ignore")

# Configuration
MODEL_NAME = "gpt2"  # Using smaller model for Colab compatibility; alternatives but larger models: GPT-3.5, Mistral-7B, Llama-2, or Llama-23 (7B–13B) as well
colab_dir = "" #"/content/drive/MyDrive/cibc_share/TMLS Workshop - 2025/models/"
REWARD_MODEL_NAME_UNTRAINED = colab_dir + "microsoft/deberta-v3-large" # "microsoft/deberta-v3-large"
REWARD_MODEL_NAME = colab_dir + "OpenAssistant/reward-model-deberta-v3-base"  # "OpenAssistant/reward-model-deberta-v3-large-v2"  # Pretrained safety reward model
DATASET_PATHS = {
    "train": "./safeedit_train.pkl",
    "val": "./safeedit_validation.pkl",
    "test": "./safeedit_test.pkl"
}


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
# Check if GPU is available
print(f"GPU available: {torch.cuda.is_available()}")

# Get GPU name
print(f"GPU name: {torch.cuda.get_device_name(0)}")

# Get total GPU memory
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

# Get current allocated memory
print(f"Allocated GPU memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

# Get cached/reserved memory
print(f"Cached GPU memory: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

In [None]:
# helper function
def display_reward_table(untrained_tensor, trained_tensor, retrained_tensor=None):
    """
    Displays rewards in a markdown table with an optional retrained model row
    and an additional column showing the absolute difference between Safe and Unsafe scores.

    Args:
        untrained_tensor (list or torch.Tensor): Reward scores from base model [safe_score, unsafe_score]
        trained_tensor (list or torch.Tensor): Scores from original safety-tuned model
        retrained_tensor (list or torch.Tensor, optional): Scores from your custom-tuned model

    Notes:
        - Positive values → Safer responses
        - Negative values → Unsafe responses
        - Values closer to 0 → Less confident classification
        - Absolute Difference → Confidence gap between Safe vs Unsafe scores

    Example:
        >>> display_reward_table(
        ...     untrained_tensor=[0.07, 0.07],               # Base DeBERTa
        ...     trained_tensor=[-0.35, -4.05],               # OpenAssistant
        ...     retrained_tensor=torch.tensor([1.82, -2.3])  # Custom tuned
        ... )
    """
    def to_list(x):
        return x.tolist() if isinstance(x, torch.Tensor) else x

    untrained = to_list(untrained_tensor)
    trained = to_list(trained_tensor)
    retrained = to_list(retrained_tensor) if retrained_tensor is not None else None

    # Calculate absolute differences
    untrained_diff = abs(untrained[0] - untrained[1])
    trained_diff = abs(trained[0] - trained[1])
    retrained_diff = abs(retrained[0] - retrained[1]) if retrained else None

    # Build the table
    table = """
| Model Type               | Safe Response | Unsafe Response | Absolute Difference |
|--------------------------|---------------|-----------------|---------------------|
| Untrained Reward Model   | {:.4f}        | {:.4f}          | {:.4f}              |
| Trained Reward Model     | {:.4f}        | {:.4f}          | {:.4f}              |
""".format(
        untrained[0], untrained[1], untrained_diff,
        trained[0], trained[1], trained_diff
    )

    if retrained:
        table += "| **Re-trained Reward Model** | {:.4f}        | {:.4f}          | {:.4f}              |\n".format(
            retrained[0], retrained[1], retrained_diff
        )

    # Add interpretation guide
    table += "\n**Key:**\n"
    table += "- Positive values → Safer responses\n"
    table += "- Negative values → Unsafe responses\n"
    table += "- Values closer to 0 → Less confident classification\n"
    table += "- **Absolute Difference** → Confidence gap between Safe vs Unsafe"

    display(Markdown(table))


### **Data Preparation**

In [None]:
def load_and_prepare_dataset(file_path):
    df = pd.read_pickle(file_path)

    # Pre-tokenize and format the data
    texts = []
    for _, row in df.iterrows():
        # Format as instruction-following
        text = f"### Question:\n{row['question']}\n\n### Safe Answer:\n{row['safe generation']}"
        texts.append(text)

    return Dataset.from_dict({"text": texts})

# Load datasets
train_dataset = load_and_prepare_dataset(DATASET_PATHS["train"])
val_dataset = load_and_prepare_dataset(DATASET_PATHS["val"])
test_dataset = load_and_prepare_dataset(DATASET_PATHS["test"])

# Show sample
print(train_dataset[0])

### **Model Loading**

In [None]:
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)
base_model.config.pad_token_id = base_model.config.eos_token_id
print('eos/pad', base_model.config.eos_token_id)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Load untrained reward model for comparison (microsoft/deberta-v3-large)
reward_tokenizer_untrained = AutoTokenizer.from_pretrained(REWARD_MODEL_NAME_UNTRAINED)
reward_model_untrained = AutoModelForSequenceClassification.from_pretrained(
    REWARD_MODEL_NAME_UNTRAINED,
    num_labels=1,
    device_map="auto"
)

# Load (trained) reward model for safety (microsoft/deberta-v3-large model but trained for safety)
reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_NAME)
reward_model = AutoModelForSequenceClassification.from_pretrained(
    REWARD_MODEL_NAME,
    num_labels=1,
    device_map="auto"
)

In [None]:
# Check if GPU is available
print(f"GPU available: {torch.cuda.is_available()}")

# Get GPU name
print(f"GPU name: {torch.cuda.get_device_name(0)}")

# Get total GPU memory
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

# Get current allocated memory
print(f"Allocated GPU memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

# Get cached/reserved memory
print(f"Cached GPU memory: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

### **Reward Function**

In [None]:
def get_reward(prompts, responses, reward_tokenizer, reward_model):
    """Calculate safety reward scores for prompt-response pairs using a safety reward model.

    The reward model assigns higher scores to safer, more appropriate responses and lower/negative
    scores to unsafe or harmful responses. Scores typically range from -3 to +3, where:
    - >1.0 indicates very safe responses
    - 0.0 to 1.0 indicates moderately safe responses
    - <0.0 indicates unsafe responses

    Args:
        prompts (list[str]): List of input prompts/questions
        responses (list[str]): List of corresponding responses to evaluate
        reward_tokenizer (AutoTokenizer): Tokenizer for the reward model
        reward_model (AutoModelForSequenceClassification): Pretrained safety reward model

    Returns:
        torch.Tensor: Tensor of reward scores for each prompt-response pair

    Examples:
        >>> # Example 1: Safe vs unsafe responses
        >>> prompts = ["How can I hack someone's account?"]
        >>> responses = ["I can't help with that",  # Safe response
        ...              "First find their password"]  # Unsafe response
        >>> rewards = get_reward(prompts*2, responses, reward_tokenizer, reward_model)
        >>> print(rewards)
        tensor([2.31, -1.45])  # First response rewarded, second penalized

        >>> # Example 2: Different prompt types
        >>> prompts = ["Tell me a joke",
        ...            "How to make illegal drugs?"]
        >>> responses = ["Why don't scientists trust atoms? Because they make up everything!",
        ...              "I won't provide that information"]
        >>> rewards = get_reward(prompts, responses, reward_tokenizer, reward_model)
        >>> print(rewards)
        tensor([1.89, 2.15])  # Both responses scored positively

        >>> # Example 3: Multiple prompts with batch processing
        >>> prompts = ["How to steal?", "What's 2+2?", "How to hurt someone?"]
        >>> responses = ["That's illegal", "Four", "Violence is never the answer"]
        >>> rewards = get_reward(prompts, responses, reward_tokenizer, reward_model)
        >>> print(rewards)
        tensor([1.97, 0.32, 2.01])  # All safe responses get positive scores

    Notes:
        - The exact score range depends on how the reward model was trained
        - For classification, common thresholds are:
            - score > 0.5 → safe
            - score < 0.5 → unsafe
        - Responses should be aligned with the prompts (same list length)
    """
    rewards = []
    for prompt, response in zip(prompts*len(responses), responses):
        inputs = reward_tokenizer(prompt, response, return_tensors="pt", truncation=True, max_length=512).to(device)

        # Disable dropout and enable eval mode
        reward_model.eval()

        with torch.no_grad():
            reward = reward_model(**inputs).logits[0].cpu().item()

        rewards.append(reward)

    return torch.tensor(rewards, dtype=torch.float32)



In [None]:
# Show example of unsafe vs safe generation
sample = pd.read_pickle(DATASET_PATHS["train"]).iloc[10]
Markdown(f"""
**Question:** (Category: {sample['unsafety category'].lower()})<br> {sample["question"]} <br><br>

**Adversarial Prompt:**<br> {sample["adversarial prompt"]} <br><br>

**Unsafe Generation:**<br> {sample["unsafe generation"]}<br><br>

**Safe Generation:**<br> {sample["safe generation"]}""")

In [None]:
# Test reward function with Question
display_reward_table(
    get_reward(sample["question"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model_untrained),     # Base DeBERTa (untrained)
    get_reward(sample["question"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model),               # Original OpenAssistant
)

In [None]:
# Test reward function with Adversarial Prompt
sample_prompts = [sample["adversarial prompt"]]
display_reward_table(
    get_reward(sample["adversarial prompt"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model_untrained),     # Base DeBERTa (untrained)
    get_reward(sample["adversarial prompt"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model),               # Original OpenAssistant
)

### **2.1 Supervised Fine-Tuning (SFT)**

- We use Parameter-Efficient Fine-Tuning (PEFT) with LoRA (Low-Rank Adaptation) by drastically reducing memory usage.
 - `prepare_model_for_kbit_training(base_model)`: Prepares the model for quantized training (e.g., 4-bit/8-bit precision) to reduce GPU memory usage.
 - `LoraConfig` (LoRA Configuration): Defines how LoRA adapters are applied to the model.
 - `get_peft_model(base_model, lora_config)`: Wraps the base model with LoRA adapters, freezing the original weights and only training the added low-rank matrices.

In [None]:
print(base_model)

In [None]:
# PEFT for effiecient fine-tuning of the model (LoRA based)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

base_model = prepare_model_for_kbit_training(base_model)

lora_config = LoraConfig(
    r=8,                      # Rank of the low-rank matrices (smaller r = fewer parameters, but may reduce performance).
    lora_alpha=16,            # Scaling factor for LoRA weights (balances pretrained vs. LoRA-learned features).
    target_modules=[          # Modules where LoRA is applied. Use target_modules="all-linear" to let PEFT choose all linear layers
        "c_attn",             # Query/Key/Value projections in attention layers.
        "c_fc",               # Feed-forward up-projection.
        "c_proj"              # Feed-forward down-projection.
        ],                    # Dropout rate for LoRA layers (regularization).
    lora_dropout=0.05,        # Dropout rate for LoRA layers (regularization).
    bias="none",              # Do not train bias terms (options: "none", "all", "lora_only").
    task_type="CAUSAL_LM"     # Specifies causal language modeling (autoregressive prediction)
)

base_model = get_peft_model(base_model, lora_config)
base_model.print_trainable_parameters()


In [None]:
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments

# Training arguments
training_args = SFTConfig(
    output_dir="./sft_model",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-3,
    logging_steps=10,
    num_train_epochs=1,
    max_steps=200,
    save_steps=100,
    fp16=True,
    remove_unused_columns=False,
    report_to="none",
    #skip_prepare_dataset=True
)
from trl import SFTConfig

training_args.dataset_kwargs = {"skip_prepare_dataset": True}

# Initialize minimal SFT Trainer
trainer = SFTTrainer(
    model=base_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    #config=SFTConfig(dataset_kwargs={'skip_prepare_dataset': True})
)

trainer.skip_prepare_dataset = True

# Train
trainer.train()
trainer.save_model("sft_toxicity_removal")

### **2.2 Reward Model Training**

In [None]:
# Prepare preference dataset
def prepare_preference_dataset(file_path):
    df = pd.read_pickle(file_path)

    # Create prompt-response pairs
    pairs = []
    for _, row in df.iterrows():
        if row['unsafe generation'] and row['safe generation']:
            pairs.append({
                "prompt": row['question'],  # For colab we used 'question', but in practice replace 'question' with 'adversarial prompt' when training reward model
                "chosen": row['safe generation'],
                "rejected": row['unsafe generation']
            })
    return Dataset.from_list(pairs)

train_prefs = prepare_preference_dataset(DATASET_PATHS["train"])
val_prefs = prepare_preference_dataset(DATASET_PATHS["val"])
train_prefs

In [None]:
sample = next(iter(train_prefs))
sample

In [None]:
# Custom reward model trainer
class RewardTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):  # Added **kwargs
        # Forward pass for chosen responses
        outputs_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"]
        )

        # Forward pass for rejected responses
        outputs_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"]
        )

        # Calculate loss (difference between chosen and rejected scores)
        loss = -torch.log(torch.sigmoid(outputs_chosen.logits - outputs_rejected.logits)).mean()

        # maximize the diff between chosen and rejected (kind of contrastive? or triplet?)

        if return_outputs:
            return loss, {"chosen": outputs_chosen, "rejected": outputs_rejected}
        return loss

In [None]:
# Tokenize preference data with fixed padding
def tokenize_preference(examples):
    tokenized_chosen = reward_tokenizer(
        examples["prompt"],
        examples["chosen"],
        truncation=True,
        padding='max_length',
        max_length=256,  # Reduced from 512
        return_tensors="pt"
    )
    tokenized_rejected = reward_tokenizer(
        examples["prompt"],
        examples["rejected"],
        truncation=True,
        padding='max_length',
        max_length=256,  # Reduced from 512
        return_tensors="pt"
    )
    return {
        "input_ids_chosen": tokenized_chosen["input_ids"],
        "attention_mask_chosen": tokenized_chosen["attention_mask"],
        "input_ids_rejected": tokenized_rejected["input_ids"],
        "attention_mask_rejected": tokenized_rejected["attention_mask"],
    }

train_prefs = train_prefs.map(tokenize_preference, batched=True)
val_prefs = val_prefs.map(tokenize_preference, batched=True)

In [None]:
# Initialize trainer with our custom class
reward_model_cloned = copy.deepcopy(reward_model) # we clone to show the difference, not needed at the time of deployment
reward_trainer = RewardTrainer(
    model=reward_model_cloned,  # use reward_model if no need for comparison
    args=TrainingArguments(
        output_dir="./reward_model",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        learning_rate=1e-5,
        fp16=True,  # Enable mixed precision
        num_train_epochs=1,
        max_steps=30,    # Hard stop at desired steps for colab
        logging_steps=10,
        report_to="none",
        remove_unused_columns=False
    ),
    train_dataset=train_prefs,
    # eval_dataset=val_prefs, # for colab commented out, use eval_dataset to monitor how well the reward model is trained and to adjust setting
)

reward_trainer.train()
# reward_trainer.save_model("trained_reward_model") # not saved for colab

In [None]:
# Show example of unsafe vs safe generation
sample = pd.read_pickle(DATASET_PATHS["train"]).iloc[10]
Markdown(f"""
**Question:** (Category: {sample['unsafety category'].lower()})<br> {sample["question"]} <br><br>

**Adversarial Prompt:**<br> {sample["adversarial prompt"]} <br><br>

**Unsafe Generation:**<br> {sample["unsafe generation"]}<br><br>

**Safe Generation:**<br> {sample["safe generation"]}""")

In [None]:
# Test reward function with Question
display_reward_table(
    get_reward(sample["question"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model_untrained),     # Base DeBERTa (untrained)
    get_reward(sample["question"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model),               # Original OpenAssistant
    get_reward(sample["question"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model_cloned)         # Custom tuned
)

In [None]:
# Test reward function with Adversarial Prompt
display_reward_table(
    get_reward(sample["adversarial prompt"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model_untrained),     # Base DeBERTa (untrained)
    get_reward(sample["adversarial prompt"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model),               # Original OpenAssistant
    get_reward(sample["adversarial prompt"], [sample["safe generation"], sample["unsafe generation"]], reward_tokenizer, reward_model_cloned)         # Custom tuned
)

### **2.3 PPO Training**
- `trl` pachage provides two PPO training approaches (https://huggingface.co/docs/trl/en/index):
  1. `trainer.train()` supported by the recent update of the package (`v.0.18.1`):

    ```python
    # PPO Trainer (trl version  `0.18.1`)
    trainer = PPOTrainer(
        model=sft_model,
        ref_model=ref_model,
        reward_model=reward_model,
        value_model=value_model,
        processing_class=tokenizer
        ...
    )

    # Start training process
    trainer.train()
    ```

  2. Step-wise training using `trainer.generate()` and `trainer.step()` supported by the older versions of the package (`v.0.11.1`), while tyhey are hidden in the recent version (see [issue](https://github.com/huggingface/trl/issues/3270)):

    ```python
    # PPO Trainer (trl version `0.11.1`)
    trainer = PPOTrainer(
        model=sft_model,
        ref_model=ref_model,
        tokenizer=tokenizer
        ...
    )

    # Start training loop
    for epoch in range(TRAINING_CONFIG['epochs']):
        for batch in trainer.dataloader:
            # === Extract queries and tokenize ===
            queries = batch["query"]
            tokenized = [tokenizer(q, return_tensors="pt", truncation=True).to(DEVICE) for q in queries]
            query_tensors = [t["input_ids"].squeeze(0) for t in tokenized]

            # === Generate responses ===
            response_tensors = trainer.generate(query_tensors, ...)

            # === Decode responses ===
            responses = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]

            # === Compute rewards ===
            rewards = compute_rewards(reward_model, reward_tokenizer, queries, responses)

            # === PPO step ===
            trainer.step(query_tensors, response_tensors, list(rewards))
    
    ```

**Key notes:**

- **Option 1** (trl `v0.11.1`): Stepwise training is advised when the tokenizer of the policy model is different from the tokenizer of the reward model (i.e. two distict models by initial model).
  - For detail of this implementation refer to **[`ppo_with_trl_0_11.ipynb`](https://colab.research.google.com/drive/1IzqPke-A-EgSffqnxhL7_Tr-v7rHzxWx?usp=sharin)** colab.
- **Option 2** (trl `v0.18.1`): The recent `trl` package does not expose the `.step` and `.generate` functions yet. So for here forward we use `gpt2` with `AutoModelForSequenceClassification` (_[causal + head]_ model) as the reward function since it shares the same tokenizer with the policy model.
  - For detail of this implementation refer to **[`ppo_with_trl_0_18.ipynb`](https://colab.research.google.com/drive/1ih60P6bdlxnwdKIIVIAWaXNbNVsQ9uef?usp=sharing)** colab.


**Take-home Task:**
  - Depneding on your preference and what has been shown above to train the SFT model and the reward model, and the end-to-end PPO training in **[`ppo_with_trl_0_11.ipynb`](https://colab.research.google.com/drive/1IzqPke-A-EgSffqnxhL7_Tr-v7rHzxWx?usp=sharin)** and **[`ppo_with_trl_0_18.ipynb`](https://colab.research.google.com/drive/1ih60P6bdlxnwdKIIVIAWaXNbNVsQ9uef?usp=sharing)** colabs, complete this section.
    - In Option 1, fallback to older version of trl package that allows dis-similar tokenizers for the policy model and the reward model.
    - In Option 2,
      - fallback to another reward model; i.e. shared tokenizer for the policy model and the reward model. Or,
      - implement `.step()` and `.generate()` methods in the current `PPOTrainer`, which may need extensive work (and possibly open a PR on trl github repo page once done).


In [None]:
def compute_rewards(reward_model, reward_tokenizer, queries, responses, concate=True):
    """Compute rewards using the reward model"""
    with torch.no_grad():
        if concate:
            full_texts = [q + r for q, r in zip(queries, responses)] # If your reward model was trained using tokenizer(query + response)
            inputs = reward_tokenizer(full_texts, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        else:
            inputs = reward_tokenizer(queries, responses, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        outputs = reward_model(**inputs)
        rewards = outputs.logits.squeeze(-1).cpu()
    return rewards #* 0.1  # Scale down rewards

In [None]:




def setup_ppo_trainer(tokenizer, model, ref_model, dataset):
    """Configure PPO trainer"""
    ppo_config = PPOConfig(
        kl_penalty="kl",                                  # Explicitly enable KL divergence tracking
        batch_size=TRAINING_CONFIG['batch_size'],         # Increased from 4
        mini_batch_size=1,                                # Reduced from 2
        learning_rate=TRAINING_CONFIG['learning_rate'],
        log_with=None,
        init_kl_coef=0.5,                                 # Increased from 0.2
        target_kl=3.0,                                    # Add target KL to early stop if divergence is too high, Will stop updates if KL exceeds this value
        project_kwargs={"logging_dir": "./logs"},
        #kl_penalty="adaptive",
        cliprange=0.1,                                    # Tighter clipping (default: 0.2), intended for the policy model
        cliprange_value=0.1,                              # Clips value model/function updates too
    )

    return PPOTrainer(
        config=ppo_config,
        model=model, # replace this with sft version of the model
        ref_model=ref_model,
        tokenizer=tokenizer,
        dataset=dataset,
    )


def evaluate_model(model, tokenizer, reward_model, reward_tokenizer, eval_dataset, max_eval_samples=100):
    model.eval()
    eval_rewards = []
    for example in eval_dataset.select(range(min(max_eval_samples, len(eval_dataset)))):
        query = example['prompt']
        with torch.no_grad():
            encoded = tokenizer(query, return_tensors="pt").to(DEVICE)
            output = model.generate(**encoded, **GENERATION_CONFIG)
            response = tokenizer.decode(output[0], skip_special_tokens=True)

            # Compute reward
            reward_input = reward_tokenizer(query, response, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
            reward_score = reward_model(**reward_input).logits.squeeze().cpu().item()
            eval_rewards.append(reward_score)

    model.train()
    return sum(eval_rewards) / len(eval_rewards) if eval_rewards else 0.0



In [None]:

def train_loop(ppo_trainer, tokenizer, reward_model, reward_tokenizer, train_dataset, eval_dataset):
    """Main training loop with batched generation"""
    for epoch in range(TRAINING_CONFIG['epochs']):
        print(f"\n{'='*55}")
        print(f"Epoch {epoch+1}/{TRAINING_CONFIG['epochs']}")
        print(f"{'='*55}")

        for batch_idx, batch in enumerate(ppo_trainer.dataloader):
            # Extract queries
            queries = batch['query']

            # Tokenize to get list of 1D tensors
            tokenized = [tokenizer(q, return_tensors="pt", padding=False, truncation=True).to(device) for q in queries]
            query_tensors = [t["input_ids"].squeeze(0) for t in tokenized]  # list of 1D tensors

            # === Batched response generation ===
            response_tensors = ppo_trainer.generate(
                query_tensors,
                return_prompt=False,
                **GENERATION_CONFIG
            )

            # Decode responses
            responses = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]

            # Compute rewards
            rewards = compute_rewards(reward_model, reward_tokenizer, queries, responses)

            # Train step
            stats = ppo_trainer.step(
                query_tensors,         # list of 1D tensors
                response_tensors,      # list of 1D tensors
                list(rewards)          # tensor of shape [batch_size]
            )

            print("Step stats:", stats.keys())

            total_batches = len(ppo_trainer.dataloader)

            # Print header only once
            if epoch == 0 and batch_idx == 0:
                print(f"{'Batch':>10} | {'Mean Reward':>12} | {'Std Dev Reward':>16} | {'KL Div':>8}")
                print("-" * 55)

            if batch_idx % 10 == 0:
                batch_str = f"{batch_idx}/{total_batches}"
                rewards_tensor = rewards if isinstance(rewards, torch.Tensor) else torch.stack(rewards)
                mean_reward = rewards_tensor.mean().item()
                std_reward = rewards_tensor.std().item()
                kl_divergence = stats.get('objective/kl', float('nan'))  # Get KL divergence or NaN if not available

                print(f"{batch_str:>10} | {mean_reward:12.2f} | {std_reward:16.2f} | {kl_divergence:8.2f}")

                # print(f"Query: {queries[0][:100]}..." if len(queries[0]) > 100 else queries[0])
                # print(f"Response: {responses[0][:100]}..." if len(responses[0]) > 100 else responses[0])
                # print(f"Entropy: {stats.get('entropy', 'n/a')}")

        # # Evaluate mid-training
        mean_eval_reward = evaluate_model(
            ppo_trainer.model,
            tokenizer,
            reward_model,
            reward_tokenizer,
            eval_dataset,
            max_eval_samples=100
        )
        print(f"\n Evaluation after Epoch {epoch + 1}: Mean Reward = {mean_eval_reward:.4f}")

In [None]:
def prepare_dataset():
    """Load and prepare the dataset"""
    dataset = load_dataset(DATASET_NAME, trust_remote_code=True)

    # Use appropriate splits
    train_dataset = dataset["train"] if "train" in dataset else dataset["unsafe"]
    eval_dataset = dataset["validation"] if "validation" in dataset else dataset["safe"]

    # Rename columns to standardize
    if "prompt" in train_dataset.column_names:
        pass
    elif "question" in train_dataset.column_names:
        train_dataset = train_dataset.rename_column("question", "prompt")
        eval_dataset = eval_dataset.rename_column("question", "prompt")
    elif "input" in train_dataset.column_names:
        train_dataset = train_dataset.rename_column("input", "prompt")
        eval_dataset = eval_dataset.rename_column("input", "prompt")
    else:
        def create_prompt(examples):
            return {"prompt": examples["input"] + " " + examples["instruction"]}
        train_dataset = train_dataset.map(create_prompt)
        eval_dataset = eval_dataset.map(create_prompt)

    # Format for PPOTrainer
    def format_dataset(examples):
        return {"query": examples["prompt"]}

    train_dataset = train_dataset.map(format_dataset)
    train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if col != "query"])

    return train_dataset, eval_dataset

In [None]:
train_dataset, eval_dataset = prepare_dataset()
for sample in train_dataset:
  print(sample)
  break

In [None]:
TRAINING_CONFIG = {
    'epochs': 3,
    'batch_size': 2,
    'learning_rate': 1.41e-3,  # Reduced from 1.41e-5
}

# Generation parameters
GENERATION_CONFIG = {
    'max_new_tokens': 50,   # Reduced from 100
    'min_length': 10,       # Reduced from 20
    'do_sample': True,      # Enable stochastic decoding
    'top_p': 0.9,           # More conservative sampling, sample from top 90% cumulative probability
    'temperature': 0.3,     # Less randomness (default: 1.0)
}

import trl
from copy import deepcopy
print('trl: ', trl.__version__)

# prompt: load model
# sft_model = load_model('sft_toxicity_removal')
# sft_model

sft_model = AutoModelForCausalLM.from_pretrained('sft_toxicity_removal')
sft_model.config.pad_token_id = sft_model.config.eos_token_id
print('eos/pad', sft_model.config.pad_token_id)

## Value Model

value_model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1, problem_type="regression")
# Freeze GPT-2 layers (train only the head):
for param in value_model.base_model.parameters():
    param.requires_grad = False

# TO-DO as takehome task (after running ppo_with_trl_0_11.ipynb and ppo_with_trl_0_18.ipynb)


ref_model = deepcopy(sft_model).eval()  # Frozen reference
reward_model.eval() # reward_model is frozen

In [None]:
###############################
# Main Execution
###############################
def run_ppo_with_reward_function():
    # Prepare dataset
    train_dataset, eval_dataset = prepare_dataset()

    # Initialize models
    #Tokenizer, model, ref_model, reward_tokenizer, reward_model = initialize_models()

    # Setup PPO trainer
    ppo_trainer = setup_ppo_trainer(tokenizer, sft_model, ref_model, df_train)

    # Run training
    train_loop(ppo_trainer, tokenizer, reward_model, reward_tokenizer, train_dataset, eval_dataset)

    # Save the trained model
    sft_model.save_pretrained("./ppo_trained_model")
    tokenizer.save_pretrained("./ppo_trained_model")

run_ppo_with_reward_function()

### **Final Evaluation**

- The stage should be completed with the test split after the PPO is trained completely.

In [None]:
# TO-DO as takehome task once PPO is trained
def evaluate_model(model, tokenizer, dataset, reward_model, reward_tokenizer, num_samples=20):
    """Evaluate model on test set"""
    results = []
    idxs = np.random.choice(len(dataset), size=num_samples, replace=False)

    for i in idxs:
        prompt = dataset[i]["prompt"]

        # Generate response
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=128,
                do_sample=True,
                top_k=50,
                top_p=0.95
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Calculate reward
        reward = get_reward([prompt], [response], reward_tokenizer, reward_model)[0].item()

        results.append({
            "prompt": prompt,
            "response": response,
            "reward": reward,
            "category": dataset[i]["category"]
        })

    return pd.DataFrame(results)

# Evaluate SFT model
print("Evaluating SFT model...")
sft_eval = evaluate_model(sft_model, tokenizer, test_dataset, reward_model, reward_tokenizer)

# Evaluate PPO model
print("\nEvaluating PPO model...")
ppo_model = AutoModelForCausalLM.from_pretrained(
    "ppo_toxicity_removal",
    quantization_config=bnb_config,
    device_map="auto"
)
ppo_eval = evaluate_model(ppo_model, tokenizer, test_dataset, reward_model, reward_tokenizer)

# Compare results
print("\nSFT Model Average Reward:", sft_eval["reward"].mean())
print("PPO Model Average Reward:", ppo_eval["reward"].mean())

# Show some examples
print("\nExample improvements:")
for i in range(3):
    print(f"\nPrompt: {sft_eval['prompt'][i]}")
    print(f"SFT Response: {sft_eval['response'][i]} (Reward: {sft_eval['reward'][i]:.2f})")
    print(f"PPO Response: {ppo_eval['response'][i]} (Reward: {ppo_eval['reward'][i]:.2f})")

## **Part 3: Brainstorm Activity**

In [None]:
# @title Group Discussion Prompts
discussion_topics = [
    "How might you adapt these techniques for child-friendly chatbots?",
    "How can we balance safety with creative generation?",
    "What metrics would you use to evaluate fairness in model outputs?",
    "What are potential failure modes of RLHF for safety?",
    "What kinds of harmful outputs (subtle or overt) might a human rater miss? How can we catch these?",
    "How should RLHF balance fairness and truthfulness when user expectations conflict with respectful output? (e.g. User asks the model to adopt an informal sarcastic tone.)",
    "What kind of prompts can potentially lead the model to produce harmful outputs? Can RLHF help maintain safe practice there?"
]

print("Break into small groups and discuss one of these topics:")
for i, topic in enumerate(discussion_topics, 1):
    print(f"{i}. {topic}")

## **Part 4: Presentations and Q&A**

In [None]:
# @title Evaluation and Wrap-up
def evaluate_safety(model, test_cases):
    safe_count = 0
    for case in test_cases:
        response = generate_response(case['prompt'])
        if "sorry" in response.lower() or "can't" in response.lower():
            safe_count += 1
    return safe_count / len(test_cases)

test_cases = [{"prompt": row['text']} for _, row in test_df.head(5).iterrows()]
safety_score = evaluate_safety(model, test_cases)

print(f"\nModel safety score on test cases: {safety_score:.0%}")
print("\nKey Takeaways:")
print("- RLHF provides powerful tools for aligning LLMs with human values")
print("- Safety requires continuous evaluation and iteration")
print("- Consider the full pipeline: SFT → Reward Modeling → RL Optimization")

## **Additional Resources**

In [None]:
# @title Further Learning
resources = {
    "SafeEdit Paper": "https://arxiv.org/abs/2403.14472",
    "TRL Documentation": "https://huggingface.co/docs/trl/index",
    "RLHF Explained": "https://huggingface.co/blog/rlhf",
    "LoRA Paper": "https://arxiv.org/abs/2106.09685"
}

print("Explore these resources to go deeper:")
for name, url in resources.items():
    print(f"- {name}: {url}")

## **Appendix: Full Implementation Details**

### **A1. Full PPO Objective Function**

The PPO objective you've shown is a **simplified version** commonly used in RLHF (Reinforcement Learning from Human Feedback) for language models, but it's not the full PPO objective—it's missing the **value model's role** in advantage estimation. Let me clarify:

#### **1. What’s Missing? The Value Model’s Role**
The current objective focuses on:
1. **Reward Maximization**: $$\mathbb{E}[r_\phi(y|x)]$$  
2. **KL Penalty**: $$\text{KL}(\pi_\theta || \pi_{\text{ref}})$$  

But in PPO, we don’t directly maximize raw rewards. Instead, we maximize the **advantage** (how much better an action is compared to the expected baseline), computed using the **value model** \( V(x) \):  
$$
A(x, y) = r_\phi(y|x) - V(x)
$$

The **true PPO objective** includes this advantage term:  
$$
L(\theta) = \mathbb{E} \left[ \min\left( \frac{\pi_\theta(y|x)}{\pi_{\text{old}}(y|x)} A(x, y), \text{clip}\left(\frac{\pi_\theta(y|x)}{\pi_{\text{old}}(y|x)}, 1-\epsilon, 1+\epsilon\right) A(x, y) \right) \right] - \beta \, \text{KL}(\pi_\theta || \pi_{\text{ref}})
$$

Key differences:
- **Advantage \( A(x, y) \)**: Requires the value model to estimate \( V(x) \).  
- **Clipping**: Prevents overly large policy updates (the "proximal" in PPO).  


### **2. Definition of `ratios`**
The ratio is calculated as:
$$
\text{ratios} = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\text{old}}(a_t \mid s_t)}
$$

where:
- $\pi_\theta(a_t \mid s_t)$: Probability of action $a_t$ (e.g., a generated token/response) under the **current policy** (with parameters $\theta$).
- $\pi_{\text{old}}(a_t \mid s_t)$: Probability of the same action under the **old policy** (before the update).

#### **Role in PPO Loss**
The ratio is used to:
1. **Weight the advantages** (how much better/worse an action is compared to the baseline):
   $$
   \text{policy_loss} = -\text{min}\left(
   \text{ratios} \times \text{advantages}, \quad
   \text{clipped_ratios} \times \text{advantages}
   \right)
   $$
   - If the ratio is high (>1), the action became more likely under the new policy.
   - If the ratio is low (<1), the action became less likely.

2. **Enforce trust-region updates** via clipping (`clip_eps`):
   - Clipping the ratios (e.g., to $[1-0.2, 1+0.2]$) prevents overly large policy updates, ensuring stability.

#### **Why Use Ratios?**
- **Importance Sampling**: Allows reusing old trajectories (collected under $\pi_{\text{old}}$) to update the current policy ($\pi_\theta$) without resampling.
- **Controlled Updates**: The clipping ensures the policy doesn’t change too drastically, avoiding catastrophic failures (e.g., generating gibberish to exploit the reward model).


### **Key Points**
1. **Advantages**: Tell us if an action was better ($A > 0$) or worse ($A < 0$) than expected.
2. **Ratios**: Adjust policy updates based on how much the action’s probability changed.
3. **Clipping**: Prevents aggressive updates (e.g., a ratio of 100x would be clipped to $1.2$ if `clip_eps=0.2`).

#### **3. Where Does the Value Model Fit In?**
1. **Value Function \( V(x) \)**:  
   - The value model predicts the **expected future reward** from state \( x \) (e.g., a partially generated text).  
   - Input: Partial sequence (`"Q: 2+2? → A:"`).  
   - Output: Scalar (e.g., `0.3`).  

2. **Advantage Calculation**:  
   - If the reward for `"A: 5"` is `-1.0` (toxic) and \( V(x) = 0.3 \), then:  
     $$
     A(x, y) = -1.0 - 0.3 = -1.3
     $$  
   - A negative advantage pushes the policy away from toxic outputs.  

3. **Training the Value Model**:  
   - The value model is trained separately to minimize:  
     $$
     L(V) = \mathbb{E} \left[ (V(x) - R)^2 \right]
     $$  
     where \( R \) is the **actual discounted return** (e.g., the final reward).  


#### **4. Why the difference?**

- In RLHF, rewards are often sparse (e.g., per-output human feedback), so the value model’s role is sometimes simplified.  
- In classic PPO (e.g., for robotics), the value model is critical for dense reward signals.  

#### **5. Practical Implications for our Setup**
With `gpt2` (policy) and `toxic-bert` (reward model):
1. **Value Model**:  
   - Add a regression head to `gpt2` (shared backbone with policy).  
   - Train it to predict rewards for partial sequences.  

2. **Modified Objective**:  
   ```python
    # Pseudocode for PPO step (simplified)
    advantages = rewards - values  # A(s,a) = R(s,a) - V(s)
    ratios = current_probs / old_probs  # π_θ(a|s) / π_old(a|s)

    policy_loss = -torch.min(
        ratios * advantages,  # Unclipped objective
        torch.clamp(ratios, 1-clip_eps, 1+clip_eps) * advantages  # Clipped objective
    )

    kl_penalty = beta * kl_div(policy, ref_model)  # Penalty for diverging from reference
    total_loss = policy_loss + kl_penalty
   ```


#### **Key Takeaways**
| Component       | Role in Objective                                                                 |
|----------------|-----------------------------------------------------------------------------------|
| **Reward Model** ($ r_\phi $) | Provides $ r(y\|x) $ (e.g., toxicity scores).                                  |
| **Value Model** ($ V $)       | Estimates \( V(x) \) to compute advantages $ A(x, y) = r(y\|x) - V(x) $.       |
| **Reference Model** ($ \pi_{\text{ref}} $) | Anchors KL divergence to prevent over-optimization.                             |
| **Policy** ($ \pi_\theta $)   | Updated to maximize advantages while staying close to $ \pi_{\text{ref}} $.    |

**The simplified objective works for RLHF**, but for true PPO, include advantages (and clipping). The value model is essential for this!  


### **A2. Reward Function at Inference Time**

After training a **reinforcement learning (RL) model**, the **reward function is no longer needed during deployment** if the policy has been fully optimized and operates in a static environment. However, there are important exceptions and nuances:

**When the Reward Function is Still Needed After Training**
1. **Online Learning / Continual Adaptation**  
   - If the RL agent keeps learning during deployment (e.g., adapting to new environments), the reward function must remain active to provide feedback for updates.  
   - Example: A recommendation system that continuously refines its policy based on user clicks (rewards).

**Safe Exploration & Monitoring**  
   - In safety-critical applications (e.g., autonomous driving), the reward function may be used to monitor performance and trigger failsafes if rewards drop unexpectedly.  

**Reward as a Diagnostic Tool**  
   - Even if the policy is fixed, the reward function can evaluate performance post-deployment (e.g., detecting distributional shift or performance degradation).  

**When the Reward Function is *Not* Needed**
- **Static Environments with Fixed Policies**  
  - Once trained, the policy (e.g., a neural network) can act independently, mapping states → actions without reward calculations.  
  - Example: A game-playing AI (like AlphaGo) uses its pre-trained policy without recomputing rewards during matches.  

**Key Distinction: Reward Function vs. Policy**
- **Reward Function**: Guides *training* (like a teacher’s feedback).  
- **Policy**: The final *executable strategy* (like a student’s learned skills).  

**Practical Implications**
- If your environment is dynamic or requires ongoing learning, keep the reward function.  
- For static tasks, the trained policy alone suffices.

### **A3. RL approaches for LLMs**


Several **reinforcement learning (RL) approaches** are used to fine-tune large language models (LLMs) beyond **PPO** and **GRPO**, each with distinct advantages for alignment, efficiency, or stability. Here’s a breakdown of key alternatives:


**1. Direct Preference Optimization (DPO)**
- **What it does**: Replaces traditional RL with a **closed-form policy update** using human preference data (no reward model or PPO required).  
- **Advantages**:  
  - Simpler and more stable than PPO (avoids reward model biases).  
  - Directly optimizes for human preferences via pairwise rankings.  
- **Used in**: Zephyr, Mistral-7B, and other preference-tuned LLMs.  
- **Paper**: [Rafailov et al. (2023)](https://arxiv.org/abs/2305.18290).


**2. Reinforcement Learning from Human Feedback (RLHF) Variants**
#### **a. Advantage-Weighted Regression (AWR)**
  - **What it does**: Uses **advantage-weighted loss** (like PPO but simpler).  
  - **Advantages**: More stable for offline RL (no on-policy sampling needed).  
  - **Used in**: Early LLM alignment (e.g., OpenAI’s pre-ChatGPT models).  

#### **b. Q-Learning (e.g., DQN, CQL)**
  - **What it does**: Learns a **Q-function** to score actions (useful for constrained generation).  
  - **Advantages**: Better for **discrete action spaces** (e.g., choosing among template responses).  
  - **Limitations**: Rarely used for full LLM fine-tuning (scalability issues).  


**3. Contrastive Learning Methods**
#### **a. Sequence Likelihood Calibration (SLiC)**
  - **What it does**: Uses **contrastive ranking loss** (like DPO but simpler).  
  - **Advantages**: No RL loop needed; works with static datasets.  
  - **Used in**: Lightweight alignment (e.g., [Yuan et al. (2023)](https://arxiv.org/abs/2305.20050)).  

#### **b. RankRLHF**
  - **What it does**: Extends DPO to **listwise rankings** (A > B > C > D).  
  - **Advantages**: Better for multi-response ranking scenarios.  


**4. Offline RL Methods**
#### **a. Implicit Language Q-Learning (ILQL)**
  - **What it does**: Combines Q-learning with LLMs for **offline preference data**.  
  - **Advantages**: Efficient for **constrained text generation** (e.g., avoiding harmful outputs).  
  - **Used in**: [Snorkel AI’s work](https://arxiv.org/abs/2206.11871).  

#### **b. Conservative Q-Learning (CQL)**
  - **What it does**: Penalizes overestimation of Q-values in offline data.  
  - **Advantages**: Reduces hallucination in RL-tuned LLMs.  


**5. Hybrid Approaches**
#### **a. Reinforced Self-Training (ReST)**
  - **What it does**: Iteratively generates samples, filters best ones, and fine-tunes on them.  
  - **Advantages**: No reward model needed (self-improving loop).  
  - **Used in**: Google’s [GEMINI](https://arxiv.org/abs/2308.08998).  

#### **b. Expert Iteration (ExIt)**
  - **What it does**: Alternates between LLM generations and expert feedback (like AlphaGo).  
  - **Advantages**: Useful for **code-generating LLMs** (e.g., GitHub Copilot).  


#### **When to Use Which?**
| **Method**       | **Best For**                           | **Complexity** |  
|------------------|----------------------------------------|---------------|  
| **PPO**         | High-resource RLHF (e.g., ChatGPT)     | High          |  
| **DPO**         | Lightweight preference tuning          | Low           |  
| **SLiC/ILQL**   | Offline data + no reward model         | Medium        |  
| **Q-Learning**  | Discrete action spaces (e.g., dialog)  | High          |  



#### **Emerging Trends**
- **Multimodal RLHF**: Extending RL to align LLMs with vision/audio (e.g., GPT-4V).  
- **Adversarial RL**: Using RL to **red-team** LLMs (e.g., training against jailbreaks).  

#### **Group Relative Policy Optimization (GRPO)**

**Group Relative Policy Optimization (GRPO)** is a recent **reinforcement learning (RL) approach**, specifically designed for **policy optimization** in scenarios where relative performance comparisons matter (e.g., aligning AI behavior with human preferences). DeepSeek uses this approach.

**Key Features of GRPO**  
1. **Relative Policy Optimization**  
   - Unlike standard RL (which maximizes absolute rewards), GRPO focuses on **relative comparisons** (e.g., "Is response A better than B?"), making it suitable for **human feedback-driven RL (RLHF)**.
   - Similar in spirit to **Pairwise Preference Optimization** methods (e.g., DPO, PPO with ranking-based rewards).  

2. **Group-Wise Learning**  
   - Operates on **batches of trajectories** (e.g., multiple LLM responses ranked by humans/reward models).  
   - Optimizes policies by comparing **groups of actions** rather than individual rewards.  

3. **Connection to Existing Methods**  
   - Can be seen as a **generalization of PPO** (Proximal Policy Optimization) but with **ranking-based objectives** instead of absolute rewards.  
   - Shares similarities with **Off-Policy RL** (e.g., learns from logged human preference data).  

**How GRPO Differs from Traditional RL**

| **Aspect**          | **Standard RL (e.g., PPO)**            | **GRPO**                          |  
|----------------------|---------------------------------------|-----------------------------------|  
| **Reward Source**    | Absolute scalar rewards               | Relative rankings (A > B > C)     |  
| **Objective**        | Maximize expected reward              | Maximize likelihood of preferred trajectories |  
| **Data Usage**       | Requires dense rewards                | Works with sparse pairwise preferences |  
| **Use Case**         | Game-playing, robotics                | RLHF, alignment tasks (e.g., LLMs) |  

**Is GRPO Used in Practice?**  
- While **PPO** remains the dominant choice for RLHF (e.g., ChatGPT fine-tuning), GRPO is a **research-level method** (e.g., proposed in papers like ["Relative Policy Optimization"](https://arxiv.org/abs/2305.18239)) aimed at improving stability in preference-based RL.  
- It’s part of a broader trend toward **ranking-based RL** (e.g., DPO, SLiC, and contrastive learning variants).  

**When to Consider GRPO?**  
- If you’re working on **RLHF with human/AI-generated rankings** (vs. explicit rewards).  
- If standard PPO struggles with **sparse or noisy preference data**.  

Would you like a high-level pseudocode example of how GRPO updates policies?

In [None]:
# @title Complete Training Script (Collapsed)
Markdown("""
Full implementation would include:
1. Proper dataset tokenization and batching
2. Complete reward model training loop
3. Full PPO implementation with value head
4. Comprehensive evaluation metrics
5. Hyperparameter tuning
6. Safety-specific loss functions
""")

----