<a href="https://colab.research.google.com/github/vinodkraman/RL4LLMs/blob/main/grpo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GRPO from Scratch

In [None]:
import math, os, random
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import (
  AutoTokenizer,
  AutoModelForCausalLM,
  AutoModelForSequenceClassification,
)

def set_seed(seed: int = 42):
  import numpy as np
  random.seed(seed); np.random.seed(seed)
  torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

@dataclass
class Batch:
  prompts: List[str]

def collate_prompts(features, prompt_column: str):
  return Batch(prompts=[f[prompt_column] for f in features])

#computes rewards for a batch of responses to prompts
def reward_score(rm, rm_tok, prompt: str, response: str, max_length: int, device: torch.device) -> float:
  text = prompt.strip() + "\n\n" + response.strip()
  toks = rm_tok(text, return_tensors="pt", truncation=True, max_length=max_length, add_special_tokens=True).to(device)
  with torch.no_grad():
      out = rm(**toks).logits
      if out.size(-1) == 1:
          return float(out.squeeze())
      return float(out.squeeze(0).mean(dim=-1))

#generates responses to prompts
def generate_with_scores(model, tokenizer, prompts: List[str], max_new_tokens: int, temperature: float, top_p: float, device: torch.device):
  enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True).to(device)
  with torch.no_grad():
      out = model.generate(
          **enc,
          max_new_tokens=max_new_tokens,
          do_sample=True,
          temperature=temperature,
          top_p=top_p,
          return_dict_in_generate=True,
          output_scores=True,
          pad_token_id=tokenizer.eos_token_id,
      )
  prompt_lens = enc["attention_mask"].sum(dim=1).tolist()
  sequences = out.sequences
  input_len = enc["input_ids"].shape[1]
  texts = []
  for i in range(sequences.size(0)):
      cont = sequences[i, input_len:]
      texts.append(tokenizer.decode(cont, skip_special_tokens=True))
  return out, texts, input_len, prompt_lens


In [None]:
class GRPOTrainer:
  def __init__(self,
                dataset_name="HuggingFaceH4/ultrachat_200k",
                split="train[:1024]",
                prompt_column="text",
                policy_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                ref_name=None,
                reward_name="OpenAssistant/reward-model-deberta-v3-base",
                output_dir="./grpo-policy",
                epochs=1, batch_size=2, group_size=4, mu=3,
                gen_max_new_tokens=64, rm_max_length=512,
                temperature=1.0, top_p=0.95, lr=1e-5,
                clip_ratio=0.2, kl_coef=0.01, max_grad_norm=1.0,
                seed=42, device=None):
      set_seed(seed)
      self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
      self.epochs = epochs
      self.group_size = group_size
      self.gen_max_new_tokens = gen_max_new_tokens
      self.rm_max_length = rm_max_length
      self.temperature = temperature
      self.top_p = top_p
      self.mu = mu
      self.clip_ratio = clip_ratio
      self.kl_coef = kl_coef
      self.max_grad_norm = max_grad_norm
      self.output_dir = output_dir

      self.tokenizer = AutoTokenizer.from_pretrained(policy_name, padding_side="left", use_fast=True, model_max_length=512)
      if self.tokenizer.pad_token is None:
          self.tokenizer.pad_token = self.tokenizer.eos_token

      self.policy = AutoModelForCausalLM.from_pretrained(policy_name).to(self.device)
      self.policy_old = AutoModelForCausalLM.from_pretrained(policy_name).to(self.device)
      for p in self.policy_old.parameters():
          p.requires_grad_(False)

      self.ref = None
      if ref_name is not None and kl_coef > 0.0:
          self.ref = AutoModelForCausalLM.from_pretrained(ref_name).to(self.device)
          for p in self.ref.parameters():
              p.requires_grad_(False)

      self.rm_tok = AutoTokenizer.from_pretrained(reward_name, use_fast=True)
      self.rm = AutoModelForSequenceClassification.from_pretrained(reward_name).to(self.device)
      for p in self.rm.parameters():
          p.requires_grad_(False)

      self.opt = torch.optim.AdamW(self.policy.parameters(), lr=lr)

      self.ds = load_dataset(dataset_name, split=split)
      self.dl = DataLoader(self.ds, batch_size=batch_size, shuffle=True,
                            collate_fn=lambda feats: collate_prompts(feats, prompt_column))

  #computes matrix of logprobs
  def tokens_logprob_from_forward(self, sequences, T: int) -> torch.Tensor:
      """
      Compute the log-probs of the last T continuation tokens with autograd enabled.
      """
      pad_id = self.tokenizer.pad_token_id
      attn = (sequences != pad_id).long()                         # [BxK, L]
      outputs = self.policy(sequences, attention_mask=attn)
      logits = outputs.logits[:, :-1, :]                          # [BxK, L-1, V]
      target = sequences[:, 1:]                                   # [BxK, L-1]

      cont_targets = target[:, -T:]                               # [BxK, T]
      cont_logits  = logits[:, -T:, :]                            # [BxK, T, V]

      logprobs = cont_logits.log_softmax(dim=-1)                  # [BxK, T, V]
      chosen = cont_targets.unsqueeze(-1)                         # [BxK, T, 1]
      logp = logprobs.gather(-1, chosen).squeeze(-1)   # [BxK, T]
      return logp


  #performs a single iteration of GRPO
  def step(self, prompts: List[str]):
    B, K = len(prompts), self.group_size
    prompts_rep = [p for p in prompts for _ in range(K)]

    # 1. Sync old policy ← current policy (before rollout)
    self.policy_old.load_state_dict(self.policy.state_dict())

    # 2. Sample rollouts using OLD policy
    out, responses, input_len, prompt_lens = generate_with_scores(
        self.policy_old, self.tokenizer, prompts_rep,
        self.gen_max_new_tokens, self.temperature, self.top_p, self.device
    )
    sequences = out.sequences                                    # [B*K, L]
    T = len(out.scores)                                          # continuation length

    # 3. Rewards (scalar per [B,K])
    rewards = []
    idx = 0
    for i in range(B):
        row = []
        for j in range(K):
            r = reward_score(self.rm, self.rm_tok, prompts[i], responses[idx],
                             self.rm_max_length, self.device)
            row.append(r); idx += 1
        rewards.append(row)
    rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)  # [B,K]

    # 4. Compute targets (for reuse)
    pad_id = self.tokenizer.pad_token_id
    attn = (sequences != pad_id).long()
    target = sequences[:, 1:]
    cont_targets = target[:, -T:]                                # [B*K, T]

    # 5. Compute logprobs under OLD policy (fixed, no grad)
    with torch.no_grad():
        logits_old = self.policy_old(sequences, attention_mask=attn).logits[:, :-1, :]
        cont_logits_old = logits_old[:, -T:, :]
        logp_old = cont_logits_old.log_softmax(dim=-1)\
                      .gather(-1, cont_targets.unsqueeze(-1))\
                      .squeeze(-1)                               # [B*K, T]

    # 6. KL penalty (depends only on sampled sequences)
    if self.ref is not None and self.kl_coef > 0.0:
        with torch.no_grad():
            ref_logits = self.ref(sequences, attention_mask=attn).logits[:, :-1, :]
            cont_logits_ref = ref_logits[:, -T:, :]              # [B*K, T, V]
            logp_ref = cont_logits_ref.log_softmax(dim=-1)\
                          .gather(-1, cont_targets.unsqueeze(-1))\
                          .squeeze(-1)                           # [B*K, T]
        p_ref = logp_ref.exp()
        # we’ll recompute p_new inside the μ loop
    else:
        p_ref = None

    # 7. Reshape rewards and advantages
    adv = rewards - rewards.mean(dim=1, keepdim=True)            # [B,K]
    adv = adv.unsqueeze(-1)                                      # [B,K,1]

    # 8. Inner loop: μ GRPO updates
    for _ in range(self.mu):
        # a) Compute logprobs under NEW policy
        logp_new = self.tokens_logprob_from_forward(sequences, T)    # [B*K, T]
        logp_new = logp_new.view(B, K, T)
        logp_old_ = logp_old.view(B, K, T)
        p_ref_ = p_ref.view(B, K, T)

        # b) Broadcast advantage
        adv_broadcast = adv.expand_as(logp_new)                      # [B,K,T]


        # c) PPO clipped objective
        ratio   = (logp_new - logp_old_).exp()
        clipped = torch.clamp(ratio, 1.0 - self.clip_ratio, 1.0 + self.clip_ratio)
        obj1 = ratio * adv_broadcast
        obj2 = clipped * adv_broadcast
        pg_loss = -torch.min(obj1, obj2).mean()

        # d) KL penalty (recompute p_new each epoch)
        if p_ref is not None:
            p_new = logp_new.exp()
            ratio_kl = p_ref_ / (p_new + 1e-8)
            kl_tok = ratio_kl - torch.log(ratio_kl + 1e-8) - 1.0
            kl_loss = self.kl_coef * kl_tok.mean()
        else:
            kl_loss = torch.tensor(0.0, device=self.device)

        # e) Total loss
        loss = pg_loss + kl_loss

        # f) Optimize
        self.opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.opt.step()

    # After μ epochs, policy is updated (policy_old will sync next step)
    return {
        "loss": float(loss.item()),
        "pg_loss": float(pg_loss.item()),
        "kl_loss": float(kl_loss.item()),
        "reward_mean": float(rewards.mean().item()),
    }


  def train(self):
    self.policy.train()
    step = 0

    for epoch in range(self.epochs):
        # === Algorithm 1: Step 3 ===
        # Update reference model once per epoch
        if self.ref is not None:
            self.ref.load_state_dict(self.policy.state_dict())

        for batch in self.dl:
            # === Algorithm 1: Step 6 ===
            # Update old policy once per batch
            self.policy_old.load_state_dict(self.policy.state_dict())

            # === Step 7–11 ===
            stats = self.step(batch.prompts)

            step += 1
            if step % 5 == 0:
                print(
                    f"epoch {epoch} step {step} | "
                    f"loss {stats['loss']:.4f} | "
                    f"pg {stats['pg_loss']:.4f} | "
                    f"kl {stats['kl_loss']:.4f} | "
                    f"R {stats['reward_mean']:.3f}"
                )

    # Save final model
    if self.output_dir:
        os.makedirs(self.output_dir, exist_ok=True)
        self.policy.save_pretrained(self.output_dir)
        self.tokenizer.save_pretrained(self.output_dir)



In [None]:
# Example usage (adjust model sizes to your GPU):
trainer = GRPOTrainer(
    dataset_name="HuggingFaceH4/ultrachat_200k",
    split="train_gen",
    prompt_column="prompt",
    policy_name="EleutherAI/pythia-70m",
    ref_name="EleutherAI/pythia-70m",  # set None to disable KL
    reward_name="OpenAssistant/reward-model-deberta-v3-base",
    epochs=1, batch_size=2, group_size=4,
    gen_max_new_tokens=64, rm_max_length=512,
    temperature=1.0, top_p=0.95, lr=1e-5,
    clip_ratio=0.2, kl_coef=1e-5, max_grad_norm=1.0,
)
trainer.train()
print('Saved to', trainer.output_dir)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


epoch 0 step 5 | loss -0.2922 | pg -0.2922 | kl 0.0000 | R -1.541
epoch 0 step 10 | loss -0.4817 | pg -0.4817 | kl 0.0000 | R -1.038
epoch 0 step 15 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R 2.900
epoch 0 step 20 | loss -0.2717 | pg -0.2717 | kl 0.0000 | R 0.817
epoch 0 step 25 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R 5.850
epoch 0 step 30 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R 1.636
epoch 0 step 35 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R -1.101
epoch 0 step 40 | loss 0.0001 | pg -0.0000 | kl 0.0001 | R 2.779
epoch 0 step 45 | loss 0.0001 | pg -0.0000 | kl 0.0001 | R -0.184
epoch 0 step 50 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R 0.526
epoch 0 step 55 | loss 0.0001 | pg -0.0000 | kl 0.0001 | R 1.532
epoch 0 step 60 | loss 0.0001 | pg -0.0000 | kl 0.0001 | R 3.850
epoch 0 step 65 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R 1.552
epoch 0 step 70 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R -2.325
epoch 0 step 75 | loss 0.0000 | pg -0.0000 | kl 0.0000 | R 1.760
epoch 0 step 80 | 