<a href="https://colab.research.google.com/github/pastrop/kaggle/blob/master/GRPO_Study_Ntbk.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GRPO Study Notebook

This is notebook is created purely as an example.  Don't expect a production quality code here.  This is a byproduct of me reading the Deepseek paper and understanding how the GPRO agorithm works.  I looked at other people's code for it and found it a bit vague  

**Setup**
1.   LLM is used to create training examples for training the Policy (I am using GPT2 and facebook's opt-1.3b just bcs it is easy without GPUs)
2.   There is a pretrained Reward Model.  I Am using a fully trained sentiment classfier bcs I need smth quick and easy.
3. The policy is a trivial 2 layer feed-forward net



In [None]:
%%capture
!pip install torch transformers

# Finetuning LLM with the GRPO

generating training examples using LLM

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn as nn # Import the torch.nn module and alias it as nn
import torch.optim as optim # Import the torch.optim module and alias it as optim
import torch.nn.functional as F

**Using a trained sentiment model as a Reward Model**

In [None]:
# Generate responses function for Policy Network training
def generate_responses(prompt, num_responses=5):
    # Load GPT model & tokenizer
    model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b")
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")

    # Set pad token explicitly
    tokenizer.pad_token = tokenizer.eos_token

    # Tokenize the input with padding and attention mask
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)# max_length is specific to facebook/opt-1.3b
    # Generate multiple responses using sampling
    responses = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=50,
        do_sample=True,
        top_k=50,
        top_p=0.9,
        temperature=0.7,
        num_return_sequences=num_responses,
        pad_token_id=tokenizer.pad_token_id  # Ensures padding works correctly
    )

    # Decode responses (remove prompt from output)
    decoded_responses = [
        tokenizer.decode(output[len(inputs["input_ids"][0]):], skip_special_tokens=True).strip()
        for output in responses
    ]

    return decoded_responses

In [None]:
#test for generate_reponses()
prompts = ["What people like or dislike about working out?",
           "What people like or dislike about deserts?",
           "Whatpeople like or dislike about traveling?",
           "What people like or dislike about New York City?",
           "What people like or dislike about living in France?"]
responses={}
for prompt in prompts:
    responses[prompt] = generate_responses(prompt)

**Reward Model**

In [None]:
%%capture
# Load sentiment classifier
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
reward_model = AutoModelForSequenceClassification.from_pretrained(model_name)
reward_tokenizer = AutoTokenizer.from_pretrained(model_name)

def get_reward_scores(responses):
    """Scores responses using a sentiment classifier."""
    inputs = reward_tokenizer(responses, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():  # No gradient calculation needed
        outputs = reward_model(**inputs)
    scores = torch.softmax(outputs.logits, dim=1)[:, 1]  # Probability of "positive" class
    return scores.detach()  # Detach to prevent computation graph tracking

In [None]:
#Test Run for the Reward Model (This is a test cell just to make sure get_reward_scores works properly)
'''
#Example responses
responses = [
    "I love exercising, it makes me feel amazing!",
    "Exercise is okay, but it's tiring.",
    "I hate exercising, it's the worst!"
]
'''
test_responses = responses["What people like or dislike about living in France?"]
# Get reward scores
reward_scores = get_reward_scores(test_responses)

# Print scores
for i, (test_response, score) in enumerate(zip(test_responses, reward_scores)):
    print(f"Response {i+1}: {test_response}\nScore: {score.item():.3f}\n")


Response 1: I don't know, I don't have any French friends and I don't go there that often.
Score: 0.011

Response 2: I don't live in France but I have family there and I have heard many different things. One of the main things I've heard is that the French are a very laid back people.
Score: 0.999

Response 3: I’ve been living in France for a little over a year now. I like it a lot. I think the French people are friendly, helpful, and patient. I like
Score: 1.000

Response 4: You might be surprised to learn that France is a popular destination for people from all over the world. There are many reasons why people from all over the world love France, but I�
Score: 1.000

Response 5: I like living in France. I think the people are friendly and the French language is fun to learn.   I dislike the weather. The heat and humidity are so bad. I just
Score: 0.004



training the target model

In [None]:
# Load GPT-2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Set padding token and enforce left-padding
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # Ensures padding is on the left for GPT-2

# Optimizer for fine-tuning GPT-2
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Simulated training batch
texts = ["Do you like living in Paris?", "What are advantages and disadvantages of the paleo diet?"]
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

# Ensure device compatibility (if using GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {key: val.to(device) for key, val in inputs.items()}

# Step 1: Generate responses WITH GRADIENT TRACKING
outputs = model.generate(
    inputs["input_ids"],
    attention_mask=inputs["attention_mask"],  # 🔥 Fix missing attention mask
    max_length=50,
    return_dict_in_generate=True,
    output_hidden_states=True  # ✅ Ensures we get hidden states
)


# Step 2: Compute response scores using the reward network
response_scores = get_reward_scores(outputs)

# Step 3: Compute policy loss (GRPO-style clipped loss)
baseline = response_scores.mean()
advantage = response_scores - baseline
policy_loss = -torch.min(response_scores * advantage, torch.clamp(response_scores, 0.8, 1.2) * advantage).mean()

# 🔥 Step 4: Backpropagate loss INTO GPT-2 parameters 🔥
optimizer.zero_grad()
policy_loss.backward()  # ✅ This now updates GPT-2's parameters
optimizer.step()

print(f"Policy Loss: {policy_loss.item()}")

Added KL-term

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm

# Assume models and tokenizer are initialized
# target_model: trainable GPT-2 model
# ref_model: frozen GPT-2 model (base)
# reward_model: function taking (query, response) -> reward (float or torch.Tensor)
# tokenizer: GPT2Tokenizer
# device: torch.device("cuda" or "cpu")

target_model.to(device)
ref_model.to(device)
ref_model.eval()

optimizer = Adam(target_model.parameters(), lr=5e-6)
kl_beta = 0.1  # KL divergence penalty coefficient
num_epochs = 3
batch_size = 4
num_samples_per_query = 5

def generate_response(model, prompt, max_length=50, num_samples=1):
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
    responses = []
    for _ in range(num_samples):
        output = model.generate(
            input_ids,
            do_sample=True,
            max_length=max_length,
            pad_token_id=tokenizer.eos_token_id,
            top_k=50,
            top_p=0.95
        )
        responses.append(output[0])
    return responses

def compute_log_probs(model, input_ids):
    # Get log-probs for next-token prediction
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        log_probs = -outputs.loss * input_ids.size(1)  # Approximate total log-prob
    return log_probs

for epoch in range(num_epochs):
    for query_batch in dataloader:  # assumes `dataloader` yields batches of queries
        optimizer.zero_grad()
        batch_loss = 0.0

        for query in query_batch:
            prompt_text = query["text"]
            responses = generate_response(target_model, prompt_text, num_samples=num_samples_per_query)

            rewards = []
            log_probs_target = []
            log_probs_ref = []

            for response in responses:
                prompt_response_ids = response.unsqueeze(0)  # (1, seq_len)

                # Reward
                decoded_response = tokenizer.decode(response, skip_special_tokens=True)
                reward = reward_model(prompt_text, decoded_response)
                rewards.append(reward)

                # Log-probs from target and reference
                log_prob_target = compute_log_probs(target_model, prompt_response_ids)
                log_prob_ref = compute_log_probs(ref_model, prompt_response_ids)

                log_probs_target.append(log_prob_target)
                log_probs_ref.append(log_prob_ref)

            # Convert to tensors
            rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
            log_probs_target = torch.stack(log_probs_target)
            log_probs_ref = torch.stack(log_probs_ref)

            # Compute KL divergence
            kl_divs = log_probs_target - log_probs_ref  # this is actually reverse KL for tokens

            # GRPO Loss: Negative reward-weighted log-probs + KL penalty
            # Use softmax baseline normalization for variance reduction (optional)
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)

            loss = -torch.mean(rewards * log_probs_target) + kl_beta * torch.mean(kl_divs)
            batch_loss += loss

        batch_loss /= batch_size
        batch_loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} | Loss: {batch_loss.item():.4f}")


# Experiments

training policy network with KL term

In [None]:
# Load GPT model & tokenizer
gpt_model = AutoModelForCausalLM.from_pretrained("gpt2")
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token # Use eos_token as pad_token for gpt2

'''
Logits Reshaping to avoid the dimension mismatch error:
	1.	Reshape new_logits:Inside the compute_kl_divergence function, the new_logits tensor is reshaped using unsqueeze(-1)
    and repeat to match the dimensions of the original_logits. This ensures that both tensors have compatible shapes for calculating the
    KL divergence.
	2.	Apply softmax after reshaping: The softmax function is applied to new_logits after it has been reshaped t
    o match the dimensions of original_logits. This ensures that the KL divergence is calculated between probability distributions
    with the correct shape.
'''



def compute_kl_divergence(original_logits, new_logits):
    """Compute KL divergence between the original GPT-2 outputs and the policy network outputs."""
    # Reshape new_logits to match original_logits
    new_logits = new_logits.unsqueeze(-1).repeat(1, original_logits.shape[1])

    original_probs = F.softmax(original_logits, dim=-1)
    new_probs = F.softmax(new_logits, dim=-1)  # Apply softmax to new_logits after reshaping
    kl_div = F.kl_div(new_probs.log(), original_probs, reduction="batchmean")  # KL(P || Q)
    return kl_div

def train_policy_network_with_kl(responses, reward_scores, beta=0.01):
    """Updates the policy network using GRPO-style policy gradients with KL divergence."""

    # Get response embeddings from GPT-2
    response_inputs = gpt_tokenizer(responses, return_tensors="pt", padding=True, truncation=True)
    # Ensure all token IDs are within the model's vocabulary
    response_inputs["input_ids"] = response_inputs["input_ids"].clip(0, gpt_model.config.vocab_size -1)
    response_embeddings = gpt_model.transformer.wte(response_inputs["input_ids"]).mean(dim=1)  # Mean token embedding

    # Get logits from the original GPT-2 model (before policy updates)
    with torch.no_grad():
        original_logits = gpt_model(response_inputs["input_ids"]).logits.mean(dim=1)

    # Get predicted scores from the policy network
    predicted_scores = policy_net(response_embeddings).squeeze()

    # Compute Advantage (A = reward - baseline)
    baseline = reward_scores.mean()
    advantage = reward_scores - baseline

    # Compute policy loss (GRPO-style clipped loss)
    clip_ratio = 0.2
    policy_loss = -torch.min(
        predicted_scores * advantage,
        torch.clamp(predicted_scores, 1 - clip_ratio, 1 + clip_ratio) * advantage
    ).mean()

    # Compute KL divergence loss
    new_logits = policy_net(response_embeddings)  # Policy's logits
    kl_loss = compute_kl_divergence(original_logits, new_logits) # Calculate KL divergence

    # Final loss with KL penalty
    total_loss = policy_loss + beta * kl_loss

    # Backpropagation
    policy_optimizer.zero_grad()
    total_loss.backward()
    policy_optimizer.step()

    return total_loss.item(), kl_loss.item()

In [None]:
# Training loop with KL regularization
num_epochs = 5
for epoch in range(num_epochs):
    training_responses = responses[prompts[epoch]]
    reward_scores = get_reward_scores(training_responses)
    loss, kl = train_policy_network_with_kl(training_responses, reward_scores)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss:.4f}, KL Divergence: {kl:.4f}")


Epoch 1/5 - Loss: 0.0732, KL Divergence: 4.6478
Epoch 2/5 - Loss: 0.1697, KL Divergence: 4.5876
Epoch 3/5 - Loss: 0.1655, KL Divergence: 4.6667
Epoch 4/5 - Loss: 0.1732, KL Divergence: 4.2839
Epoch 5/5 - Loss: 0.2287, KL Divergence: 4.7516


Train the target model using the Policy Network

In [None]:
# Load GPT-2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Set padding token and enforce left-padding
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # Ensures padding is on the left for GPT-2

# Optimizer for fine-tuning GPT-2
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Simulated training batch
texts = ["Do you like living in Paris?", "What are advantages and disadvantages of the paleo diet?"]
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

# Ensure device compatibility (if using GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {key: val.to(device) for key, val in inputs.items()}

# Step 1: Generate responses WITH GRADIENT TRACKING
outputs = model.generate(
    inputs["input_ids"],
    attention_mask=inputs["attention_mask"],  # 🔥 Fix missing attention mask
    max_length=50,
    return_dict_in_generate=True,
    output_hidden_states=True  # ✅ Ensures we get hidden states
)

# Extract hidden states from outputs
hidden_states = outputs.hidden_states  # This is a tuple of layers' hidden states
last_layer_hidden_states = hidden_states[-1]  # Get last layer's hidden states

# Get the hidden states of the final generated tokens
last_hidden_state = last_layer_hidden_states[-1]  # 🔥 Extract tensor, not tuple
response_embeddings = last_hidden_state.mean(dim=1)  # Mean over sequence tokens

# Step 3: Define a simple policy network (if not already defined)
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)  # Output: scalar score for each response
        )

    def forward(self, x):
        return self.fc(x)

policy_net = PolicyNetwork(input_dim=768).to(device)  # Ensure it's on the right device

# Step 4: Compute response scores using the policy network
response_scores = policy_net(response_embeddings)

# Step 5: Compute policy loss (GRPO-style clipped loss)
baseline = response_scores.mean()
advantage = response_scores - baseline
policy_loss = -torch.min(response_scores * advantage, torch.clamp(response_scores, 0.8, 1.2) * advantage).mean()

# 🔥 Step 6: Backpropagate loss INTO GPT-2 parameters 🔥
optimizer.zero_grad()
policy_loss.backward()  # ✅ This now updates GPT-2's parameters
optimizer.step()

print(f"Policy Loss: {policy_loss.item()}")


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Policy Loss: 0.2255459576845169


 (this is a untrained neural net used as an example, it will give random rewards) to Score Responses

In [None]:
class RewardModel(nn.Module):
    """Simple reward model that scores responses."""
    def __init__(self, embedding_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),  # Output: a single score per response
            nn.Sigmoid()  # Normalize score between 0 and 1
        )

    def forward(self, response_embedding):
        return self.fc(response_embedding)

# Instantiate reward model
reward_model = RewardModel(embedding_dim=768)  # Assuming GPT's embeddings

# Convert responses to embeddings
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
response_inputs = tokenizer(decoded_responses, return_tensors="pt", padding=True, truncation=True)
response_embeddings = model.transformer.wte(response_inputs["input_ids"]).mean(dim=1)  # Averaging token embeddings

# Score each response
reward_scores = reward_model(response_embeddings).squeeze()

# Print scores
for i, (response, score) in enumerate(zip(decoded_responses, reward_scores.tolist())):
    print(f"Response {i+1}: {response} | Score: {score:.3f}")

toy example of the policy network

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.fc(state)

def compute_advantage(rewards, values, gamma=0.99):
    """Computes advantage estimates using reward and value function."""
    returns = []
    advs = []
    G = 0
    for t in reversed(range(len(rewards))):
        G = rewards[t] + gamma * G  # Compute return
        returns.insert(0, G)
        advs.insert(0, G - values[t])  # Advantage = Return - Value Estimate
    return torch.tensor(advs, dtype=torch.float32)

# Hyperparameters
epsilon = 0.2
gamma = 0.99
learning_rate = 0.01

# Sample data (dummy example)
state_dim = 4
action_dim = 2
states = torch.rand((5, state_dim))  # 5 sample states
actions = torch.tensor([0, 1, 0, 1, 0])  # Actions taken
old_probs = torch.tensor([0.4, 0.6, 0.5, 0.7, 0.5])  # Old policy probabilities
rewards = [1, 0, 1, 1, 0]  # Rewards received
values = [0.5, 0.4, 0.6, 0.7, 0.3]  # Value estimates

# Initialize network and optimizer
policy_net = PolicyNetwork(state_dim, action_dim)
optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)

# Compute advantage
advantages = compute_advantage(rewards, values, gamma)

# Compute new policy probabilities
new_probs = policy_net(states).gather(1, actions.view(-1, 1)).squeeze()

# Compute probability ratio
ratios = new_probs / old_probs

# Compute clipped and unclipped loss
clipped_ratios = torch.clamp(ratios, 1 - epsilon, 1 + epsilon)
loss = -torch.min(ratios * advantages, clipped_ratios * advantages).mean()

# Perform gradient ascent
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Policy Loss: {loss.item()}")


Train the Policy with GRPO
Train the policy network to predict higher scores for better responses.
Update it using a reinforcement learning algorithm like GRPO.
Example: Training the Policy Network

Using the Policy for Action Selection
Once the policy is trained, it can be used to select actions in the environment.

Example: Action Selection in a Trained Policy

In [None]:
import torch

def select_action(policy_net, state):
    """Selects an action based on the trained policy."""
    with torch.no_grad():  # No gradients needed for inference
        action_probs = policy_net(state)
        action = torch.multinomial(action_probs, 1)  # Sample from the policy distribution
    return action.item()

# Example usage:
state = torch.rand((1, 4))  # Example state (assuming 4D input)
action = select_action(policy_net, state)
print(f"Selected Action: {action}")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Simulated training data (for visualization purposes)
num_epochs = 50
loss_values = []
kl_values = []
grad_norms = []

# Dummy policy network (small for visualization purposes)
policy_net = nn.Sequential(
    nn.Linear(768, 64),
    nn.ReLU(),
    nn.Linear(64, 1)
)
policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-5)

for epoch in range(num_epochs):
    # Simulated loss and KL divergence values
    policy_loss = torch.rand(1).item() * 2  # Randomized for visualization
    kl_loss = torch.rand(1).item() * 0.2
    total_loss = policy_loss + kl_loss

    # Simulate gradient computation
    policy_optimizer.zero_grad()
    total_loss_tensor = torch.tensor(total_loss, requires_grad=True)
    total_loss_tensor.backward()

    # Compute gradient norm
    total_norm = 0
    for param in policy_net.parameters():
        if param.grad is not None:
            total_norm += param.grad.norm().item()
    grad_norms.append(total_norm)

    # Apply gradient step
    policy_optimizer.step()

    # Store values for plotting
    loss_values.append(total_loss)
    kl_values.append(kl_loss)

# Plot the results
fig, ax1 = plt.subplots()
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss / KL Divergence", color="tab:blue")
ax1.plot(loss_values, label="Total Loss", color="tab:red")
ax1.plot(kl_values, label="KL Divergence", color="tab:blue", linestyle="dashed")
ax1.legend(loc="upper right")
ax1.tick_params(axis="y", labelcolor="tab:blue")

ax2 = ax1.twinx()
ax2.set_ylabel("Gradient Norms", color="tab:green")
ax2.plot(grad_norms, label="Gradient Norms", color="tab:green", linestyle="dotted")
ax2.tick_params(axis="y", labelcolor="tab:green")
ax2.legend(loc="lower right")

plt.title("Policy Training: Loss, KL Divergence & Gradient Norms")
plt.show()


Test of using GPT2 to generate multiple responses based on the promt

In [None]:
# Load GPT model & tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Define a question (prompt)
prompt = "What people like or dislike about working out?"

# Tokenize the input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

# Generate multiple responses using sampling
num_responses = 5  # Generate 5 different completions
responses = model.generate(
    input_ids,
    max_length=50,
    do_sample=True,  # Enables sampling instead of greedy decoding
    top_k=50,  # Consider top 50 tokens at each step
    top_p=0.9,  # Nucleus sampling: keeps top tokens contributing to 90% probability
    temperature=0.7,  # Controls randomness (lower = more deterministic)
    num_return_sequences=num_responses  # Generates multiple responses
)

# Decode and print responses
#decoded_responses = [tokenizer.decode(output, skip_special_tokens=True) for output in responses]
#Decode $ remove the prompt &print
decoded_responses = [
    tokenizer.decode(output[len(input_ids[0]):], skip_special_tokens=True).strip()
    for output in responses
]
for i, response in enumerate(decoded_responses):
    print(f"Response {i+1}: {response}")

example of not working code

🔴 Problem in the Code
The policy_loss.backward() only computes gradients for the policy network because response_scores = policy_net(response_embeddings.mean(dim=1)) does not involve GPT-2 parameters.

torch.no_grad() disabled gradient tracking for GPT-2’s outputs.
response_embeddings = model.transformer.wte(outputs): Here, we used frozen embeddings, which do not track gradients.
Since these embeddings were detached from GPT-2, policy_loss.backward() did not propagate gradients into GPT-2—only into the policy network.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the LLM (e.g., GPT-2 for simplicity)
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Define a simple policy network that scores responses
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)  # Output: scalar score for each response
        )

    def forward(self, x):
        return self.fc(x)

# Instantiate the policy network
policy_net = PolicyNetwork(input_dim=768)  # GPT's embedding size
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Simulated training batch
texts = ["What is the capital of France?", "How does photosynthesis work?"]
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

# Step 1: Generate responses from the LLM
with torch.no_grad():
    outputs = model.generate(inputs["input_ids"], max_length=50)

# Step 2: Get response embeddings (used as input to policy network)
response_embeddings = model.transformer.wte(outputs)  # Word token embeddings
response_scores = policy_net(response_embeddings.mean(dim=1))  # Score each response

# Step 3: Compute policy loss (GRPO-style clipped loss)
baseline = response_scores.mean()  # Baseline for advantage calculation
advantage = response_scores - baseline  # Compute advantage function
policy_loss = -torch.min(response_scores * advantage, torch.clamp(response_scores, 0.8, 1.2) * advantage).mean()

# Step 4: Backpropagate the loss to fine-tune the LLM
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()

print(f"Policy Loss: {policy_loss.item()}")

In [None]:
#test code to delete:

# Define a simple policy network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)  # Output: scalar score for each response
        )

    def forward(self, x):
        return self.fc(x)

# Instantiate the policy network (Assumed Pretrained)
policy_net = PolicyNetwork(input_dim=768)  # GPT-2 embedding size

# Optimizer for fine-tuning GPT-2
optimizer = optim.Adam(model.parameters(), lr=1e-5)