# RLAIF Fine-Tuning

## Mount Google Drive

Do this to be able to access a dataset you have in your Google Drive account

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Need the updated version of this package
!pip install -U bitsandbytes



## LLM Judge (Reward Function)

Here is where we define the LLM judge/our fake reward model for Direct-RLAIF
Instead of using a trained reward model, we'll use the Google Colab AI API ot prompt Gemini for a rating.

In [None]:
import re
import time
from google.colab import ai

# --- 1. Helper Function to Call Your API ---
def call_llm_api(prompt_content):
    """
    Simulates or performs the actual API call.
    Replace the body of this function with your actual API request (e.g., OpenAI, Anthropic).
    """
    try:
        return ai.generate_text(prompt_content)

    except Exception as e:
        print(f"LLM Judge Call failed: {e}")
        return "Score: 5" # Fail-safe neutral score

# --- 2. The Main Scoring Function ---
def get_llm_judge_score(prompt, response):
    """
    Get a reward score from an LLM judge for a given prompt-response pair.
    Uses a 1-10 absolute rating scale, normalized to [-1, 1].

    Args:
        prompt (str): The original prompt/question
        response (str): The model's generated response

    Returns:
        float: Reward score in range [-1, 1]
    """

    # A. Construct the Evaluation Prompt
    # We ask for "Reasoning" first to encourage Chain-of-Thought, which improves scoring accuracy.
    eval_prompt = f"""
You are an expert AI evaluator. Review the following interaction:

[User Instruction]
{prompt}

[AI Response]
{response}

Your task is to evaluate the helpfullness, honestly, and overall quality of the
AI response on a scale of 1 to 10.
- **1-3 (Poor):** Factually incorrect, harmful, or irrelevant.
- **4-6 (Average):** Relevant but vague, incomplete, slightly hallucinated, or
way too verbose for what the question required
- **7-9 (Good):** Accurate, helpful, and concise.
- **10 (Excellent):** Perfect, insightful, and well-structured.

**Evaluation Protocol:**
1. Briefly verify the factual accuracy.
2. Check if all constraints in the instruction were met.
3. Assign a final integer score.

**Be Aware of Your Own Limitations**
Beware of the central tendency bias in LLMs to just assign everything a good
score. YOU SHOULD BE ASSIGNING SOME BAD SCORES AND SOME PERFECT ONES.

**Output Format:**
Reasoning: [Your reasoning here]
Score: [Integer 1-10]
"""

    # B. Get the LLM's opinion
    llm_output = call_llm_api(eval_prompt)

    # C. Extract the Score using Regex
    # This looks for "Score: 7" or just "7" at the end of a line
    match = re.search(r'Score:\s*(\d+)', llm_output, re.IGNORECASE)

    if match:
        raw_score = int(match.group(1))
    else:
        # Fallback if parsing fails
        print(f"Warning: Could not parse score from LLM output. Defaulting to 5.\nOutput: {llm_output[:100]}...")
        raw_score = 5

    # Clamp score to 1-10 just in case
    raw_score = max(1, min(10, raw_score))

    # D. Normalize to [-1, 1] range for PPO
    # Formula: (score - 1) / 9 gives [0, 1]. Then multiply by 2 and subtract 1.
    # 1 -> -1.0
    # 5 -> -0.11
    # 10 -> 1.0
    normalized_score = 2 * ((raw_score - 1) / 9) - 1

    return float(normalized_score)

# Test the placeholder function
test_prompt = "What is the capital of France?"
test_response = "Paris is the capital and most populous city of France."
test_score = get_llm_judge_score(test_prompt, test_response)
print(f"Test LLM Judge Score [GOOD]: {test_score:.4f}")
print("✓ LLM judge function ready (using placeholder - replace with real API later)")

Test LLM Judge Score [GOOD]: 1.0000
✓ LLM judge function ready (using placeholder - replace with real API later)


## Policy Model with Value Head

PPO requires two components:
- **Policy (Actor)**: The language model that generates text
- **Value Head (Critic)**: Estimates the expected reward for a given state

We combine both into a single model class.

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") #make sure this says "cuda"

from typing import Optional
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class ValueHead(nn.Module):
    """
    The ValueHead class implements a head for the model
    that returns a scalar for each output token.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.value = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(self.value.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.value.bias)

    def forward(self, hidden_states):
        output = hidden_states
        return self.value(output)


class ModelForCausalLMWithValueHead(nn.Module):
    """
    Causal LM model with a value head on top.
    """

    def __init__(self, model_name_or_path, quantization_config=None):
        super().__init__()
        # NEW: Support loading from HuggingFace with quantization
        if quantization_config is not None:
            self.llm = AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                quantization_config=quantization_config,
                device_map="auto",
                torch_dtype=torch.bfloat16,
            )
        else:
            # OLD: Load from local path (GPT-2)
            self.llm = AutoModelForCausalLM.from_pretrained(model_name_or_path)

        # Add the value head
        self.v_head = ValueHead(self.llm.config)

        # IMPORTANT: Move value head to same device as the LLM
        # With device_map="auto", the LLM is on GPU but v_head might be on CPU
        if quantization_config is not None:
            # Find which device the LLM is on
            try:
                llm_device = next(self.llm.parameters()).device
                self.v_head = self.v_head.to(llm_device)
                print(f"Value head moved to {llm_device}")
            except StopIteration:
                print("Warning: Could not determine LLM device")

    def forward(
        self,
        input_ids,
        attention_mask,
    ) -> Optional[torch.FloatTensor]:

        transformer_outputs = self.llm.forward(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states = True,
        )
        lm_logits = transformer_outputs.logits
        # Get the last hidden state
        last_hidden_state = transformer_outputs.hidden_states[-1]

        # Apply the value head
        value = self.v_head(last_hidden_state).squeeze(-1)
        return lm_logits, value

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

Using device: cuda


## Load the model from Huggingface

In [None]:
# Load Zephyr 7B from HuggingFace with 4-bit quantization
from transformers import BitsAndBytesConfig

model_name = "HuggingFaceH4/zephyr-7b-beta"

# Quantization config for 4-bit loading
# note: might have to change bfloat16 to float16 depending on GPU
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load model with value head
model = ModelForCausalLMWithValueHead(model_name, quantization_config=bnb_config)
print("Zephyr 7B model loaded successfully with value head!")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Value head moved to cuda:0
Zephyr 7B model loaded successfully with value head!


## Preparing Dataset

We load prompts from our CSV file and tokenize them for training.

In [None]:
from transformers import AutoTokenizer

# Load Zephyr tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # Important for batch generation
print(f"Tokenizer loaded. Pad token: {tokenizer.pad_token}")

Tokenizer loaded. Pad token: </s>


In [None]:
# Updated Data Loading Code
from datasets import load_dataset
import random

# 1. Load the specific split designed for PPO/Generation
# The 'train_gen' split contains prompts suitable for generation tasks
full_dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split="train_gen")

# 2. Select a subset of the data
# RECOMMENDATION: Start with 2,000 for a quick run, or 10,000+ for better results.
# We shuffle to get a random assortment of prompts.
num_samples = 2000
dataset_subset = full_dataset.shuffle(seed=2981).select(range(num_samples))

print(f"Dataset loaded: {len(dataset_subset)} prompts")
print(f"Columns: {dataset_subset.column_names}")
print(f"\nSample prompt: {dataset_subset[0]['prompt']}")

Dataset loaded: 2000 prompts
Columns: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected']

Sample prompt: A 6 month stakeholder engagement plan for climate change-related flooding across NYC. Engaging social groups that are most vulnerable to impacts. Engaging local experts who have insights and knowledge. Weekly schedule of meetings, workshops, surveys, etc.


In [None]:
# Create train/val split from our dataset
# might be able to make this 100% train
train_size = int(0.9 * len(dataset_subset))
ds_train = dataset_subset.select(range(train_size))
ds_val = dataset_subset.select(range(train_size, len(dataset_subset)))

print(f"Train size: {len(ds_train)}")
print(f"Val size: {len(ds_val)}")

Train size: 1800
Val size: 200


# Define the Tokenizer
One important thing to note is that this is adding in the formatted prompt "<|user|>\nsample<...." so it is expecting the prompt dataset to not have this already. This should print out a sample prompt in this format.

In [None]:
# Tokenize prompts using Zephyr's chat template
def tokenize(sample):
    # Format the prompt using Zephyr's chat template
    # Zephyr expects: <|user|>\n{prompt}</s>\n<|assistant|>\n

    try:
        # Try using the built-in chat template
        messages = [
            {"role": "user", "content": sample['prompt']}
        ]

        # Use tokenize=True to get token IDs directly
        encoded = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors=None  # Return list, not tensor
        )

        # Get the formatted text for debugging
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    except Exception as e:
        # Fallback: Manually format using Zephyr's template
        # Based on: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
        formatted_prompt = f"<|user|>\n{sample['prompt']}</s>\n<|assistant|>\n"
        # Use add_special_tokens=True to properly handle special tokens
        encoded = tokenizer.encode(formatted_prompt, add_special_tokens=True)

    sample['input_ids'] = encoded
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = formatted_prompt  # Keep formatted prompt text
    return sample

map_kwargs = {
    "batched": False,
    "remove_columns": ['prompt', 'chosen', 'rejected']  # Remove CSV columns, keep tokenized data
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

print(f"Tokenized {len(tokenized_dataset_train)} training prompts")
print(f"Tokenized {len(tokenized_dataset_val)} validation prompts")
print(f"\nSample formatted prompt:")
print(repr(tokenized_dataset_train[0]['query']))
print(f"\nToken IDs (first 20): {tokenized_dataset_train[0]['input_ids'][:20]}")
print(f"Number of tokens: {len(tokenized_dataset_train[0]['input_ids'])}")

Tokenized 1800 training prompts
Tokenized 200 validation prompts

Sample formatted prompt:
'<|user|>\nA 6 month stakeholder engagement plan for climate change-related flooding across NYC. Engaging social groups that are most vulnerable to impacts. Engaging local experts who have insights and knowledge. Weekly schedule of meetings, workshops, surveys, etc.</s>\n<|assistant|>\n'

Token IDs (first 20): [523, 28766, 1838, 28766, 28767, 13, 28741, 28705, 28784, 2102, 15790, 8229, 15613, 2623, 354, 11259, 2268, 28733, 9646, 2175]
Number of tokens: 68


In [None]:
# IDK what this does, but it's probably needed
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

In [None]:
# Insert this after your 'tokenize' function and before creating DataLoaders

# 1. Define your GPU's limit
# For a T4 GPU (Colab free), 512-768 is a safe limit.
# For A100, you can go to 2048.
MAX_PROMPT_LENGTH = 512

def filter_long_prompts(sample):
    # Keep only samples where the prompt is short enough
    return len(sample['input_ids']) <= MAX_PROMPT_LENGTH

# 2. Apply filtering
print(f"Original Training Set Size: {len(tokenized_dataset_train)}")
tokenized_dataset_train = tokenized_dataset_train.filter(filter_long_prompts)
print(f"Filtered Training Set Size: {len(tokenized_dataset_train)}")

print(f"Original Val Set Size: {len(tokenized_dataset_val)}")
tokenized_dataset_val = tokenized_dataset_val.filter(filter_long_prompts)
print(f"Filtered Val Set Size: {len(tokenized_dataset_val)}")

Original Training Set Size: 1800


Filter:   0%|          | 0/1800 [00:00<?, ? examples/s]

Filtered Training Set Size: 1618
Original Val Set Size: 200


Filter:   0%|          | 0/200 [00:00<?, ? examples/s]

Filtered Val Set Size: 174


## Reward Token

The reward token marks where we compute the final reward score for a generated sequence.

In [None]:
from torch.utils.data import DataLoader

# NOTE: Batch size may need to be reduced for Zephyr 7B (larger than GPT-2)
# Start with small batch size and increase if memory allows
# note we're using batch_size = 8 for training. Not sure if this needs to match
batch_size = 2  # Reduced from 32 for 7B model

def collator(batch):
    return dict((key, [d[key] for d in batch]) for key in batch[0])

train_dataloader = DataLoader(tokenized_dataset_train, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_val, batch_size=batch_size, collate_fn=collator, shuffle=True)

print(f"Dataloaders created with batch_size={batch_size}")

Dataloaders created with batch_size=2


In [None]:
batch = next(iter(train_dataloader))
batch

{'prompt_id': ['2b223ac0550faaca92222a2131bf3d66ac313e4c95068f0d8ad9eddd4389ea87',
  '49d5f2cdb3f6751d392fa79b8b4de2d83d23f0d7325dc1678bcfda90d8114a51'],
 'messages': [[{'content': 'In this task, you are given two sentences. Your task is to classify the given sentences as "Yes" if they have same meaning; otherwise, classify them as "No". \n\nExample Input: Sentence-1: I\'ve never gotten into them.<sep>Sentence-2: Like readers digest .\nExample Output: No\n\nExample Input: Sentence-1: I pick up shifts when I want to.<sep>Sentence-2: I work at the weekends .\nExample Output: Yes\n\nExample Input: Sentence-1: We are still waiting.<sep>Sentence-2: I did get my new office .\nExample Output:',
    'role': 'user'}],
  [{'content': 'Write a Python program that receives a text file as input and analyzes it, extracting the most relevant keywords from the text. The program should prioritize the most common and meaningful words, such as nouns and verbs, and ignore stop words like "the" or "and". T

In [None]:
# Generation settings for model responses
# NOTE: These values control response length - adjust based on your needs
output_min_length = 20  # Increased from 5 for Zephyr
output_max_length = 100  # Increased from 16 for more complete responses

# https://huggingface.co/docs/trl/how_to_train#how-to-generate-text-for-training
generation_kwargs = {
    "min_length": 0, # this was changed from -1
    "top_k": 50, # this was changed from 0.0
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id
}

print(f"Generation settings: {output_min_length}-{output_max_length} tokens")

Generation settings: 20-100 tokens


## Sample Generation (Test)

Let's test that our model can generate responses before starting training.

In [None]:
import random
new_tokens = random.choice(list(range(output_min_length, output_max_length)))
generation_kwargs["max_new_tokens"] = new_tokens
sample = tokenizer('Hi, this')
sample

{'input_ids': [1, 15359, 28725, 456], 'attention_mask': [1, 1, 1, 1]}

In [None]:
query_response = model.generate(
    input_ids=torch.tensor(sample['input_ids']).unsqueeze(0).to(device),
    attention_mask=torch.tensor(sample['attention_mask']).unsqueeze(0).to(device),
    **generation_kwargs
    ).squeeze(0)
query_response

tensor([    1, 15359, 28725,   456,   349,   475, 28723,  8181,   304, 10058,
          298,  1698,  6073,  1704, 28723,   315,  3317,   368, 28742,   267,
          544,  2548,  1162,   304,  7484,   456,  3518, 10865, 28745,   297,
         3154, 28742, 28713,  1704, 28725,   478, 28742,   267,  1404,   298,
         2796,  1581,  4514,   302,  4400,  2621, 17869,   369,   460,  2492,
         8154, 14650,   297,   272,  7153, 13894, 28723,    13,    13,  2565,
         1395,   693,   460,   633,   298,   456,  3518, 28725,   478, 28742,
          267,  9045,   356,  7501, 12076,   395, 12302, 20715,  5202,   298,
         3270,  7080], device='cuda:0')

In [None]:
tokenizer.decode(query_response)

"<s> Hi, this is J. Lee and welcome to another blog post. I hope you're all doing well and finding this series helpful; in today's post, we're going to cover different types of website design trends that are making splashes in the digital landscape.\n\nFor those who are new to this series, we're focused on providing readers with valuable insights related to online marketing"

In [None]:
# Test scoring a generated response with LLM judge
with torch.no_grad():
    # Decode the generated response
    response_text = tokenizer.decode(query_response, skip_special_tokens=True)
    prompt_text = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)

    # Get score from LLM judge (replaces reward model)
    score = get_llm_judge_score(prompt_text, response_text)

    print(f"Prompt: {prompt_text}")
    print(f"Response: {response_text}")
    print(f"LLM Judge Score: {score:.4f}")

Prompt: Hi, this
Response: Hi, this is J. Lee and welcome to another blog post. I hope you're all doing well and finding this series helpful; in today's post, we're going to cover different types of website design trends that are making splashes in the digital landscape.

For those who are new to this series, we're focused on providing readers with valuable insights related to online marketing
LLM Judge Score: 1.0000


## Batch Generation

Now let's test generating responses for a full batch of prompts.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# NOTE: Model already on device via device_map="auto", so we don't move it
# model = model.to(device)  # Not needed with device_map="auto"

query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']

response_tensors = []
query_response_tensors = []
score_tensors = []

print(f"Generating responses for {len(query_tensors)} prompts...")

for i, query in enumerate(query_tensors):
    query = query.to(device)
    query_attention_mask = query_attention_masks[i].to(device)
    new_tokens = random.choice(list(range(output_min_length, output_max_length)))
    generation_kwargs["max_new_tokens"] = new_tokens
    query_response = model.generate(
        input_ids=query.unsqueeze(0),
        attention_mask=query_attention_mask.unsqueeze(0),
        **generation_kwargs
    ).squeeze(0)

    response_len = len(query_response) - len(query)
    response_tensors.append(query_response[-response_len:])
    query_response_tensors.append(query_response)

    # Use LLM judge instead of reward model
    with torch.no_grad():
        prompt_text = tokenizer.decode(query, skip_special_tokens=True)
        print(prompt_text)
        response_text = tokenizer.decode(query_response[-response_len:], skip_special_tokens=True)

        # Get score from LLM judge
        score = get_llm_judge_score(prompt_text, response_text)
        score = torch.tensor(score).to(device)

    score_tensors.append(score)

batch["response"] = [tokenizer.decode(response, skip_special_tokens=True) for response in response_tensors]
print("Generated responses:")
print(batch['response'])

Using device: cuda
Generating responses for 2 prompts...
<|user|>
In this task, you are given two sentences. Your task is to classify the given sentences as "Yes" if they have same meaning; otherwise, classify them as "No". 

Example Input: Sentence-1: I've never gotten into them.<sep>Sentence-2: Like readers digest .
Example Output: No

Example Input: Sentence-1: I pick up shifts when I want to.<sep>Sentence-2: I work at the weekends .
Example Output: Yes

Example Input: Sentence-1: We are still waiting.<sep>Sentence-2: I did get my new office .
Example Output: 
<|assistant|>

<|user|>
Write a Python program that receives a text file as input and analyzes it, extracting the most relevant keywords from the text. The program should prioritize the most common and meaningful words, such as nouns and verbs, and ignore stop words like "the" or "and". The extracted keywords should be sorted by frequency and displayed to the user. Additionally, the program should offer an optional parameter t

## Compute Reward

The reward function combines:
- **Score from LLM judge**: Quality of the response
- **KL penalty**: Prevents the model from diverging too much from the reference (SFT) model

**Reward Formula:**

$\text{reward} = \text{score} - \beta \cdot \log \left(\frac{\pi^{RL}_\theta}{\pi^{SFT}}\right)$

Where:
- $\text{score}$ is from the LLM judge
- $\beta$ is the KL penalty coefficient (controls how much we penalize divergence)
- $\pi^{RL}_\theta$ is the current policy (model being trained)
- $\pi^{SFT}$ is the reference policy (frozen SFT model)

In [None]:
# Adding these so that types match up correctly

# Move the value head to the same device as the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.v_head.to(device)

# Verify it works
print(f"Value head device: {next(model.v_head.parameters()).device}")

# Convert the value head to float16 to match the base model's output
model.v_head.to(dtype=torch.bfloat16)

# Verify the fix
print(f"Value head dtype: {model.v_head.value.weight.dtype}")

Value head device: cuda:0
Value head dtype: torch.bfloat16


In [None]:
# Create reference model (frozen copy for KL divergence)
# NOTE: For large models like Zephyr 7B, deepcopy might use a lot of memory
# The reference model stays frozen during training
from copy import deepcopy
print("Creating reference model (this may take a moment for 7B model)...")
sft_model = deepcopy(model)
sft_model.eval()  # Set to evaluation mode (frozen)
print("✓ Reference model created")

Creating reference model (this may take a moment for 7B model)...
✓ Reference model created


In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
input_data = data_collator([
    {'input_ids': ids,
     'attention_mask': torch.ones_like(ids)} for ids in query_response_tensors
]).to(device)
input_data

{'input_ids': tensor([[  523, 28766,  1838, 28766, 28767,    13,   657,   456,  3638, 28725,
           368,   460,  2078,   989, 23748, 28723,  3604,  3638,   349,   298,
           875,  1575,   272,  2078, 23748,   390,   345,  5613, 28739,   513,
           590,   506,  1348,  5746, 28745,  5860, 28725,   875,  1575,   706,
           390,   345,  2501,  2586, 28705,    13,    13, 20275, 11232, 28747,
           318,   308,   636, 28733, 28740, 28747,   315, 28742,   333,  1484,
         10930,   778,   706, 26364, 21571, 28767, 26968,   636, 28733, 28750,
         28747,  5410, 12076, 18922,   842,    13, 20275, 15985, 28747,  1770,
            13,    13, 20275, 11232, 28747,   318,   308,   636, 28733, 28740,
         28747,   315,  3088,   582, 23573,   739,   315,   947,   298, 26364,
         21571, 28767, 26968,   636, 28733, 28750, 28747,   315,   771,   438,
           272,  1819,  2827,   842,    13, 20275, 15985, 28747,  5592,    13,
            13, 20275, 11232, 28747,  

In [None]:
def compute_rewards(input_data, query_tensors, response_tensors, score_tensors):
    with torch.no_grad():
        logits, values = model(**input_data) # b, seq, vocab
        ref_logits, _ = sft_model(**input_data)

        # FIX 1: Clamp logits to avoid -inf/nan in log_softmax
        logits = torch.clamp(logits, min=-100, max=100)
        ref_logits = torch.clamp(ref_logits, min=-100, max=100)

        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)

        # FIX 2: These lines were missing! They select the prob of the actual token used.
        labels = input_data['input_ids'][:, 1:] # b, seq
        logp = torch.gather(logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq
        ref_logp = torch.gather(ref_logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq

        kl = logp - ref_logp
        beta = 0.2
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:,:] = attention_mask[:, 1:]

        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks


In [None]:
logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
# print(rewards[0])
# print(input_data['input_ids'][0])
# print(input_data['attention_mask'][0])

In [None]:
print(masks[0])
print(values[0])

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([  0.0000,  -0.0000,  -0.0000,  -0.0000,   0.0000,   0.0000,  -0.0000,
         -0.0000,   0.0000,   0.0000,  -0.0000,  -0.00

## Compute Advantage

The advantage function estimates how much better an action is compared to the average:
- **Positive advantage**: This action is better than expected
- **Negative advantage**: This action is worse than expected

Uses Generalized Advantage Estimation (GAE) for lower variance.

In [None]:
def masked_mean(values, mask):
    return (values * mask).sum() / mask.sum()

def masked_var(values, mask):
    mean = masked_mean(values, mask)
    centred_values = values - mean
    return masked_mean(centred_values ** 2, mask)

# def masked_whiten(values, mask):
#     mean, var = masked_mean(values, mask), masked_var(values, mask)
#     whitened = (values - mean) * torch.rsqrt(var + 1e-8)
#     whitened += mean
#     return whitened
def masked_whiten(values, mask):
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    # Add a larger epsilon (1e-6 instead of 1e-8) for stability
    # If variance is 0, just return the centered values without scaling
    if var < 1e-6:
        return values - mean
    whitened = (values - mean) * torch.rsqrt(var + 1e-6)
    return whitened  # Note: Standard whitening returns (x-mu)/sigma, not adding mean back

def compute_advantage(rewards, values, masks):
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95

    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)

    returns = advantages + values
    return advantages, returns


In [None]:
advantages, returns = compute_advantage(rewards, values, masks)
print(advantages[0])
print(returns[0])

tensor([-7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2266e-02,
        -7.2266e-02, -7.2266e-02, -7.2266e-02, -7.2

## Mini-batch PPO Training

PPO updates the policy using mini-batches to improve sample efficiency and stability.

### Training Config

In [None]:
# Training hyperparameters
# NOTE: These may need adjustment for Zephyr 7B
learning_rate = 1e-5  # Conservative learning rate for fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f"Optimizer created with lr={learning_rate}")

Optimizer created with lr=1e-05


In [None]:
np.random.permutation(batch_size)

array([0, 1])

In [None]:
# PPO training configuration
mini_batch_size = 2  # Reduced from 4 for memory efficiency with 7B model
ppo_epochs = 4

cliprange_ratio = 0.2  # PPO clipping range
v_loss_coeff = 0.1     # Value loss coefficient
ratio_threshold = 10   # Threshold to detect unstable training

def compute_loss(old_logprobs, values, logprobs, vpreds, masks, advantages, returns):
    """
    Compute PPO loss with clipping.
    """
    ratio = torch.exp(logprobs - old_logprobs)
    pg_loss1 = - ratio * advantages
    pg_loss2 = - torch.clamp(ratio, 1 - cliprange_ratio, 1 + cliprange_ratio) * advantages
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)

    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    loss = pg_loss + v_loss_coeff * v_loss

    avg_ratio = masked_mean(ratio, masks)
    if avg_ratio > ratio_threshold:
        # Unstable training detected - zero out gradients
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0

    return loss, v_loss

def mini_batch_train():
    """Run mini-batch PPO training for multiple epochs."""
    # FIX 1: Get actual batch size from input data (prevents index out of bounds)
    current_batch_size = len(input_data['input_ids'])

    for ep in range(ppo_epochs):
        # FIX 1: Use current_batch_size for permutation
        batch_inds = np.random.permutation(current_batch_size)

        for start in range(0, current_batch_size, mini_batch_size):
            end = start + mini_batch_size
            mini_batch_inds = batch_inds[start:end]

            mb_model_inputs = {
                'input_ids': input_data['input_ids'][mini_batch_inds],
                'attention_mask': input_data['attention_mask'][mini_batch_inds]
            }
            mb_logits, mb_vpreds = model(**mb_model_inputs)

            # FIX 2: Clamp logits here to prevent NaN loss!
            mb_logits = torch.clamp(mb_logits, min=-100, max=100)

            mb_logits = torch.nn.functional.log_softmax(mb_logits[:, :-1, :], dim=-1)
            mb_logprobs = torch.gather(mb_logits, 2, mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)

            loss, loss_v = compute_loss(
                logprobs[mini_batch_inds],
                values[mini_batch_inds],
                mb_logprobs,
                mb_vpreds[:, :-1],
                masks[mini_batch_inds],
                advantages[mini_batch_inds],
                returns[mini_batch_inds]
            )

            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping (you already had this, keep it!)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            # Debug print
            if torch.isnan(loss):
                print(f"WARNING: Loss is NaN in epoch {ep}!")
            else:
                print('loss/total', loss.item())
    print('mini-batch training finished')

print(f"PPO config: {ppo_epochs} epochs, mini_batch_size={mini_batch_size}")

PPO config: 4 epochs, mini_batch_size=2


In [None]:
mini_batch_train()

loss/total 0.10107421875
loss/total 0.06640625
loss/total 0.033203125
loss/total 0.01806640625
mini-batch training finished


## Train RLHF

The main training loop:
1. Generate responses for each prompt
2. Score responses with LLM judge
3. Compute rewards (score - KL penalty)
4. Compute advantages (how good was this response?)
5. Update the policy with PPO

In [None]:
num_epochs = 1

for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Generate responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']

        response_tensors = []
        query_response_tensors = []
        score_tensors = []

        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = random.choice(list(range(output_min_length, output_max_length)))
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
                ).squeeze(0)

            response_len = len(query_response) - len(query)
            response_tensors.append(query_response[-response_len:])
            query_response_tensors.append(query_response)

            # Use LLM judge instead of reward model
            with torch.no_grad():
                prompt_text = tokenizer.decode(query, skip_special_tokens=True)
                response_text = tokenizer.decode(query_response[-response_len:], skip_special_tokens=True)

                # Get score from LLM judge
                score = get_llm_judge_score(prompt_text, response_text)
                score = torch.tensor(score).to(device)

            score_tensors.append(score)

        input_data = data_collator([
            {
                'input_ids': ids,
                'attention_mask': torch.ones_like(ids)
            }
            for ids in query_response_tensors
        ]).to(device)

        # rewards and advantages
        logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
        advantages, returns = compute_advantage(rewards, values, masks)

        # mini batch training
        mini_batch_train()
    print(f'epoch {epoch + 1} finished')

loss/total 0.1015625
loss/total 0.059814453125
loss/total 0.03369140625
loss/total 0.0126953125
mini-batch training finished
loss/total 0.09814453125
loss/total 0.06689453125
loss/total 0.04248046875
loss/total 0.02099609375
mini-batch training finished
loss/total 0.09912109375
loss/total 0.083984375
loss/total 0.059326171875
loss/total 0.044189453125
mini-batch training finished
loss/total 0.099609375
loss/total 0.07958984375
loss/total 0.05859375
loss/total 0.030029296875
mini-batch training finished
loss/total 0.099609375
loss/total 0.072265625
loss/total 0.044189453125
loss/total 0.01220703125
mini-batch training finished
loss/total 0.099609375
loss/total 0.072265625
loss/total 0.04296875
loss/total 0.0234375
mini-batch training finished
loss/total 0.099609375
loss/total 0.06640625
loss/total 0.0419921875
loss/total 0.019287109375
mini-batch training finished
loss/total 0.10107421875
loss/total 0.0693359375
loss/total 0.033447265625
loss/total 0.0146484375
mini-batch training finis

In [None]:
# save the model
import os
torch.save(model.state_dict(), '/content/drive/MyDrive/ppo_model_epoch_1.pt')

## Validation

Test the trained model on the validation set to see if responses improved.

In [None]:
# len(tokenized_dataset_val)

In [None]:
# val_gen_lengths = [0] * len(tokenized_dataset_val)
# for i in range(len(tokenized_dataset_val)):
#     val_gen_lengths[i] = random.choice(list(range(output_min_length, output_max_length)))

In [None]:
# val_gen_lengths[:10]

In [None]:
# def validate():
#     scores = []
#     for b, batch in enumerate(val_dataloader):
#         # Generate_responses
#         query_tensors = batch['input_ids']
#         query_attention_masks = batch['attention_mask']
#         for i, query in enumerate(query_tensors):
#             query = query.to(device)
#             query_attention_mask = query_attention_masks[i].to(device)
#             new_tokens = val_gen_lengths[b * len(query_tensors) + i]
#             generation_kwargs["max_new_tokens"] = new_tokens
#             query_response = model.generate(
#                 input_ids=query.unsqueeze(0),
#                 attention_mask=query_attention_mask.unsqueeze(0),
#                 **generation_kwargs
#                 ).squeeze(0)
#             # query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
#             attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
#             score = reward_model(query_response_score.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
#             score = 2 * (score - 0.5)
#             scores.append(score.item())
#     print('avg score:', sum(scores) / len(scores))

In [None]:
# validate()

In [None]:
# torch.save(model.state_dict(), 'ppo_model_epoch_1.pt')

In [None]:
# model_path = './sft_model_epoch_1'
# model = ModelForCausalLMWithValueHead(model_path).to(device)

In [None]:
# validate()