<a href="https://colab.research.google.com/github/paulkroe/minireason/blob/main/grop_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
from transformers import AutoModelForCausalLM, AutoTokenizer

def reward_function(prompt, completion):
    # Simple reward: count unique characters in the completion.
    return len(set(completion))

def generate_completions(model, tokenizer, prompt, num_generations=3, max_length=50, temperature=1.0):
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    outputs = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        do_sample=True,
        num_return_sequences=num_generations,
        temperature=temperature,
    )
    return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

def compute_log_probs(model, input_ids, attention_mask):
    """
    Returns per-token log probabilities.
    Shape: [batch_size, seq_len]
    """
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits  # [B, L, V]
    log_probs = torch.log_softmax(logits, dim=-1)
    # Gather the log probabilities corresponding to the input_ids.
    token_log_probs = torch.gather(log_probs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
    return token_log_probs

def train_step(
    model,
    ref_model,
    tokenizer,
    optimizer,
    prompt,
    num_generations=3,
    epsilon=0.2,
    beta=0.1
):
    """
    A single training step implementing a PPO-like update on a per-token basis.
    The KL divergence is computed per token using:
        KL = exp(ref_logp - cur_logp) - (ref_logp - cur_logp) - 1
    and the PPO objective is computed per token before averaging.
    """
    # 1. Generate completions.
    completions = generate_completions(model, tokenizer, prompt, num_generations=num_generations)

    # 2. Compute rewards & advantages (one reward per generated completion).
    rewards = [reward_function(prompt, comp) for comp in completions]
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    mean_reward = rewards_tensor.mean()
    std_reward = rewards_tensor.std() + 1e-8  # avoid division by zero
    advantages = (rewards_tensor - mean_reward) / std_reward  # one scalar advantage per generation

    # Tokenize prompt to get its length.
    prompt_ids = tokenizer(prompt, return_tensors='pt').input_ids
    prompt_length = prompt_ids.shape[1]

    total_loss = 0.0
    for i, comp in enumerate(completions):
        # 3. Prepare the full input (prompt + completion).
        full_text = prompt + comp[len(prompt):]  # simple concatenation
        full_ids = tokenizer(full_text, return_tensors='pt').input_ids
        attention_mask = torch.ones_like(full_ids)  # assume no padding for simplicity

        # 4. Get per-token log probabilities for the completion part.
        cur_log_probs = compute_log_probs(model, full_ids, attention_mask)[0, prompt_length:]  # shape [T]
        # Use the reference model to get the old log probabilities.
        with torch.no_grad():
            ref_log_probs = compute_log_probs(ref_model, full_ids, attention_mask)[0, prompt_length:]

        # 5. Compute per-token probability ratios and clip them.
        ratio = torch.exp(cur_log_probs - ref_log_probs)  # shape [T]
        clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)

        # The advantage is a single scalar per generation; broadcast it to each token.
        adv = advantages[i]
        per_token_loss1 = ratio * adv
        per_token_loss2 = clipped_ratio * adv
        # Per-token PPO loss (apply min for each token).
        ppo_loss = -torch.min(per_token_loss1, per_token_loss2)

        # 6. Compute per-token KL divergence.
        per_token_kl = torch.exp(ref_log_probs - cur_log_probs) - (ref_log_probs - cur_log_probs) - 1

        # 7. Combine the PPO loss and the weighted KL term (per token).
        token_loss = ppo_loss + beta * per_token_kl

        # 8. For simplicity, assume all tokens are valid (mask of ones).
        completion_mask = torch.ones_like(cur_log_probs)
        loss_i = (token_loss * completion_mask).sum() / completion_mask.sum()
        total_loss += loss_i

    # 9. Average loss over generations.
    total_loss /= num_generations

    # 10. Backpropagation and parameter update.
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    print(f"Prompt: {prompt}")
    print(f"Completions: {completions}")
    print(f"Rewards: {rewards}")
    print(f"Advantages: {advantages.tolist()}")
    print(f"Loss: {total_loss.item()}")
    return total_loss.item()

def main():
    # Load a small model & tokenizer (e.g. GPT-2).
    model_name = "gpt2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.train()

    # Reference model (could be updated less frequently; here we update every step for simplicity).
    ref_model = copy.deepcopy(model)
    ref_model.eval()

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    prompt = "Once upon a time"

    # Run several training steps.
    for step in range(3):
        print(f"\n=== Training Step {step} ===")
        train_step(model, ref_model, tokenizer, optimizer, prompt, num_generations=3, epsilon=0.2, beta=0.1)

if __name__ == "__main__":
    main()


In [None]:
# Load a small model & tokenizer (e.g. GPT-2).
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.train()

# Reference model (could be updated less frequently; here we update every step for simplicity).
ref_model = copy.deepcopy(model)
ref_model.eval()

optimizer = optim.Adam(model.parameters(), lr=1e-4)
prompt = "Once upon a time"

# Run several training steps.
for step in range(3):
    print(f"\n=== Training Step {step} ===")
    train_step(model, ref_model, tokenizer, optimizer, prompt, num_generations=3, epsilon=0.2, beta=0.1)