# RLHF Fine-Tuning

## Mount Google Drive

```python
from google.colab import drive
drive.mount('/content/drive')
```

# Mount Google Drive to access dataset
from google.colab import drive
drive.mount('/content/drive')

In [24]:
# OLD: Load from Google Drive
# %cp /content/drive/MyDrive/copy\ files/reward_model.pt .
# %cp /content/drive/MyDrive/copy\ files/sft_model_epoch_1.zip .

In [25]:
# OLD: Unzip local file
# !unzip sft_model_epoch_1.zip

## LLM Judge (Reward Function)

Instead of using a trained reward model, we'll use an LLM API (like GPT-4 or Claude) to score the quality of generated responses.

In [26]:
def get_llm_judge_score(prompt, response):
    """
    Get a reward score from an LLM judge for a given prompt-response pair.
    
    Args:
        prompt (str): The original prompt/question
        response (str): The model's generated response
        
    Returns:
        float: Reward score in range [-1, 1]
    """
    # TODO: Replace this with actual API call to GPT-4, Claude, or other LLM
    # Example API call structure:
    # judge_prompt = f"Rate the quality of this response (0-10):\nQuestion: {prompt}\nAnswer: {response}\nRating:"
    # api_response = call_llm_api(judge_prompt)
    # score = extract_score(api_response) / 5.0 - 1.0  # Normalize to [-1, 1]
    
    # PLACEHOLDER: Mock scoring based on response length
    # Longer responses get slightly higher scores (just for testing)
    mock_score = min(len(response) / 200.0, 1.0)  # 0 to 1
    mock_score = 2 * (mock_score - 0.5)  # Convert to [-1, 1] range
    
    return float(mock_score)

# Test the placeholder function
test_prompt = "What is the capital of France?"
test_response = "Paris is the capital and most populous city of France."
test_score = get_llm_judge_score(test_prompt, test_response)
print(f"Test LLM Judge Score: {test_score:.4f}")
print("✓ LLM judge function ready (using placeholder - replace with real API later)")

Test LLM Judge Score: -0.4600
✓ LLM judge function ready (using placeholder - replace with real API later)


In [27]:
# OLD: Load reward model from file
# model_name = "gpt2"
# reward_model = GPT2RewardModel(model_name)
# reward_model.load_state_dict(torch.load("reward_model.pt", map_location='cpu'))

print("Reward model loading skipped - will use LLM judge API instead")

Reward model loading skipped - will use LLM judge API instead


## Policy Model with Value Head

PPO requires two components:
- **Policy (Actor)**: The language model that generates text
- **Value Head (Critic)**: Estimates the expected reward for a given state

We combine both into a single model class.

In [52]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

from typing import Optional
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class ValueHead(nn.Module):
    """
    The ValueHead class implements a head for the model
    that returns a scalar for each output token.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.value = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(self.value.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.value.bias)

    def forward(self, hidden_states):
        output = hidden_states
        return self.value(output)


class ModelForCausalLMWithValueHead(nn.Module):
    """
    Causal LM model with a value head on top.
    """

    def __init__(self, model_name_or_path, quantization_config=None):
        super().__init__()
        # NEW: Support loading from HuggingFace with quantization
        if quantization_config is not None:
            self.llm = AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                quantization_config=quantization_config,
                device_map="auto",
            )
        else:
            # OLD: Load from local path (GPT-2)
            self.llm = AutoModelForCausalLM.from_pretrained(model_name_or_path)
        
        # Add the value head
        self.v_head = ValueHead(self.llm.config)

    def forward(
        self,
        input_ids,
        attention_mask,
    ) -> Optional[torch.FloatTensor]:

        transformer_outputs = self.llm.forward(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states = True,
        )
        lm_logits = transformer_outputs.logits
        # Get the last hidden state
        last_hidden_state = transformer_outputs.hidden_states[-1]

        # Apply the value head
        value = self.v_head(last_hidden_state).squeeze(-1)
        return lm_logits, value

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

Using device: cuda


In [29]:
# OLD: Load GPT-2 from local path
# model_path = './sft_model_epoch_1'
# model = ModelForCausalLMWithValueHead(model_path)

# NEW: Load Zephyr 7B from HuggingFace with 4-bit quantization
from transformers import BitsAndBytesConfig

model_name = "HuggingFaceH4/zephyr-7b-beta"

# Quantization config for 4-bit loading
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load model with value head
model = ModelForCausalLMWithValueHead(model_name, quantization_config=bnb_config)
print("Zephyr 7B model loaded successfully with value head!")

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Zephyr 7B model loaded successfully with value head!


## Preparing Dataset

We load prompts from our CSV file and tokenize them for training.

In [30]:
from transformers import AutoTokenizer

# OLD: Load GPT-2 tokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# NEW: Load Zephyr tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # Important for batch generation
print(f"Tokenizer loaded. Pad token: {tokenizer.pad_token}")

Tokenizer loaded. Pad token: </s>


In [31]:
# TEMPORARY: Create dummy dataset for testing (remove when using real dataset from Drive)
import os
os.makedirs('datasets', exist_ok=True)

dummy_csv = """prompt,chosen,rejected
Explain why the sky looks blue.,"Answer: The sky appears blue because molecules in the atmosphere scatter shorter wavelengths of sunlight more efficiently than longer ones.","The sky looks blue due to Rayleigh scattering."
How does photosynthesis work in plants?,"Answer: Photosynthesis uses light, water, and carbon dioxide to produce glucose and oxygen in chloroplasts.","Photosynthesis is the way plants turn light into food."
What is the capital city of Japan?,Answer: Tokyo is the capital city of Japan.,Tokyo is the capital city of Japan.
Describe the function of the human heart.,Answer: The heart pumps oxygenated and deoxygenated blood through the body to support cellular activity.,The heart moves blood throughout the body.
Why do objects fall toward Earth?,"Answer: Objects fall because gravity pulls masses toward each other, with Earth exerting a strong attractive force.",Objects fall since gravity pulls them down.
What is a prime number?,Answer: A prime number is an integer greater than one that has no positive divisors other than one and itself.,A prime number is a number divisible only by one and itself.
Explain the water cycle.,"Answer: The water cycle involves evaporation, condensation, precipitation, and collection as water moves through Earth systems.","The water cycle is the movement of water through evaporation, condensation, and precipitation."
How do vaccines help protect people?,Answer: Vaccines stimulate the immune system to recognize specific pathogens so the body can respond quickly if exposed.,Vaccines train the immune system to identify harmful microbes.
"""

with open('datasets/gpt_formatted_dataset_clean.csv', 'w') as f:
    f.write(dummy_csv)

print("✓ Dummy dataset created at datasets/gpt_formatted_dataset_clean.csv")
print("✓ Contains 8 sample prompts for testing")

✓ Dummy dataset created at datasets/gpt_formatted_dataset_clean.csv
✓ Contains 8 sample prompts for testing


In [32]:
# OLD: Load SST2 sentiment dataset
# from datasets import load_dataset
# dataset = load_dataset("sst2")
# dataset

# NEW: Load our prompt dataset
from datasets import load_dataset

# TEMPORARY: Using local dummy dataset for testing in VS Code
# WHEN READY FOR COLAB: Replace with Google Drive path:
# dataset_path = "/content/drive/MyDrive/datasets/gpt_formatted_dataset_clean.csv"
dataset_path = "datasets/gpt_formatted_dataset_clean.csv"

dataset = load_dataset("csv", data_files=dataset_path)
print(f"Dataset loaded: {len(dataset['train'])} prompts")
print(f"Columns: {dataset['train'].column_names}")
print(f"\nSample prompt: {dataset['train'][0]['prompt']}")

Generating train split: 0 examples [00:00, ? examples/s]

Dataset loaded: 8 prompts
Columns: ['prompt', 'chosen', 'rejected']

Sample prompt: Explain why the sky looks blue.


In [33]:
# OLD: Split SST2 dataset
# ds_train, ds_val = dataset['train'], dataset['validation']

# NEW: Create train/val split from our dataset
train_size = int(0.9 * len(dataset['train']))
ds_train = dataset['train'].select(range(train_size))
ds_val = dataset['train'].select(range(train_size, len(dataset['train'])))

print(f"Train size: {len(ds_train)}")
print(f"Val size: {len(ds_val)}")

Train size: 7
Val size: 1


# OLD: Filtering - not needed for our dataset
print("Filtering section skipped - our dataset already has suitable prompts")

In [34]:
# OLD: Check length after filtering
# len(ds_train)

print(f"Train size: {len(ds_train)}")

Train size: 7


In [35]:
# OLD: Filter validation for SST2
# ds_train = ds_train.filter(lambda x: len(x['sentence'].split(' ')) > 8)

# NEW: No filtering needed
print("No additional filtering needed for our prompts")

No additional filtering needed for our prompts


In [36]:
# OLD: Check length after filtering
# len(ds_train)

print(f"Train size: {len(ds_train)}")

Train size: 7


In [37]:
# OLD: Filter validation for SST2
# ds_val = ds_val.filter(lambda x: len(x['sentence'].split(' ')) > 8)

# NEW: No filtering needed
print("No additional filtering needed for validation set")

No additional filtering needed for validation set


In [38]:
# OLD: Check val length
# len(ds_val)

print(f"Val size: {len(ds_val)}")

Val size: 1


In [39]:
# OLD: Random input token length for SST2 truncation
# import random
# input_min_token_length = 2
# input_max_token_length = 8
# input_token_length_range = list(range(input_min_token_length, input_max_token_length))
# print(input_token_length_range)

# NEW: We'll use full prompts from our dataset (no truncation)
import random
print("Using full prompts from dataset (no truncation)")

Using full prompts from dataset (no truncation)


In [40]:
# OLD: Test random choice
# random.choice(input_token_length_range)

print("Random truncation not used - using full prompts")

Random truncation not used - using full prompts


In [41]:
# OLD: Tokenize with random truncation for SST2
# def tokenize(sample):
#     input_size = random.choice(input_token_length_range)
#     sample['input_ids'] = tokenizer.encode(sample['sentence'])[:input_size]
#     sample['attention_mask'] = [1] * len(sample['input_ids'])
#     sample['query'] = tokenizer.decode(sample['input_ids'])
#     return sample

# NEW: Tokenize full prompts from our dataset
def tokenize(sample):
    # Use the full prompt from our dataset
    encoded = tokenizer.encode(sample['prompt'], add_special_tokens=True)
    sample['input_ids'] = encoded
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = sample['prompt']  # Keep original prompt text
    return sample

map_kwargs = {
    "batched": False,
    "remove_columns": ['prompt', 'chosen', 'rejected']  # Remove CSV columns, keep tokenized data
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

print(f"Tokenized {len(tokenized_dataset_train)} training prompts")
print(f"Tokenized {len(tokenized_dataset_val)} validation prompts")

Map:   0%|          | 0/7 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Tokenized 7 training prompts
Tokenized 1 validation prompts


In [42]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

In [43]:
tokenized_dataset_train[6]

{'input_ids': tensor([    1, 13702,   426,   272,  2130, 10061, 28723]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1]),
 'query': 'Explain the water cycle.'}

## Reward Token

The reward token marks where we compute the final reward score for a generated sequence.

In [45]:
from torch.utils.data import DataLoader

# NOTE: Batch size may need to be reduced for Zephyr 7B (larger than GPT-2)
# Start with small batch size and increase if memory allows
batch_size = 4  # Reduced from 32 for 7B model

def collator(batch):
    return dict((key, [d[key] for d in batch]) for key in batch[0])

train_dataloader = DataLoader(tokenized_dataset_train, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_val, batch_size=batch_size, collate_fn=collator, shuffle=True)

print(f"Dataloaders created with batch_size={batch_size}")

Dataloaders created with batch_size=4


In [46]:
batch = next(iter(train_dataloader))
batch

{'input_ids': [tensor([    1,  1824,   349,   272,  5565,  2990,   302,  4720, 28804]),
  tensor([    1,  4315,   511,  6697,  2949,  4112,  8599, 28804]),
  tensor([    1, 13702,   426,   272,  2130, 10061, 28723]),
  tensor([    1, 27984,   272,   908,   302,   272,  2930,  3031, 28723])],
 'attention_mask': [tensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1])],
 'query': ['What is the capital city of Japan?',
  'Why do objects fall toward Earth?',
  'Explain the water cycle.',
  'Describe the function of the human heart.']}

In [47]:
# Generation settings for model responses
# NOTE: These values control response length - adjust based on your needs
output_min_length = 20  # Increased from 5 for Zephyr
output_max_length = 100  # Increased from 16 for more complete responses

# https://huggingface.co/docs/trl/how_to_train#how-to-generate-text-for-training
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id
}

print(f"Generation settings: {output_min_length}-{output_max_length} tokens")

Generation settings: 20-100 tokens


## Sample Generation (Test)

Let's test that our model can generate responses before starting training.

In [53]:
new_tokens = random.choice(list(range(output_min_length, output_max_length)))
generation_kwargs["max_new_tokens"] = new_tokens
sample = tokenizer('Hi, this')
sample

{'input_ids': [1, 15359, 28725, 456], 'attention_mask': [1, 1, 1, 1]}

In [54]:
query_response = model.generate(
    input_ids=torch.tensor(sample['input_ids']).unsqueeze(0).to(device),
    attention_mask=torch.tensor(sample['attention_mask']).unsqueeze(0).to(device),
    **generation_kwargs
    ).squeeze(0)
query_response

tensor([    1, 15359, 28725,   456,   349,   272,   907,  3798,   315,  1038,
          304,   315,   947,   298, 19411,   272,  2944, 22275, 28713,   298,
         1877,   264,  2488,  1236, 28725,   970,   315,   967,   741,  1178,
         1101,   315, 28742,   333, 12879,   390,   264,   633,  3638,   272,
         4480, 28725,  1236, 28742, 28713,   264,  1877,  1703, 28725,  7368,
          744, 28745,  1236, 28742, 28713,   272,   908,  1101,  4003, 28725,
          478,   506,   272, 15000,  1101,  1537,  1938,   264,  2948,  3216,
        28725,   477], device='cuda:0')

In [55]:
tokenizer.decode(query_response)

"<s> Hi, this is the first video I make and I want to invite the press analysts to check a project here, where I add some data... I've approved as a new task the feature, here's a checklist, completed part; here's the function... Here, we have the approval... So during a specific period, from"

In [56]:
# Test scoring a generated response with LLM judge
with torch.no_grad():
    # Decode the generated response
    response_text = tokenizer.decode(query_response, skip_special_tokens=True)
    prompt_text = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
    
    # Get score from LLM judge (replaces reward model)
    score = get_llm_judge_score(prompt_text, response_text)
    
    print(f"Prompt: {prompt_text}")
    print(f"Response: {response_text}")
    print(f"LLM Judge Score: {score:.4f}")

Prompt: Hi, this
Response: Hi, this is the first video I make and I want to invite the press analysts to check a project here, where I add some data... I've approved as a new task the feature, here's a checklist, completed part; here's the function... Here, we have the approval... So during a specific period, from
LLM Judge Score: 1.0000


## Batch Generation

Now let's test generating responses for a full batch of prompts.

In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# NOTE: Model already on device via device_map="auto", so we don't move it
# model = model.to(device)  # Not needed with device_map="auto"

query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']

response_tensors = []
query_response_tensors = []
score_tensors = []

print(f"Generating responses for {len(query_tensors)} prompts...")

for i, query in enumerate(query_tensors):
    query = query.to(device)
    query_attention_mask = query_attention_masks[i].to(device)
    new_tokens = random.choice(list(range(output_min_length, output_max_length)))
    generation_kwargs["max_new_tokens"] = new_tokens
    query_response = model.generate(
        input_ids=query.unsqueeze(0),
        attention_mask=query_attention_mask.unsqueeze(0),
        **generation_kwargs
    ).squeeze(0)

    response_len = len(query_response) - len(query)
    response_tensors.append(query_response[-response_len:])
    query_response_tensors.append(query_response)

    # Use LLM judge instead of reward model
    with torch.no_grad():
        prompt_text = tokenizer.decode(query, skip_special_tokens=True)
        response_text = tokenizer.decode(query_response[-response_len:], skip_special_tokens=True)
        
        # Get score from LLM judge
        score = get_llm_judge_score(prompt_text, response_text)
        score = torch.tensor(score).to(device)
        
    score_tensors.append(score)

batch["response"] = [tokenizer.decode(response, skip_special_tokens=True) for response in response_tensors]
print("Generated responses:")
print(batch['response'])

Using device: cuda
Generating responses for 4 prompts...
Generated responses:
['\nNSW Department of Education: Curriculum Connections: Stages: 2, 3: Arts: Visual Arts. Students analyse, make and present artworks inspired by the works and ideas of other artists.\nStudents analyse, make and present visual artworks in', '\n curiosity. In my sophomore Earth Science class at Padua I encountered a perfect storm of academics and activities that stimulated the curious. I remember sitting in World Geography class during my sophomore year, doodling on the side of my paper while we were learning the distribution of renewable resources across the world. http://www.workcity.com/forums/index.php?showtopic=62165 Mrs Ger', ": I took this photo this afternoon in my garden. The weather here in Sydney has been so incredibly hot of late and I decided to take an old watering can I've had for years and 'play' with it in the garden... I wanted to 'recreate' some of the ways I", '\nextensions: guess the words

## Compute Reward

The reward function combines:
- **Score from LLM judge**: Quality of the response
- **KL penalty**: Prevents the model from diverging too much from the reference (SFT) model

**Reward Formula:**

$\text{reward} = \text{score} - \beta \cdot \log \left(\frac{\pi^{RL}_\theta}{\pi^{SFT}}\right)$

Where:
- $\text{score}$ is from the LLM judge
- $\beta$ is the KL penalty coefficient (controls how much we penalize divergence)
- $\pi^{RL}_\theta$ is the current policy (model being trained)
- $\pi^{SFT}$ is the reference policy (frozen SFT model)

In [58]:
# Create reference model (frozen copy for KL divergence)
# NOTE: For large models like Zephyr 7B, deepcopy might use a lot of memory
# The reference model stays frozen during training
from copy import deepcopy
print("Creating reference model (this may take a moment for 7B model)...")
sft_model = deepcopy(model)
sft_model.eval()  # Set to evaluation mode (frozen)
print("✓ Reference model created")

Creating reference model (this may take a moment for 7B model)...
✓ Reference model created


In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
input_data = data_collator([
    {'input_ids': ids,
     'attention_mask': torch.ones_like(ids)} for ids in query_response_tensors
]).to(device)
input_data

{'input_ids': tensor([[  271,   257,  5391,    12,    86,  2175,   289,  2954,   286,   257,
          3807,   764,   220,   220,   220, 50256, 50256, 50256, 50256, 50256,
         50256],
        [ 8340,   257,  3807,   475,  2038,   284,   787,   340,  1254, 16425,
           220,   220,   220,   220,   220,   220, 50256, 50256, 50256, 50256,
         50256],
        [23442,   262,  5743,   286,   465,  5852,   284,  2222,   428, 16576,
           925, 10997,   284, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [19188,   326,  4077, 29815,   550,  3750,   257,  2239,  2252,  5633,
           220,   220,   220,   220,   220,   220,   220,   220, 50256, 50256,
         50256],
        [  292, 29408,   287,  3450,   220,   220,   220,   220,   714,   423,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [ 5832,   705,   260,  9431, 33376,   261,   318,  4451,   764,   220,
           220,   220,   220,  5

In [None]:
def compute_rewards(input_data, query_tensors, response_tensors, score_tensors):
    with torch.no_grad():
        logits, values = model(**input_data) # b, seq, vocab
        ref_logits, _ = sft_model(**input_data)
        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)

        labels = input_data['input_ids'][:, 1:] # b, seq

        logp = torch.gather(logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq
        ref_logp = torch.gather(ref_logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq

        kl = logp - ref_logp
        beta = 0.2
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:,:] = attention_mask[:, 1:]
        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks


In [None]:
logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
print(rewards[0])
print(input_data['input_ids'][0])
print(input_data['attention_mask'][0])

tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.9969, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000], device='cuda:0')
tensor([  271,   257,  5391,    12,    86,  2175,   289,  2954,   286,   257,
         3807,   764,   220,   220,   220, 50256, 50256, 50256, 50256, 50256,
        50256], device='cuda:0')
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       device='cuda:0')


In [None]:
print(masks[0])
print(values[0])

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tensor([-0.0000, -0.0000, -0.0000,  0.0000,  0.0000, -1.6054, -0.1053, -1.0586,
        -1.0855, -0.2100, -1.4899,  3.4527,  0.9113,  1.7163,  0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000], device='cuda:0')


## Compute Advantage

The advantage function estimates how much better an action is compared to the average:
- **Positive advantage**: This action is better than expected
- **Negative advantage**: This action is worse than expected

Uses Generalized Advantage Estimation (GAE) for lower variance.

In [None]:
def masked_mean(values, mask):
    return (values * mask).sum() / mask.sum()

def masked_var(values, mask):
    mean = masked_mean(values, mask)
    centred_values = values - mean
    return masked_mean(centred_values ** 2, mask)

def masked_whiten(values, mask):
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    whitened += mean
    return whitened

def compute_advantage(rewards, values, masks):
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95

    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)

    returns = advantages + values
    return advantages, returns


In [None]:
advantages, returns = compute_advantage(rewards, values, masks)
print(advantages[0])
print(returns[0])

tensor([ 0.1104,  0.0946,  0.0779,  0.0603,  0.0418,  0.9713,  0.1140,  0.6619,
         0.6910,  0.1882,  0.9330, -1.9612, -0.5839, -1.1121,  0.4116,  0.4116,
         0.4116,  0.4116,  0.4116,  0.4116], device='cuda:0')
tensor([ 0.1104,  0.0946,  0.0779,  0.0603,  0.0418, -0.6341,  0.0088, -0.3967,
        -0.3945, -0.0218, -0.5569,  1.4915,  0.3275,  0.6042,  0.4116,  0.4116,
         0.4116,  0.4116,  0.4116,  0.4116], device='cuda:0')


## Mini-batch PPO Training

PPO updates the policy using mini-batches to improve sample efficiency and stability.

### Training Config

In [None]:
# Training hyperparameters
# NOTE: These may need adjustment for Zephyr 7B
learning_rate = 1e-5  # Conservative learning rate for fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f"Optimizer created with lr={learning_rate}")

In [None]:
np.random.permutation(batch_size)

array([13, 17, 20,  7, 26,  1, 25, 23, 11, 21,  4, 14,  6,  3,  9,  2,  0,
       15,  5, 10, 30, 19, 29, 22, 24,  8, 28, 12, 27, 18, 16, 31])

In [None]:
# PPO training configuration
mini_batch_size = 2  # Reduced from 4 for memory efficiency with 7B model
ppo_epochs = 4

cliprange_ratio = 0.2  # PPO clipping range
v_loss_coeff = 0.1     # Value loss coefficient
ratio_threshold = 10   # Threshold to detect unstable training

def compute_loss(old_logprobs, values, logprobs, vpreds, masks, advantages, returns):
    """
    Compute PPO loss with clipping.
    
    Args:
        old_logprobs: Log probabilities from the old policy
        values: Value estimates from the old policy
        logprobs: Log probabilities from the current policy
        vpreds: Value predictions from the current policy
        masks: Attention masks
        advantages: Computed advantages
        returns: Computed returns (advantages + values)
    
    Returns:
        loss: Combined policy and value loss
        v_loss: Value loss component
    """
    ratio = torch.exp(logprobs - old_logprobs)
    pg_loss1 = - ratio * advantages
    pg_loss2 = - torch.clamp(ratio, 1 - cliprange_ratio, 1 + cliprange_ratio) * advantages
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)

    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    loss = pg_loss + v_loss_coeff * v_loss

    avg_ratio = masked_mean(ratio, masks)
    if avg_ratio > ratio_threshold:
        # Unstable training detected - zero out gradients
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0

    return loss, v_loss

def mini_batch_train():
    """Run mini-batch PPO training for multiple epochs."""
    for ep in range(ppo_epochs):
        batch_inds = np.random.permutation(batch_size)

        for start in range(0, batch_size, mini_batch_size):
            end = start + mini_batch_size
            mini_batch_inds = batch_inds[start:end]

            mb_model_inputs = {
                'input_ids': input_data['input_ids'][mini_batch_inds],
                'attention_mask': input_data['attention_mask'][mini_batch_inds]
            }
            mb_logits, mb_vpreds = model(**mb_model_inputs)
            mb_logits = torch.nn.functional.log_softmax(mb_logits[:, :-1, :], dim=-1)
            mb_logprobs = torch.gather(mb_logits, 2, mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)

            loss, loss_v = compute_loss(
                logprobs[mini_batch_inds],
                values[mini_batch_inds],
                mb_logprobs,
                mb_vpreds[:, :-1],
                masks[mini_batch_inds],
                advantages[mini_batch_inds],
                returns[mini_batch_inds]
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('loss/total', loss.item())
    print('mini-batch training finished')

print(f"PPO config: {ppo_epochs} epochs, mini_batch_size={mini_batch_size}")

In [None]:
mini_batch_train()

loss/total -0.6469668745994568
loss/total -1.084718942642212
loss/total -0.9519063234329224
loss/total -0.5013611912727356
loss/total -0.7945360541343689
loss/total -0.4254860281944275
loss/total -1.0511527061462402
loss/total -0.9178730845451355
loss/total -0.31340065598487854
loss/total -1.4144445657730103
loss/total -0.74092698097229
loss/total -1.5609691143035889
loss/total -0.8986895084381104
loss/total -1.1131693124771118
loss/total -1.3536183834075928
loss/total -0.83698570728302
loss/total -0.8659009337425232
loss/total -1.666944980621338
loss/total -0.9409258365631104
loss/total -1.2696202993392944
loss/total -0.6142494678497314
loss/total -1.3109718561172485
loss/total -0.7585713267326355
loss/total -0.9758541584014893
loss/total -1.3880150318145752
loss/total -0.833629310131073
loss/total -0.638252317905426
loss/total -1.1900427341461182
loss/total -0.4314558207988739
loss/total -1.388943076133728
loss/total -1.1693115234375
loss/total -1.4005227088928223
mini-batch training

## Train RLHF

The main training loop:
1. Generate responses for each prompt
2. Score responses with LLM judge
3. Compute rewards (score - KL penalty)
4. Compute advantages (how good was this response?)
5. Update the policy with PPO

In [None]:
num_epochs = 1

for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Generate responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']

        response_tensors = []
        query_response_tensors = []
        score_tensors = []

        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = random.choice(list(range(output_min_length, output_max_length)))
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
                ).squeeze(0)

            response_len = len(query_response) - len(query)
            response_tensors.append(query_response[-response_len:])
            query_response_tensors.append(query_response)

            # Use LLM judge instead of reward model
            with torch.no_grad():
                prompt_text = tokenizer.decode(query, skip_special_tokens=True)
                response_text = tokenizer.decode(query_response[-response_len:], skip_special_tokens=True)
                
                # Get score from LLM judge
                score = get_llm_judge_score(prompt_text, response_text)
                score = torch.tensor(score).to(device)
                
            score_tensors.append(score)

        input_data = data_collator([
            {
                'input_ids': ids,
                'attention_mask': torch.ones_like(ids)
            }
            for ids in query_response_tensors
        ]).to(device)

        # rewards and advantages
        logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
        advantages, returns = compute_advantage(rewards, values, masks)

        # mini batch training
        mini_batch_train()
    print(f'epoch {epoch + 1} finished')

## Validation

Test the trained model on the validation set to see if responses improved.

In [None]:
len(tokenized_dataset_val)

807

In [None]:
val_gen_lengths = [0] * len(tokenized_dataset_val)
for i in range(len(tokenized_dataset_val)):
    val_gen_lengths[i] = random.choice(list(range(output_min_length, output_max_length)))

In [None]:
val_gen_lengths[:10]

[5, 8, 7, 11, 10, 8, 8, 9, 8, 10]

In [None]:
def validate():
    scores = []
    for b, batch in enumerate(val_dataloader):
        # Generate_responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']
        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = val_gen_lengths[b * len(query_tensors) + i]
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
                ).squeeze(0)
            query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
            score = reward_model(query_response_score.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
            score = 2 * (score - 0.5)
            scores.append(score.item())
    print('avg score:', sum(scores) / len(scores))

In [None]:
validate()

avg score: 0.6572876147916621


In [None]:
torch.save(model.state_dict(), 'ppo_model_epoch_1.pt')

In [None]:
model_path = './sft_model_epoch_1'
model = ModelForCausalLMWithValueHead(model_path).to(device)

In [None]:
validate()

avg score: 0.07222306484626571
