In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)

In [None]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Policy model
policy = AutoModelForCausalLM.from_pretrained(model_name)
policy_ref = AutoModelForCausalLM.from_pretrained(model_name)
policy.train()
policy_ref.eval()

# SIMPLIFIED: should be a separate, pretrained model
reward_model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=1
)
reward_model.eval()

In [None]:
prompt = "I am studying"
inputs = tokenizer(
    prompt, return_tensors="pt", padding=True
)
input_ids = inputs["input_ids"]

with torch.no_grad():
    gen_ids = policy.generate(
        input_ids=input_ids,
        max_new_tokens=20,
        do_sample=True,
        top_k=50,
        temperature=1.0,
        pad_token_id=tokenizer.pad_token_id
    )
response_ids = gen_ids[:, input_ids.shape[-1]:]
query_response = torch.cat([input_ids, response_ids], dim=1)

In [None]:
# OVER-SIMPLIFIED:
# 1. Get log probabilities of the generated response, NOT the entire sequence
# 2. Get log probabilities at the time of generation, not resampling
# 3. Only calculate the log probabilities of the sampled tokens (top-k, etc.)
def get_log_prob_sum(model, input_ids):
    labels = input_ids.clone()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss  # average negative log-likelihood
    return -loss  # return log-likelihood

logprob_policy = get_log_prob_sum(policy, query_response)
logprob_ref = get_log_prob_sum(policy_ref, query_response)

# reward score
with torch.no_grad():
    reward_inputs = tokenizer(
        tokenizer.decode(query_response[0], skip_special_tokens=True),
        return_tensors="pt", truncation=True, padding=True
    )
    reward = reward_model(**reward_inputs).logits.squeeze().detach()

In [None]:
# advantage, PPO loss, KL loss
# SIMPLIFIED: value should come from a value model
value = reward.detach()
advantage = reward - value
log_ratio = logprob_policy - logprob_ref
ratio = torch.exp(log_ratio)

# PPO-clip loss
clip_eps = 0.2
loss1 = ratio * advantage
loss2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantage
ppo_clip_loss = -torch.min(loss1, loss2)

# KL loss (optional penalty)
kl_loss = torch.mean(log_ratio**2)

# Loss
kl_coef = 0.01  # KL-Pen
ppo_loss = ppo_clip_loss + kl_coef * kl_loss

In [None]:
ppo_loss

tensor(0.0004)

In [None]:
print("=" * 50)
print(f"[Prompt]: {prompt}")
print(f"[Response]: {tokenizer.decode(response_ids[0], skip_special_tokens=True)}")
print(f"[Reward]: {reward.item():.4f}")
print(f"[Policy logprob]: {logprob_policy.item():.4f} | [Ref logprob]: {logprob_ref.item():.4f}")
print(f"[PPO Loss]: {ppo_loss.item():.4f}")
print("=" * 50)

[Prompt]: I am studying
[Response]:  and my body is so perfect. That means that I feel I have the potential to do what I
[Reward]: 3.6512
[Policy logprob]: -3.2976 | [Ref logprob]: -3.5025
[PPO Loss]: -0.0000
