# GRPO Tutorial Notebook

This notebook reproduces the code frames from the video tutorial on Group Relative Policy Optimization (GRPO) for small reasoning LMs. Follow each cell to build the environment, model, loss functions, and training loop.

## 1. Install & Imports

Install the package in editable mode and import required modules.

In [None]:
%pip install -e .

import torch
import matplotlib.pyplot as plt
from grpo_slm.env import ReasoningEnv
from grpo_slm.model import ReasoningModel
from grpo_slm.grpo import ppo_loss, grpo_loss

## 2. Define and Test the Environment

In [None]:
# Instantiate and test the environment
env = ReasoningEnv(max_value=10)
prompt = env.reset()
print("Prompt:", prompt)
# Choose a random action to test step()
action = env.action_space.sample()
_, reward, done, info = env.step(action)
print(f"Test step -> action: {action}, reward: {reward}, true answer: {info['true']}")

## 3. Instantiate and Test the Model

In [None]:
# Instantiate model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ReasoningModel('gpt2', device=device)

# Test generation
response = model.generate([prompt], max_new_tokens=5)[0]
print("Generated response:", response)

# Test log_probs
logp = model.log_probs([prompt], [response])[0]
print("Log-prob of the response:", logp.item())

## 4. Loss Functions

In [None]:
# Sample data for loss demonstration
old = torch.tensor([1.0, 2.0, 3.0])
new = torch.tensor([1.2, 1.8, 2.5])
adv = torch.tensor([0.5, -0.2, 1.0])

# Compute PPO and GRPO losses
print("PPO loss:", ppo_loss(old, new, adv).item())
print("GRPO loss:", grpo_loss(old, new, adv).item())

## 5. Training Loop

In [None]:
# Simple one-epoch training demonstration
optimizer = torch.optim.Adam(model.model.parameters(), lr=1e-5)
batch_size = 4

prompts, responses, rewards, old_logps = [], [], [], []
for _ in range(batch_size):
    p = env.reset()
    r = model.generate([p], max_new_tokens=5)[0]
    lp = model.log_probs([p], [r])[0]
    _, rew, _, _ = env.step(int(r.strip()) if r.strip().isdigit() else -1)
    prompts.append(p); responses.append(r); rewards.append(rew); old_logps.append(lp)

rewards = torch.tensor(rewards, device=device)
old_logps = torch.stack(old_logps)
adv = rewards - rewards.mean()
new_logps = model.log_probs(prompts, responses)
loss = grpo_loss(old_logps, new_logps, adv)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Training step -> loss: {loss.item():.4f}, avg reward: {rewards.mean():.4f}")

## 6. Plotting Reward Histogram

In [None]:
# Plot rewards distribution
plt.hist(rewards.cpu().numpy())
plt.title("Reward Distribution")
plt.xlabel("Reward")
plt.ylabel("Count")
plt.show()