# RLHF Fine-Tuning

## Mount Google Drive

```python
from google.colab import drive
drive.mount('/content/drive')
```

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!ls '/content/drive/MyDrive/datasets'

gpt_formatted_dataset_clean.csv


In [3]:
!pip install -U bitsandbytes



# Mount Google Drive to access dataset
from google.colab import drive
drive.mount('/content/drive')

## LLM Judge (Reward Function)

Instead of using a trained reward model, we'll use an LLM API (like GPT-4 or Claude) to score the quality of generated responses.

In [4]:
def get_llm_judge_score(prompt, response):
    """
    Get a reward score from an LLM judge for a given prompt-response pair.

    Args:
        prompt (str): The original prompt/question
        response (str): The model's generated response

    Returns:
        float: Reward score in range [-1, 1]
    """
    # TODO: Replace this with actual API call to GPT-4, Claude, or other LLM
    # Example API call structure:
    # judge_prompt = f"Rate the quality of this response (0-10):\nQuestion: {prompt}\nAnswer: {response}\nRating:"
    # api_response = call_llm_api(judge_prompt)
    # score = extract_score(api_response) / 5.0 - 1.0  # Normalize to [-1, 1]

    # PLACEHOLDER: Mock scoring based on response length
    # Longer responses get slightly higher scores (just for testing)
    mock_score = min(len(response) / 200.0, 1.0)  # 0 to 1
    mock_score = 2 * (mock_score - 0.5)  # Convert to [-1, 1] range

    return float(mock_score)

# Test the placeholder function
test_prompt = "What is the capital of France?"
test_response = "Paris is the capital and most populous city of France."
test_score = get_llm_judge_score(test_prompt, test_response)
print(f"Test LLM Judge Score: {test_score:.4f}")
print("✓ LLM judge function ready (using placeholder - replace with real API later)")

Test LLM Judge Score: -0.4600
✓ LLM judge function ready (using placeholder - replace with real API later)


In [5]:
# OLD: Load reward model from file
# model_name = "gpt2"
# reward_model = GPT2RewardModel(model_name)
# reward_model.load_state_dict(torch.load("reward_model.pt", map_location='cpu'))

print("Reward model loading skipped - will use LLM judge API instead")

Reward model loading skipped - will use LLM judge API instead


## Policy Model with Value Head

PPO requires two components:
- **Policy (Actor)**: The language model that generates text
- **Value Head (Critic)**: Estimates the expected reward for a given state

We combine both into a single model class.

In [6]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

from typing import Optional
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class ValueHead(nn.Module):
    """
    The ValueHead class implements a head for the model
    that returns a scalar for each output token.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.value = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(self.value.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.value.bias)

    def forward(self, hidden_states):
        output = hidden_states
        return self.value(output)


class ModelForCausalLMWithValueHead(nn.Module):
    """
    Causal LM model with a value head on top.
    """

    def __init__(self, model_name_or_path, quantization_config=None):
        super().__init__()
        # NEW: Support loading from HuggingFace with quantization
        if quantization_config is not None:
            self.llm = AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                quantization_config=quantization_config,
                device_map="auto",
                torch_dtype=torch.bfloat16,
            )
        else:
            # OLD: Load from local path (GPT-2)
            self.llm = AutoModelForCausalLM.from_pretrained(model_name_or_path)

        # Add the value head
        self.v_head = ValueHead(self.llm.config)

        # IMPORTANT: Move value head to same device as the LLM
        # With device_map="auto", the LLM is on GPU but v_head might be on CPU
        if quantization_config is not None:
            # Find which device the LLM is on
            try:
                llm_device = next(self.llm.parameters()).device
                self.v_head = self.v_head.to(llm_device)
                print(f"Value head moved to {llm_device}")
            except StopIteration:
                print("Warning: Could not determine LLM device")

    def forward(
        self,
        input_ids,
        attention_mask,
    ) -> Optional[torch.FloatTensor]:

        transformer_outputs = self.llm.forward(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states = True,
        )
        lm_logits = transformer_outputs.logits
        # Get the last hidden state
        last_hidden_state = transformer_outputs.hidden_states[-1]

        # Apply the value head
        value = self.v_head(last_hidden_state).squeeze(-1)
        return lm_logits, value

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

Using device: cuda


In [7]:
# OLD: Load GPT-2 from local path
# model_path = './sft_model_epoch_1'
# model = ModelForCausalLMWithValueHead(model_path)

# NEW: Load Zephyr 7B from HuggingFace with 4-bit quantization
from transformers import BitsAndBytesConfig

model_name = "HuggingFaceH4/zephyr-7b-beta"

# Quantization config for 4-bit loading
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load model with value head
model = ModelForCausalLMWithValueHead(model_name, quantization_config=bnb_config)
print("Zephyr 7B model loaded successfully with value head!")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Value head moved to cuda:0
Zephyr 7B model loaded successfully with value head!


## Preparing Dataset

We load prompts from our CSV file and tokenize them for training.

In [8]:
from transformers import AutoTokenizer

# OLD: Load GPT-2 tokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# NEW: Load Zephyr tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # Important for batch generation
print(f"Tokenizer loaded. Pad token: {tokenizer.pad_token}")

Tokenizer loaded. Pad token: </s>


In [9]:
# TEMPORARY: Create dummy dataset for testing (remove when using real dataset from Drive)
import os
os.makedirs('datasets', exist_ok=True)

# NOTE: For Zephyr, we need to format prompts using the chat template
# The tokenizer has a built-in apply_chat_template() method
# We'll store just the user questions in the CSV and format them later during tokenization

dummy_csv = """prompt,chosen,rejected
Explain why the sky looks blue.,"Answer: The sky appears blue because molecules in the atmosphere scatter shorter wavelengths of sunlight more efficiently than longer ones.","The sky looks blue due to Rayleigh scattering."
How does photosynthesis work in plants?,"Answer: Photosynthesis uses light, water, and carbon dioxide to produce glucose and oxygen in chloroplasts.","Photosynthesis is the way plants turn light into food."
What is the capital city of Japan?,Answer: Tokyo is the capital city of Japan.,Tokyo is the capital city of Japan.
Describe the function of the human heart.,Answer: The heart pumps oxygenated and deoxygenated blood through the body to support cellular activity.,The heart moves blood throughout the body.
Why do objects fall toward Earth?,"Answer: Objects fall because gravity pulls masses toward each other, with Earth exerting a strong attractive force.",Objects fall since gravity pulls them down.
What is a prime number?,Answer: A prime number is an integer greater than one that has no positive divisors other than one and itself.,A prime number is a number divisible only by one and itself.
Explain the water cycle.,"Answer: The water cycle involves evaporation, condensation, precipitation, and collection as water moves through Earth systems.","The water cycle is the movement of water through evaporation, condensation, and precipitation."
How do vaccines help protect people?,Answer: Vaccines stimulate the immune system to recognize specific pathogens so the body can respond quickly if exposed.,Vaccines train the immune system to identify harmful microbes.
"""

with open('datasets/gpt_formatted_dataset_clean.csv', 'w') as f:
    f.write(dummy_csv)

print("✓ Dummy dataset created at datasets/gpt_formatted_dataset_clean.csv")
print("✓ Contains 8 sample prompts for testing")
print("✓ Prompts will be formatted with Zephyr chat template during tokenization")

✓ Dummy dataset created at datasets/gpt_formatted_dataset_clean.csv
✓ Contains 8 sample prompts for testing
✓ Prompts will be formatted with Zephyr chat template during tokenization


In [10]:
# OLD: Load SST2 sentiment dataset
# from datasets import load_dataset
# dataset = load_dataset("sst2")
# dataset

# NEW: Load our prompt dataset
from datasets import load_dataset

# TEMPORARY: Using local dummy dataset for testing in VS Code
# WHEN READY FOR COLAB: Replace with Google Drive path:
dataset_path = "/content/drive/MyDrive/datasets/gpt_formatted_dataset_clean.csv"
# dataset_path = "datasets/gpt_formatted_dataset_clean.csv"

dataset = load_dataset("csv", data_files=dataset_path)
print(f"Dataset loaded: {len(dataset['train'])} prompts")
print(f"Columns: {dataset['train'].column_names}")
print(f"\nSample prompt: {dataset['train'][0]['prompt']}")

Dataset loaded: 307 prompts
Columns: ['prompt', 'chosen', 'rejected']

Sample prompt: Explain why the sky looks blue.


In [11]:
# OLD: Split SST2 dataset
# ds_train, ds_val = dataset['train'], dataset['validation']

# NEW: Create train/val split from our dataset
train_size = int(0.9 * len(dataset['train']))
ds_train = dataset['train'].select(range(train_size))
ds_val = dataset['train'].select(range(train_size, len(dataset['train'])))

print(f"Train size: {len(ds_train)}")
print(f"Val size: {len(ds_val)}")

Train size: 276
Val size: 31


# OLD: Filtering - not needed for our dataset
print("Filtering section skipped - our dataset already has suitable prompts")

In [12]:
# OLD: Check length after filtering
# len(ds_train)

print(f"Train size: {len(ds_train)}")

Train size: 276


In [13]:
# OLD: Filter validation for SST2
# ds_train = ds_train.filter(lambda x: len(x['sentence'].split(' ')) > 8)

# NEW: No filtering needed
print("No additional filtering needed for our prompts")

No additional filtering needed for our prompts


In [14]:
# OLD: Check length after filtering
# len(ds_train)

print(f"Train size: {len(ds_train)}")

Train size: 276


In [15]:
# OLD: Filter validation for SST2
# ds_val = ds_val.filter(lambda x: len(x['sentence'].split(' ')) > 8)

# NEW: No filtering needed
print("No additional filtering needed for validation set")

No additional filtering needed for validation set


In [16]:
# OLD: Check val length
# len(ds_val)

print(f"Val size: {len(ds_val)}")

Val size: 31


In [17]:
# OLD: Random input token length for SST2 truncation
# import random
# input_min_token_length = 2
# input_max_token_length = 8
# input_token_length_range = list(range(input_min_token_length, input_max_token_length))
# print(input_token_length_range)

# NEW: We'll use full prompts from our dataset (no truncation)
import random
print("Using full prompts from dataset (no truncation)")

Using full prompts from dataset (no truncation)


In [18]:
# OLD: Test random choice
# random.choice(input_token_length_range)

print("Random truncation not used - using full prompts")

Random truncation not used - using full prompts


In [19]:
# OLD: Tokenize with random truncation for SST2
# def tokenize(sample):
#     input_size = random.choice(input_token_length_range)
#     sample['input_ids'] = tokenizer.encode(sample['sentence'])[:input_size]
#     sample['attention_mask'] = [1] * len(sample['input_ids'])
#     sample['query'] = tokenizer.decode(sample['input_ids'])
#     return sample

# NEW: Tokenize prompts using Zephyr's chat template
def tokenize(sample):
    # Format the prompt using Zephyr's chat template
    # Zephyr expects: <|user|>\n{prompt}</s>\n<|assistant|>\n

    try:
        # Try using the built-in chat template
        messages = [
            {"role": "user", "content": sample['prompt']}
        ]

        # Use tokenize=True to get token IDs directly
        encoded = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors=None  # Return list, not tensor
        )

        # Get the formatted text for debugging
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    except Exception as e:
        # Fallback: Manually format using Zephyr's template
        # Based on: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
        formatted_prompt = f"<|user|>\n{sample['prompt']}</s>\n<|assistant|>\n"
        # Use add_special_tokens=True to properly handle special tokens
        encoded = tokenizer.encode(formatted_prompt, add_special_tokens=True)

    sample['input_ids'] = encoded
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = formatted_prompt  # Keep formatted prompt text
    return sample

map_kwargs = {
    "batched": False,
    "remove_columns": ['prompt', 'chosen', 'rejected']  # Remove CSV columns, keep tokenized data
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

print(f"Tokenized {len(tokenized_dataset_train)} training prompts")
print(f"Tokenized {len(tokenized_dataset_val)} validation prompts")
print(f"\nSample formatted prompt:")
print(repr(tokenized_dataset_train[0]['query']))
print(f"\nToken IDs (first 20): {tokenized_dataset_train[0]['input_ids'][:20]}")
print(f"Number of tokens: {len(tokenized_dataset_train[0]['input_ids'])}")

Tokenized 276 training prompts
Tokenized 31 validation prompts

Sample formatted prompt:
'<|user|>\nExplain why the sky looks blue.</s>\n<|assistant|>\n'

Token IDs (first 20): [523, 28766, 1838, 28766, 28767, 13, 966, 19457, 2079, 272, 7212, 4674, 5045, 28723, 2, 28705, 13, 28789, 28766, 489]
Number of tokens: 24


In [20]:
# Check if tokenizer has chat template
print("Has chat_template:", hasattr(tokenizer, 'chat_template'))
print("Chat template:", tokenizer.chat_template if hasattr(tokenizer, 'chat_template') else "None")

# Test with simple example
test_messages = [{"role": "user", "content": "Hello"}]
try:
    result = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)
    print("\nFormatted:", repr(result))

    tokens = tokenizer.apply_chat_template(test_messages, tokenize=True, add_generation_prompt=True)
    print("Tokens:", tokens)
    print("Max token ID:", max(tokens))
    print("Vocab size:", tokenizer.vocab_size)
except Exception as e:
    print("Error:", e)

Has chat_template: True
Chat template: {% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>
'  + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}

Formatted: '<|user|>\nHello</s>\n<|assistant|>\n'
Tokens: [523, 28766, 1838, 28766, 28767, 13, 16230, 2, 28705, 13, 28789, 28766, 489, 11143, 28766, 28767, 13]
Max token ID: 28789
Vocab size: 32000


In [21]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

In [22]:
tokenized_dataset_train[6]

{'input_ids': tensor([  523, 28766,  1838, 28766, 28767,    13,   966, 19457,   272,  2130,
         10061, 28723,     2, 28705,    13, 28789, 28766,   489, 11143, 28766,
         28767,    13]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 'query': '<|user|>\nExplain the water cycle.</s>\n<|assistant|>\n'}

## Reward Token

The reward token marks where we compute the final reward score for a generated sequence.

In [23]:
from torch.utils.data import DataLoader

# NOTE: Batch size may need to be reduced for Zephyr 7B (larger than GPT-2)
# Start with small batch size and increase if memory allows
batch_size = 4  # Reduced from 32 for 7B model

def collator(batch):
    return dict((key, [d[key] for d in batch]) for key in batch[0])

train_dataloader = DataLoader(tokenized_dataset_train, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_val, batch_size=batch_size, collate_fn=collator, shuffle=True)

print(f"Dataloaders created with batch_size={batch_size}")

Dataloaders created with batch_size=4


In [24]:
batch = next(iter(train_dataloader))
batch

{'input_ids': [tensor([  523, 28766,  1838, 28766, 28767,    13,   966, 19457,   910,  6183,
          14428,   460, 13507, 28723,     2, 28705,    13, 28789, 28766,   489,
          11143, 28766, 28767,    13]),
  tensor([  523, 28766,  1838, 28766, 28767,    13,  5660,  1235,   396,  9274,
          28755, 11630,   574,  8208, 28804,     2, 28705,    13, 28789, 28766,
            489, 11143, 28766, 28767,    13]),
  tensor([  523, 28766,  1838, 28766, 28767,    13,  7638,   511,  9027,  1468,
            506,  1160,   274, 28804,     2, 28705,    13, 28789, 28766,   489,
          11143, 28766, 28767,    13]),
  tensor([  523, 28766,  1838, 28766, 28767,    13,  2469,   473,   272,  1618,
           3646,   269,   466,  3216, 28723,     2, 28705,    13, 28789, 28766,
            489, 11143, 28766, 28767,    13])],
 'attention_mask': [tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [25]:
# Generation settings for model responses
# NOTE: These values control response length - adjust based on your needs
output_min_length = 20  # Increased from 5 for Zephyr
output_max_length = 100  # Increased from 16 for more complete responses

# https://huggingface.co/docs/trl/how_to_train#how-to-generate-text-for-training
generation_kwargs = {
    "min_length": 0,
    "top_k": 50,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id
}

print(f"Generation settings: {output_min_length}-{output_max_length} tokens")

Generation settings: 20-100 tokens


## Sample Generation (Test)

Let's test that our model can generate responses before starting training.

In [26]:
new_tokens = random.choice(list(range(output_min_length, output_max_length)))
generation_kwargs["max_new_tokens"] = new_tokens
sample = tokenizer('Hi, this')
sample

{'input_ids': [1, 15359, 28725, 456], 'attention_mask': [1, 1, 1, 1]}

In [27]:
query_response = model.generate(
    input_ids=torch.tensor(sample['input_ids']).unsqueeze(0).to(device),
    attention_mask=torch.tensor(sample['attention_mask']).unsqueeze(0).to(device),
    **generation_kwargs
    ).squeeze(0)
query_response

tensor([    1, 15359, 28725,   456,   349,   264, 17455,  2884,   354,  4797,
          367,   279, 28717,   992, 28711, 28725,   690,  2825,   456,  1338,
          349,   459,  5489,   356,   456,  3455, 28723], device='cuda:0')

In [28]:
tokenizer.decode(query_response)

'<s> Hi, this is a placeholder page for James Pitcairn, which means this person is not currently on this site.'

In [29]:
# Test scoring a generated response with LLM judge
with torch.no_grad():
    # Decode the generated response
    response_text = tokenizer.decode(query_response, skip_special_tokens=True)
    prompt_text = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)

    # Get score from LLM judge (replaces reward model)
    score = get_llm_judge_score(prompt_text, response_text)

    print(f"Prompt: {prompt_text}")
    print(f"Response: {response_text}")
    print(f"LLM Judge Score: {score:.4f}")

Prompt: Hi, this
Response: Hi, this is a placeholder page for James Pitcairn, which means this person is not currently on this site.
LLM Judge Score: 0.0500


## Batch Generation

Now let's test generating responses for a full batch of prompts.

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# NOTE: Model already on device via device_map="auto", so we don't move it
# model = model.to(device)  # Not needed with device_map="auto"

query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']

response_tensors = []
query_response_tensors = []
score_tensors = []

print(f"Generating responses for {len(query_tensors)} prompts...")

for i, query in enumerate(query_tensors):
    query = query.to(device)
    query_attention_mask = query_attention_masks[i].to(device)
    new_tokens = random.choice(list(range(output_min_length, output_max_length)))
    generation_kwargs["max_new_tokens"] = new_tokens
    query_response = model.generate(
        input_ids=query.unsqueeze(0),
        attention_mask=query_attention_mask.unsqueeze(0),
        **generation_kwargs
    ).squeeze(0)

    response_len = len(query_response) - len(query)
    response_tensors.append(query_response[-response_len:])
    query_response_tensors.append(query_response)

    # Use LLM judge instead of reward model
    with torch.no_grad():
        prompt_text = tokenizer.decode(query, skip_special_tokens=True)
        print(prompt_text)
        response_text = tokenizer.decode(query_response[-response_len:], skip_special_tokens=True)

        # Get score from LLM judge
        score = get_llm_judge_score(prompt_text, response_text)
        score = torch.tensor(score).to(device)

    score_tensors.append(score)

batch["response"] = [tokenizer.decode(response, skip_special_tokens=True) for response in response_tensors]
print("Generated responses:")
print(batch['response'])

Using device: cuda
Generating responses for 4 prompts...
<|user|>
Explain how credit scores are calculated. 
<|assistant|>

<|user|>
How does an ATM verify your identity? 
<|assistant|>

<|user|>
Why do magnets have poles? 
<|assistant|>

<|user|>
Define the Enlightenment period. 
<|assistant|>

Generated responses:
['Credit scores are calculated using complex algorithms applied by credit bureaus based on various factors. The most commonly used credit scoring model is the FICO score, which ranges from 300 to 850. Here are some key factors that are considered when calculating a credit score:\n\n1. Payment History (35%): This is the most important factor in determining a credit score.', 'An ATM (Automated Teller Machine) verifies your identity in the following ways:\n\n1. Debit Card: When you insert your debit card into the', 'Magnets have poles because they are made of materials that have a property called magnetization. Magnetization occurs when the atoms in a material align themselves

## Compute Reward

The reward function combines:
- **Score from LLM judge**: Quality of the response
- **KL penalty**: Prevents the model from diverging too much from the reference (SFT) model

**Reward Formula:**

$\text{reward} = \text{score} - \beta \cdot \log \left(\frac{\pi^{RL}_\theta}{\pi^{SFT}}\right)$

Where:
- $\text{score}$ is from the LLM judge
- $\beta$ is the KL penalty coefficient (controls how much we penalize divergence)
- $\pi^{RL}_\theta$ is the current policy (model being trained)
- $\pi^{SFT}$ is the reference policy (frozen SFT model)

In [31]:
# Adding these so that types match up correctly

# Move the value head to the same device as the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.v_head.to(device)

# Verify it works
print(f"Value head device: {next(model.v_head.parameters()).device}")

# Convert the value head to float16 to match the base model's output
model.v_head.to(dtype=torch.bfloat16)

# Verify the fix
print(f"Value head dtype: {model.v_head.value.weight.dtype}")

Value head device: cuda:0
Value head dtype: torch.bfloat16


In [32]:
# Create reference model (frozen copy for KL divergence)
# NOTE: For large models like Zephyr 7B, deepcopy might use a lot of memory
# The reference model stays frozen during training
from copy import deepcopy
print("Creating reference model (this may take a moment for 7B model)...")
sft_model = deepcopy(model)
sft_model.eval()  # Set to evaluation mode (frozen)
print("✓ Reference model created")

Creating reference model (this may take a moment for 7B model)...
✓ Reference model created


In [33]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [34]:
input_data = data_collator([
    {'input_ids': ids,
     'attention_mask': torch.ones_like(ids)} for ids in query_response_tensors
]).to(device)
input_data

{'input_ids': tensor([[    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,   523, 28766,  1838, 28766,
         28767,    13,   966, 19457,   910,  6183, 14428,   460, 13507, 28723,
             2, 28705,    13, 28789, 28766,   489, 11143, 28766, 28767,    13,
         28743, 10709, 14428,   460, 13507,  1413,  4630, 18539,  7589,   486,
          6183,   287,   482,  1899,  2818,   356,  4118,  8612, 28723,   415,
          1080, 14473,  1307,  6183, 20310,  2229,   349,   272,   401,  1604,
         28762,  7420, 28725,   690, 19441,   477, 28705, 28770, 28734, 28734,
           298, 28705, 28783, 28782, 28734, 28723,  4003,   460,   741,  1945,
          8612,   369,   460,  4525,   739,  4900,  1077,   264,  6183,  7420,
         28747,    13,    13, 28740, 28723, 27294,  6866,   325, 28770, 28782,
         28823,  1329,   851,   349,   272,  1080,  2278,  6999,   297, 23689,
           264,  6183,  7420, 28723],


In [35]:
def compute_rewards(input_data, query_tensors, response_tensors, score_tensors):
    with torch.no_grad():
        logits, values = model(**input_data) # b, seq, vocab
        ref_logits, _ = sft_model(**input_data)

        # FIX 1: Clamp logits to avoid -inf/nan in log_softmax
        logits = torch.clamp(logits, min=-100, max=100)
        ref_logits = torch.clamp(ref_logits, min=-100, max=100)

        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)

        # FIX 2: These lines were missing! They select the prob of the actual token used.
        labels = input_data['input_ids'][:, 1:] # b, seq
        logp = torch.gather(logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq
        ref_logp = torch.gather(ref_logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq

        kl = logp - ref_logp
        beta = 0.2
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:,:] = attention_mask[:, 1:]

        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks


In [36]:
logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
print(rewards[0])
print(input_data['input_ids'][0])
print(input_data['attention_mask'][0])

tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., 1., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
       device='cuda:0', dtype=torch.bfloat16)
tensor([    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,   523, 28766,  1838, 28766,
        28767,    13,   966, 19457,   910,  6183, 14428,   460, 13507, 28723,
            2, 28705,    13, 28789, 28766,   489,

In [37]:
print(masks[0])
print(values[0])

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0], device='cuda:0')
tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        -0.0000, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000,  2.5469,
         7.5938, -0.9336,  7.7500,  7.3438,  3.2031,  6.5000,  8.3750,  2.2656,
         2.4062, 12.0000, -4.0312, -0.7148,  7.1562,  6.4375,  4.7500, -7.6875,
         3.7344, -8.6250, 11.0625, -0.0581, 11.5625,  1.1719, -5.1250,  8.5625,
        13.2500, -4.8125, -4.2500, -2.2031, -3.2656,  6.6875,  5.1875,  3.5469,
     

## Compute Advantage

The advantage function estimates how much better an action is compared to the average:
- **Positive advantage**: This action is better than expected
- **Negative advantage**: This action is worse than expected

Uses Generalized Advantage Estimation (GAE) for lower variance.

In [38]:
def masked_mean(values, mask):
    return (values * mask).sum() / mask.sum()

def masked_var(values, mask):
    mean = masked_mean(values, mask)
    centred_values = values - mean
    return masked_mean(centred_values ** 2, mask)

# def masked_whiten(values, mask):
#     mean, var = masked_mean(values, mask), masked_var(values, mask)
#     whitened = (values - mean) * torch.rsqrt(var + 1e-8)
#     whitened += mean
#     return whitened
def masked_whiten(values, mask):
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    # Add a larger epsilon (1e-6 instead of 1e-8) for stability
    # If variance is 0, just return the centered values without scaling
    if var < 1e-6:
        return values - mean
    whitened = (values - mean) * torch.rsqrt(var + 1e-6)
    return whitened  # Note: Standard whitening returns (x-mu)/sigma, not adding mean back

def compute_advantage(rewards, values, masks):
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95

    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)

    returns = advantages + values
    return advantages, returns


In [39]:
advantages, returns = compute_advantage(rewards, values, masks)
print(advantages[0])
print(returns[0])

tensor([ 0.2715,  0.2812,  0.2930,  0.3027,  0.3145,  0.3262,  0.3398,  0.3516,
         0.3652,  0.3809,  0.3945,  0.4121,  0.4316,  0.4473,  0.4668,  0.4863,
         0.5078,  0.5312,  0.5547,  0.5781,  0.6055,  0.6328,  0.6641,  0.2266,
        -0.7031,  0.8281, -0.7422, -0.7070,  0.0148, -0.5977, -0.9844,  0.0977,
         0.0723, -1.7031,  1.1562,  0.5977, -0.8281, -0.7461, -0.4785,  1.8047,
        -0.2236,  2.0469, -1.4922,  0.4785, -1.6484,  0.1768,  1.3516, -1.1172,
        -2.0469,  1.1641,  1.1172,  0.7891,  1.0234, -0.7617, -0.5312, -0.2617,
        -0.6094, -2.2969, -0.7422,  0.8203,  0.8477, -0.2344,  1.4609,  1.6016,
         1.8047, -0.0840,  0.4883,  0.3203, -1.1953,  0.6094,  1.2422,  0.2188,
        -0.5391, -0.7578, -0.9844,  0.4414,  0.9492, -0.1719,  0.5000, -1.6406,
        -0.6211,  0.9375,  0.9375, -0.0237, -0.3652, -0.5312, -1.2344,  1.0078,
         0.3906,  0.3730, -0.2812, -0.6641, -0.6484,  0.0669,  1.1172,  1.7578,
         1.0000,  0.7852,  0.9297, -0.35

## Mini-batch PPO Training

PPO updates the policy using mini-batches to improve sample efficiency and stability.

### Training Config

In [40]:
# Training hyperparameters
# NOTE: These may need adjustment for Zephyr 7B
learning_rate = 1e-5  # Conservative learning rate for fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f"Optimizer created with lr={learning_rate}")

Optimizer created with lr=1e-05


In [41]:
np.random.permutation(batch_size)

array([3, 1, 0, 2])

In [42]:
# # PPO training configuration
# mini_batch_size = 2  # Reduced from 4 for memory efficiency with 7B model
# ppo_epochs = 4

# cliprange_ratio = 0.2  # PPO clipping range
# v_loss_coeff = 0.1     # Value loss coefficient
# ratio_threshold = 10   # Threshold to detect unstable training

# def compute_loss(old_logprobs, values, logprobs, vpreds, masks, advantages, returns):
#     """
#     Compute PPO loss with clipping.

#     Args:
#         old_logprobs: Log probabilities from the old policy
#         values: Value estimates from the old policy
#         logprobs: Log probabilities from the current policy
#         vpreds: Value predictions from the current policy
#         masks: Attention masks
#         advantages: Computed advantages
#         returns: Computed returns (advantages + values)

#     Returns:
#         loss: Combined policy and value loss
#         v_loss: Value loss component
#     """
#     ratio = torch.exp(logprobs - old_logprobs)
#     pg_loss1 = - ratio * advantages
#     pg_loss2 = - torch.clamp(ratio, 1 - cliprange_ratio, 1 + cliprange_ratio) * advantages
#     pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)

#     v_loss = masked_mean((vpreds - returns) ** 2, masks)
#     loss = pg_loss + v_loss_coeff * v_loss

#     avg_ratio = masked_mean(ratio, masks)
#     if avg_ratio > ratio_threshold:
#         # Unstable training detected - zero out gradients
#         pg_loss = pg_loss * 0.0
#         v_loss = v_loss * 0.0
#         loss = loss * 0.0

#     return loss, v_loss

# def mini_batch_train():
#     """Run mini-batch PPO training for multiple epochs."""
#     for ep in range(ppo_epochs):
#         batch_inds = np.random.permutation(batch_size)

#         for start in range(0, batch_size, mini_batch_size):
#             end = start + mini_batch_size
#             mini_batch_inds = batch_inds[start:end]

#             mb_model_inputs = {
#                 'input_ids': input_data['input_ids'][mini_batch_inds],
#                 'attention_mask': input_data['attention_mask'][mini_batch_inds]
#             }
#             mb_logits, mb_vpreds = model(**mb_model_inputs)
#             mb_logits = torch.nn.functional.log_softmax(mb_logits[:, :-1, :], dim=-1)
#             mb_logprobs = torch.gather(mb_logits, 2, mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)

#             loss, loss_v = compute_loss(
#                 logprobs[mini_batch_inds],
#                 values[mini_batch_inds],
#                 mb_logprobs,
#                 mb_vpreds[:, :-1],
#                 masks[mini_batch_inds],
#                 advantages[mini_batch_inds],
#                 returns[mini_batch_inds]
#             )

#             optimizer.zero_grad()
#             loss.backward()

#             # Added in: Add gradient clipping
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

#             optimizer.step()
#             print('loss/total', loss.item())
#     print('mini-batch training finished')

# print(f"PPO config: {ppo_epochs} epochs, mini_batch_size={mini_batch_size}")

# PPO training configuration
mini_batch_size = 8  # Reduced from 4 for memory efficiency with 7B model
ppo_epochs = 4

cliprange_ratio = 0.2  # PPO clipping range
v_loss_coeff = 0.1     # Value loss coefficient
ratio_threshold = 10   # Threshold to detect unstable training

def compute_loss(old_logprobs, values, logprobs, vpreds, masks, advantages, returns):
    """
    Compute PPO loss with clipping.
    """
    ratio = torch.exp(logprobs - old_logprobs)
    pg_loss1 = - ratio * advantages
    pg_loss2 = - torch.clamp(ratio, 1 - cliprange_ratio, 1 + cliprange_ratio) * advantages
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)

    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    loss = pg_loss + v_loss_coeff * v_loss

    avg_ratio = masked_mean(ratio, masks)
    if avg_ratio > ratio_threshold:
        # Unstable training detected - zero out gradients
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0

    return loss, v_loss

def mini_batch_train():
    """Run mini-batch PPO training for multiple epochs."""
    # FIX 1: Get actual batch size from input data (prevents index out of bounds)
    current_batch_size = len(input_data['input_ids'])

    for ep in range(ppo_epochs):
        # FIX 1: Use current_batch_size for permutation
        batch_inds = np.random.permutation(current_batch_size)

        for start in range(0, current_batch_size, mini_batch_size):
            end = start + mini_batch_size
            mini_batch_inds = batch_inds[start:end]

            mb_model_inputs = {
                'input_ids': input_data['input_ids'][mini_batch_inds],
                'attention_mask': input_data['attention_mask'][mini_batch_inds]
            }
            mb_logits, mb_vpreds = model(**mb_model_inputs)

            # FIX 2: Clamp logits here to prevent NaN loss!
            mb_logits = torch.clamp(mb_logits, min=-100, max=100)

            mb_logits = torch.nn.functional.log_softmax(mb_logits[:, :-1, :], dim=-1)
            mb_logprobs = torch.gather(mb_logits, 2, mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)

            loss, loss_v = compute_loss(
                logprobs[mini_batch_inds],
                values[mini_batch_inds],
                mb_logprobs,
                mb_vpreds[:, :-1],
                masks[mini_batch_inds],
                advantages[mini_batch_inds],
                returns[mini_batch_inds]
            )

            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping (you already had this, keep it!)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            # Debug print
            if torch.isnan(loss):
                print(f"WARNING: Loss is NaN in epoch {ep}!")
            else:
                print('loss/total', loss.item())
    print('mini-batch training finished')

print(f"PPO config: {ppo_epochs} epochs, mini_batch_size={mini_batch_size}")

PPO config: 4 epochs, mini_batch_size=8


In [43]:
mini_batch_train()

loss/total 0.099609375
loss/total 0.055419921875
loss/total 0.033447265625
loss/total 0.016845703125
mini-batch training finished


## Train RLHF

The main training loop:
1. Generate responses for each prompt
2. Score responses with LLM judge
3. Compute rewards (score - KL penalty)
4. Compute advantages (how good was this response?)
5. Update the policy with PPO

In [44]:
num_epochs = 1

for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Generate responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']

        response_tensors = []
        query_response_tensors = []
        score_tensors = []

        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = random.choice(list(range(output_min_length, output_max_length)))
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
                ).squeeze(0)

            response_len = len(query_response) - len(query)
            response_tensors.append(query_response[-response_len:])
            query_response_tensors.append(query_response)

            # Use LLM judge instead of reward model
            with torch.no_grad():
                prompt_text = tokenizer.decode(query, skip_special_tokens=True)
                response_text = tokenizer.decode(query_response[-response_len:], skip_special_tokens=True)

                # Get score from LLM judge
                score = get_llm_judge_score(prompt_text, response_text)
                score = torch.tensor(score).to(device)

            score_tensors.append(score)

        input_data = data_collator([
            {
                'input_ids': ids,
                'attention_mask': torch.ones_like(ids)
            }
            for ids in query_response_tensors
        ]).to(device)

        # rewards and advantages
        logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
        advantages, returns = compute_advantage(rewards, values, masks)

        # mini batch training
        mini_batch_train()
    print(f'epoch {epoch + 1} finished')

loss/total 0.0986328125
loss/total 0.08203125
loss/total 0.061767578125
loss/total 0.0361328125
mini-batch training finished
loss/total 0.09912109375
loss/total 0.07958984375
loss/total 0.06005859375
loss/total 0.0458984375
mini-batch training finished
loss/total 0.099609375
loss/total 0.08349609375
loss/total 0.06298828125
loss/total 0.0458984375
mini-batch training finished
loss/total 0.099609375
loss/total 0.08251953125
loss/total 0.06201171875
loss/total 0.0439453125
mini-batch training finished
loss/total 0.10009765625
loss/total 0.08154296875
loss/total 0.064453125
loss/total 0.04833984375
mini-batch training finished
loss/total 0.0986328125
loss/total 0.0830078125
loss/total 0.061279296875
loss/total 0.039306640625
mini-batch training finished
loss/total 0.1005859375
loss/total 0.08447265625
loss/total 0.056396484375
loss/total 0.03759765625
mini-batch training finished
loss/total 0.1005859375
loss/total 0.07763671875
loss/total 0.0458984375
loss/total 0.029052734375
mini-batch 

In [50]:
# save the model
torch.save(model.state_dict(), 'ppo_model_epoch_1_BAD.pt')

## Validation

Test the trained model on the validation set to see if responses improved.

In [45]:
len(tokenized_dataset_val)

31

In [46]:
val_gen_lengths = [0] * len(tokenized_dataset_val)
for i in range(len(tokenized_dataset_val)):
    val_gen_lengths[i] = random.choice(list(range(output_min_length, output_max_length)))

In [47]:
val_gen_lengths[:10]

[23, 81, 39, 48, 57, 22, 48, 44, 69, 41]

In [48]:
def validate():
    scores = []
    for b, batch in enumerate(val_dataloader):
        # Generate_responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']
        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = val_gen_lengths[b * len(query_tensors) + i]
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
                ).squeeze(0)
            # query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
            score = reward_model(query_response_score.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
            score = 2 * (score - 0.5)
            scores.append(score.item())
    print('avg score:', sum(scores) / len(scores))

In [49]:
validate()

NameError: name 'REWARD_TOKEN_ID' is not defined

In [None]:
torch.save(model.state_dict(), 'ppo_model_epoch_1.pt')

In [None]:
model_path = './sft_model_epoch_1'
model = ModelForCausalLMWithValueHead(model_path).to(device)

In [None]:
validate()