### 1. Load Model
- Gemma 1B

In [262]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "google/gemma-3-1b-it"
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

print(f"Using device: {device}")
print(f"Model name: {model_name}")

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Using device: mps
Model name: google/gemma-3-1b-it


In [263]:
import numpy as np

def generate_response(prompt, num_return_sequences=2):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs, 
        max_new_tokens=10,
        temperature=0.5,      # sample generations
        return_dict_in_generate=True,
        output_scores=True,
        num_return_sequences=num_return_sequences,
    )
    transition_scores = model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )

    input_length = inputs.input_ids.shape[1]
    generated_tokens = outputs.sequences[:, input_length:]
    for i in range(generated_tokens.shape[0]):
        sum_logits = 0
        for tok, score in zip(generated_tokens[i], transition_scores[i]):
            # | token | token string | logits | probability
            score = score.cpu().numpy()
            sum_logits += score
            print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score:.4f} | {np.exp(score):.2%}")

        response = tokenizer.decode(generated_tokens[i], skip_special_tokens=True)
        print(f"Response: {response}")
        print(f"Probability of the response: {np.exp(sum_logits):.2%}")
    
    return response

print(generate_response("Tell me a joke."))


|   108 | 

       | 0.0000 | 100.00%
| 11355 | Why      | 0.0000 | 100.00%
|  1602 |  did     | 0.0000 | 100.00%
|   506 |  the     | 0.0000 | 100.00%
| 30979 |  bicycle | -1.0309 | 35.67%
|  3798 |  fall    | 0.0000 | 100.00%
|  1024 |  over    | 0.0000 | 100.00%
| 236881 | ?        | 0.0000 | 100.00%
|   108 | 

       | -0.0537 | 94.77%
| 17574 | Because  | 0.0000 | 100.00%
Response: 

Why did the bicycle fall over?

Because
Probability of the response: 33.80%
|   108 | 

       | 0.0000 | 100.00%
| 11355 | Why      | 0.0000 | 100.00%
|  1602 |  did     | 0.0000 | 100.00%
|   506 |  the     | 0.0000 | 100.00%
| 55134 |  scare   | -0.4411 | 64.33%
| 47129 | crow     | 0.0000 | 100.00%
|  3345 |  win     | 0.0000 | 100.00%
|   614 |  an      | 0.0000 | 100.00%
|  8054 |  award   | 0.0000 | 100.00%
| 236881 | ?        | 0.0000 | 100.00%
Response: 

Why did the scarecrow win an award?
Probability of the response: 64.33%


Why did the scarecrow win an award?


In [93]:
from transformers.generation.configuration_utils import GenerationConfig

gc = GenerationConfig(
    do_sample=True
)

In [264]:
sample_input = "Tell me a joke."
input_pt = tokenizer(sample_input, return_tensors="pt").to(device)
input_len = input_pt.input_ids.shape[1]

sample_seqs =model.generate(
    **input_pt,
    max_new_tokens=10,
    temperature=0.5,
    return_dict_in_generate=True,
    output_scores=True,
    num_return_sequences=1,
)

transition_scores = model.compute_transition_scores(
    sample_seqs.sequences, sample_seqs.scores, normalize_logits=True
)

for i in range(sample_seqs.sequences.shape[0]):
    print(f"Sequence {i+1}: {tokenizer.decode(sample_seqs.sequences[i], skip_special_tokens=True)}")
    print(f"Probability: {torch.exp(transition_scores[i].sum()):.2%}")
    print()


Sequence 1: Tell me a joke.

Why did the bicycle fall over?
...
Probability: 0.79%



In [265]:
for i in range(sample_seqs.sequences.shape[0]):  # no. of sequences
    sum_log_prob = 0
    seq_str = ""
    for j in range(sample_seqs.sequences[i].shape[0]-input_len):  # output time step
        # max_idx = torch.argmax(sample_seqs.scores[j][i])  # list of score per time stamp
        max_idx = sample_seqs.sequences[i][j+input_len] # sampled token
        log_probs = torch.log_softmax(sample_seqs.scores[j][i], dim=-1)
        sum_log_prob += log_probs[max_idx]
        gen_token = tokenizer.decode(max_idx)
        seq_str += gen_token
        print(f"\"{gen_token}\"\t| token {max_idx} prob: {torch.exp(log_probs[max_idx]):.2%} | log prob: {log_probs[max_idx]:.4}")
    print(f"Sequence {i+1} log prob: {sum_log_prob:.4}")
    print(f"Sequence {i+1} prob: {torch.exp(sum_log_prob):.2%}")
    print(f"Sequence {i+1} string: {seq_str}")
    print("-"*100)





"

"	| token 108 prob: 100.00% | log prob: 0.0
"Why"	| token 11355 prob: 100.00% | log prob: 0.0
" did"	| token 1602 prob: 100.00% | log prob: 0.0
" the"	| token 506 prob: 100.00% | log prob: 0.0
" bicycle"	| token 30979 prob: 35.67% | log prob: -1.031
" fall"	| token 3798 prob: 100.00% | log prob: 0.0
" over"	| token 1024 prob: 100.00% | log prob: 0.0
"?"	| token 236881 prob: 100.00% | log prob: 0.0
"
"	| token 107 prob: 5.23% | log prob: -2.951
"..."	| token 1390 prob: 42.42% | log prob: -0.8576
Sequence 1 log prob: -4.84
Sequence 1 prob: 0.79%
Sequence 1 string: 

Why did the bicycle fall over?
...
----------------------------------------------------------------------------------------------------


-------------------

### 3. Reward function

- Reward function for wordle


In [3]:
def reward_wordle(y_true: str, y_pred: str) -> float:
    """
    Reward function to calculate reward for wordle guess.
    Heuristic:
    if y_pred is invalid (!=5 or non-alpha) : 0
    for each position correct : 0.1 x n
    for each alpha correct in wrong position: 0.05 x n
    if extact match : 1
    """
    y_true = y_true.strip().upper()
    y_pred = y_pred.strip().upper()
    # Check validity
    if len(y_pred) != 5 or not y_pred.isalpha():
        return 0.0
    if y_pred == y_true:
        return 1.0
    reward = 0.0
    # Count correct positions
    for i in range(5):
        if y_pred[i] == y_true[i]:
            reward += 0.1
    # Count correct letters in wrong positions
    # To avoid double-counting, mark matched positions
    true_counts = {}
    pred_counts = {}
    for i in range(5):
        if y_pred[i] != y_true[i]:
            true_counts[y_true[i]] = true_counts.get(y_true[i], 0) + 1
            pred_counts[y_pred[i]] = pred_counts.get(y_pred[i], 0) + 1
    for letter in pred_counts:
        if letter in true_counts:
            reward += 0.05 * min(pred_counts[letter], true_counts[letter])
    return reward


In [37]:
def extract_guess_from_text(text):
    """
    Extracts the guess from between <guess> and </guess> tags in the input text.
    There should be only one pair of <guess>...</guess> tags.
    Returns the guess as a string, stripped of whitespace.
    Raises ValueError if the tags are missing or if there are multiple pairs.
    """
    import re
    matches = re.findall(r'<guess>(.*?)</guess>', text, re.DOTALL)
    if len(matches) != 1:
        raise ValueError(f"Expected exactly one <guess>...</guess> tag, found {len(matches)}.")
    return matches[0].strip()


ttt = """<think> some dumb thinking </think>
<guess> HELLO </guess>
"""
extract_guess_from_text(ttt)


'HELLO'

In [4]:
print(reward_wordle("CRANE", "CRANE"))  # exact match, expect 1.0
print(reward_wordle("CRANE", "CRONY"))  # 3 correct positions (C, R, N), expect 0.3
print(reward_wordle("CRANE", "CANER"))  # 1 correct position (C), 4 correct letters in wrong positions, expect 0.1 + 0.05*4 = 0.3
print(reward_wordle("CRANE", "PLANT"))  # 1 correct position (A), 2 correct letters in wrong positions (N, E), expect 0.1 + 0.05*2 = 0.2
print(reward_wordle("CRANE", "ABCDE"))  # 1 correct position (E), 2 correct letters in wrong positions (C, A), expect 0.1 + 0.05*2 = 0.2
print(reward_wordle("CRANE", "12345"))  # invalid guess, expect 0.0
print(reward_wordle("CRANE", "CRAN"))   # invalid guess (length 4), expect 0.0


1.0
0.30000000000000004
0.3
0.2
0.2
0.0
0.0


----------------------
### 4. GRPO

- implement GRPO loss function


In [362]:
def print_grad_sum(model, verbose=False):
    grads_norms = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            # if grad is None, assume 0
            if param.grad is None:
                param.grad = torch.zeros_like(param)
            grads_norms.append(param.grad.norm().item())
            if verbose:
                print(f"{name} grad: {param.grad.norm().item():.6f}")
    if len(grads_norms) > 0:
        grads_norms = np.array(grads_norms) ** 2
        model_grad_norm = np.sqrt(grads_norms.sum())
        print(f"Model Gradient Norm ----> {model_grad_norm:.6f}")
    else:
        print("No trainable parameters found")

print_grad_sum(model, verbose=True)

model.layers.25.self_attn.q_proj.weight grad: 0.065693
model.layers.25.self_attn.k_proj.weight grad: 0.056480
model.layers.25.self_attn.v_proj.weight grad: 0.149816
model.layers.25.self_attn.o_proj.weight grad: 0.114383
model.layers.25.self_attn.q_norm.weight grad: 0.000832
model.layers.25.self_attn.k_norm.weight grad: 0.001109
model.layers.25.mlp.gate_proj.weight grad: 0.186269
model.layers.25.mlp.up_proj.weight grad: 0.162070
model.layers.25.mlp.down_proj.weight grad: 0.946478
model.layers.25.input_layernorm.weight grad: 0.001677
model.layers.25.post_attention_layernorm.weight grad: 0.000125
model.layers.25.pre_feedforward_layernorm.weight grad: 0.013359
model.layers.25.post_feedforward_layernorm.weight grad: 0.000112
Model Gradient Norm ----> 1.000000


In [295]:
for name, param in model.named_parameters():
    print(f"{name}, trainable: {param.requires_grad}, grad: {param.grad.norm().item():.6f}, no. of params: {param.numel()}")

model.embed_tokens.weight, trainable: False, grad: 0.000000, no. of params: 301989888
model.layers.0.self_attn.q_proj.weight, trainable: False, grad: 0.000000, no. of params: 1179648
model.layers.0.self_attn.k_proj.weight, trainable: False, grad: 0.000000, no. of params: 294912
model.layers.0.self_attn.v_proj.weight, trainable: False, grad: 0.000000, no. of params: 294912
model.layers.0.self_attn.o_proj.weight, trainable: False, grad: 0.000000, no. of params: 1179648
model.layers.0.self_attn.q_norm.weight, trainable: False, grad: 0.000000, no. of params: 256
model.layers.0.self_attn.k_norm.weight, trainable: False, grad: 0.000000, no. of params: 256
model.layers.0.mlp.gate_proj.weight, trainable: False, grad: 0.000000, no. of params: 7962624
model.layers.0.mlp.up_proj.weight, trainable: False, grad: 0.000000, no. of params: 7962624
model.layers.0.mlp.down_proj.weight, trainable: False, grad: 0.000000, no. of params: 7962624
model.layers.0.input_layernorm.weight, trainable: False, grad:

In [300]:
# set only last layer of model to trainable
# set model params to non-trainable

for name, param in model.named_parameters():
    if "model.layers.25" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

print_grad_sum(model, verbose=True)


model.layers.25.self_attn.q_proj.weight grad: 0.000000
model.layers.25.self_attn.k_proj.weight grad: 0.000000
model.layers.25.self_attn.v_proj.weight grad: 0.000000
model.layers.25.self_attn.o_proj.weight grad: 0.000000
model.layers.25.self_attn.q_norm.weight grad: 0.000000
model.layers.25.self_attn.k_norm.weight grad: 0.000000
model.layers.25.mlp.gate_proj.weight grad: 0.000000
model.layers.25.mlp.up_proj.weight grad: 0.000000
model.layers.25.mlp.down_proj.weight grad: 0.000000
model.layers.25.input_layernorm.weight grad: 0.000000
model.layers.25.post_attention_layernorm.weight grad: 0.000000
model.layers.25.pre_feedforward_layernorm.weight grad: 0.000000
model.layers.25.post_feedforward_layernorm.weight grad: 0.000000
Grad norm mean ----> 0.000000


In [354]:
def get_model_log_prob(model, seq_ids, output_masks):
    """
    Calculate model sequence log probability of OUTPUT given INPUT: P(output | input)
    
    Args:
        model: Model to use for inference
        seq_ids: Tensor of shape (B, seq_len) containing token ids of generated sequence belonging to same input
        output_masks: Tensor of shape (B, seq_len) containing mask for ouptput positions. (input and padding positions in seq_ids are set to 0)
    
    Returns:
        Tensor of shape (B,) containing log probabilities for each outputsequence in batch
        
    Note:
        Uses model.forward() to get logits for sequence, then uses log_softmax to get log probabilities.
        Using model.forward() helps track gradients of output tensor for backprop.
    """
    # opmask: 0 0 0 0 0 1 1 1 1 1 1     (masks of output positions, ignore padding tokens)
    # input : i n p u t o u t p u t     (model input seq)
    # output: n p u t o u t p u t -     (get logit for output tokens, ie. model probablity)
    # masks:  0 0 0 0 1 1 1 1 1 1 0     (left shift by 1 to get logits for output tokens)
    
    attention_mask = torch.ones(1, seq_ids.shape[-1]).to(device)
    # forward pass
    output = model(
        input_ids=seq_ids,
        attention_mask=attention_mask,
    )
    logits = output.logits
    token_log_prob = torch.log_softmax(logits, dim=-1)              # (B, seq_len, V)  -> log_prob for each token in seq
    output_indices = torch.roll(seq_ids, shifts=-1, dims=-1)        # (B, seq_len) -> left shift by 1 to get indices for output tokens
    token_log_prob = token_log_prob.gather(
        index=output_indices.unsqueeze(-1),                         # (B, seq_len, 1)  -> (B, seq_len, 1) 
        dim=-1
    )                                                               # (B, seq_len, 1)  -> log_prob for seq tokens
    token_log_prob = token_log_prob.squeeze()                                   # (B, seq_len)
    masks = torch.roll(output_masks, shifts=-1, dims=-1)            # left shift by 1
    masks[:, -1] = 0                                                # pad last token with 0 as it is not output token
    output_token_log_prob = token_log_prob * masks                  # (B, seq_len)  -> log_prob for output tokens     
    output_log_prob = output_token_log_prob.sum(dim=-1)             # (B, 1)  -> sum of log_prob for output tokens
    return output_log_prob



def get_group_relative_reward_advantage(rewards: torch.Tensor) -> torch.Tensor:
    """
    Calculate reward advantage of sequence with "Group Relevative" reward function.
    
    Args:
        rewards: Tensor of shape (B,) containing reward values for generated sequences belonging to same input
    
    Returns:
        Tensor of shape (B,) containing reward advantage based on group relative reward function.
    """
    mean_reward = rewards.mean()
    std_reward = rewards.std()
    advantages = (rewards - mean_reward) / std_reward
    return advantages


def get_kl_divergence(curr_log_prob, ref_log_prob):
    """
    Calculate KL divergence between model and ref_model for sequence given input.

    Args:
        curr_log_prob: Tensor of shape (B,) containing log probabilities for current model
        ref_log_prob: Tensor of shape (B,) containing log probabilities for reference model

    Returns:
        Tensor of shape (B,) containing KL divergence for each outputsequence in batch
    """
    kl_div = torch.exp(ref_log_prob - curr_log_prob) - (ref_log_prob - curr_log_prob) - 1
    return kl_div
    


def grpo(curr_model, old_model, ref_model, seq_ids, output_masks, rewards, ep=0.2, beta=0.1):
    """
    GRPO loss function

    Args:
        curr_model: current policy model - trained model weights
        old_model: old policy model - old model weights from previous iteration
        ref_model: Reference model base model - pretrained model weights
        seq_ids: Tensor of shape (B, seq_len) containing token ids of generated sequence belonging to same input
        output_masks: Tensor of shape (B, seq_len) containing mask for ouptput positions. (input and padding positions in seq_ids are set to 0)
        rewards: Tensor of shape (B,) containing reward values for generated sequences belonging to same input
        ep: clipping threshold for reward
        beta: KL divergence regularization parameter

    Returns:
        Tensor of shape (1,) containing GRPO loss for the batch
    """
    curr_log_prob = get_model_log_prob(curr_model, seq_ids, output_masks)   # (B,) 
    with torch.no_grad():
        # no grad for old and ref models
        old_log_prob = get_model_log_prob(old_model, seq_ids, output_masks)
        ref_log_prob = get_model_log_prob(ref_model, seq_ids, output_masks)
    # policy objective
    adv = get_group_relative_reward_advantage(rewards)                     
    ratio = torch.exp(curr_log_prob - old_log_prob)                        
    unclipped = ratio * adv                                                
    clipped = torch.clamp(ratio, 1 - ep, 1 + ep) * adv                     
    policy_obj = torch.min(unclipped, clipped)                             
    # kl divergence
    kl_div = get_kl_divergence(curr_log_prob, ref_log_prob)                
    # grpo loss
    grpo_loss = policy_obj - (beta * kl_div)                                # (B,)
    grpo_loss = - grpo_loss.mean()                                          # (1,)
    return grpo_loss


Test if get prod produces correct output

In [355]:
# test get_model_prob
input_len = input_pt.input_ids.shape[1]
print(f"sample shape: {sample_gen.sequences.shape}")
output_masks = torch.ones(sample_gen.sequences.shape).to(device)
output_masks[:, :input_len] = 0 # mask input tokens
# TODO: mask padding tokens

log_prob = get_model_log_prob(model, sample_gen.sequences, output_masks)
print(f"Prob: {torch.exp(log_prob)}")
print(f"Log prob: {log_prob}")
print(f"log_prob requires grad: {log_prob.requires_grad}")

sample shape: torch.Size([8, 38])
Prob: tensor([8.6934e-07, 0.0000e+00, 2.0975e-05, 4.8292e-04, 8.7455e-05, 0.0000e+00,
        4.8703e-12, 0.0000e+00], device='mps:0', grad_fn=<ExpBackward0>)
Log prob: tensor([ -13.9555, -344.4367,  -10.7722,   -7.6357,   -9.3444, -218.2783,
         -26.0479, -455.4019], device='mps:0', grad_fn=<SumBackward1>)
log_prob requires grad: True


Test gradients of masked position is zero; we only want to learning from output token, not the input and padding tokens

In [271]:
def test_masked_gradients(model):
    """Test if gradients for masked positions are zero"""
    
    # Clear existing gradients
    model.zero_grad()
    model.train()
    
    # Forward pass with sequences that require gradients
    sample_input = "Tell me a joke."
    input_pt = tokenizer(sample_input, return_tensors="pt").to(device)
    input_len = input_pt.input_ids.shape[1]
    sample_gen = model.generate(
        **input_pt,
        max_new_tokens=10,
        temperature=0.5,
        return_dict_in_generate=True,
        output_scores=True,
    )
    output_masks = torch.ones(sample_gen.sequences.shape).to(device)
    output_masks[:, :input_len] = 0
    output_masks.requires_grad = True
    output_masks.retain_grad()
    sequences_with_grad = sample_gen.sequences
    p, log_prob, logits = get_model_log_prob(model, sequences_with_grad, output_masks)
    log_prob.retain_grad()
    logits.retain_grad()
    p.retain_grad()
    print(f"Output probability: {p}")
    print(f"log prob: {log_prob}")

    # some simple loss function like grpo
    loss = - (p / (input_len + 1) * 0.21)
    loss.retain_grad()
    print(f"loss: {loss}")

    # Grads before backward pass
    print(f"Logit grads before backward pass: {logits.grad}")
    print(f"Log_prob grads before backward pass: {log_prob.grad}")
    
    # Backward pass
    print("-"*100)
    loss.backward()
    print("Backward pass done")
    
    print(f"Loss grad: {loss.grad}")
    print(f"Prob grad: {p.grad}")
    print(f"log_prob grad: {log_prob.grad}")
    # Check if gradients for input positions are zero
    logits_grad = logits.grad
    print(f"Logits grad shape: {logits_grad.shape}")
    print(f"Logits grad: \n{logits_grad}")

    # Check input token gradients (should be zero)
    input_token_grads = logits_grad[:, :input_len-1]
    print(f"Input token grads: {input_token_grads}")
    print(f"Input token grads shape: {input_token_grads.shape}")
    print(f"Max absolute gradient for input tokens: {input_token_grads.abs().max().item():.10f}")
    print(f"Sum of input token gradients: {input_token_grads.sum().item():.10f}")
    print(f"Are all input token gradients zero? {torch.allclose(input_token_grads, torch.zeros_like(input_token_grads), atol=1e-10)}")
    
    # Check output token gradients (should be non-zero)
    output_token_grads = logits_grad[:, input_len-1:]
    print(f"Output token grads shape: {output_token_grads.shape}")
    print(f"Output token grads: {output_token_grads}")
    print(f"Max absolute gradient for output tokens: {output_token_grads.abs().max().item():.10f}")
    print(f"Number of non-zero output gradients: {(output_token_grads.abs() > 1e-10).sum().item()}")
    
    return logits_grad

# Test with your data
_ = test_masked_gradients(model)

Output probability: tensor([0.3907], device='mps:0', grad_fn=<ExpBackward0>)
log prob: tensor([-1.2838e+01, -1.1584e+00, -2.1250e+00, -5.1828e+00, -1.2865e+00,
        -3.1996e-02, -1.5039e-03, -6.3487e-02, -1.4876e-04, -8.3659e-01,
        -3.5763e-07, -5.0759e-03, -6.6636e-05, -5.2772e-04, -3.0656e-04,
        -3.9765e+01], device='mps:0', grad_fn=<SqueezeBackward0>)
loss: tensor([-0.0117], device='mps:0', grad_fn=<NegBackward0>)
Logit grads before backward pass: None
Log_prob grads before backward pass: None
----------------------------------------------------------------------------------------------------
Backward pass done
Loss grad: tensor([1.], device='mps:0')
Prob grad: tensor([-0.0300], device='mps:0')
log_prob grad: tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0117, -0.0117, -0.0117,
        -0.0117, -0.0117, -0.0117, -0.0117, -0.0117, -0.0117, -0.0117, -0.0000],
       device='mps:0')
Logits grad shape: torch.Size([1, 16, 262144])
Logits grad: 
tensor([[[0.0000e+

In [333]:
# test get_kl_divergence for same model
# test get_model_prob
sample_input = "Tell me a joke."
input_pt = tokenizer(sample_input, return_tensors="pt").to(device)
input_len = input_pt.input_ids.shape[1]
sample_gen = model.generate(
        **input_pt,
        max_new_tokens=10,
        temperature=0.5,
        return_dict_in_generate=True,
        output_scores=True,
)

print(f"sample shape: {sample_gen.sequences.shape}")
input_len = input_pt.input_ids.shape[1]
output_masks = torch.ones(sample_gen.sequences.shape).to(device)
output_masks[:, :input_len] = 0 # mask input tokens

log_prob = get_model_log_prob(model, sample_gen.sequences, output_masks)
print(f"Log prob: {log_prob}")
print(f"Prob: {torch.exp(log_prob)}")

kl_div = get_kl_divergence(log_prob, log_prob)
print(f"KLD: {kl_div}")
# Assert that KL divergence is zero when comparing same model to itself
assert kl_div == 0.0, f"KL divergence should be zero for same model, got {kl_div}"
print("✅ KL divergence is zero as expected for same model comparison")



sample shape: torch.Size([1, 16])
Log prob: tensor([-1.5540], device='mps:0', grad_fn=<SumBackward1>)
Prob: tensor([0.2114], device='mps:0', grad_fn=<ExpBackward0>)
KLD: tensor([0.], device='mps:0', grad_fn=<SubBackward0>)
✅ KL divergence is zero as expected for same model comparison


- Sample sequences and prepare data for model

In [347]:
# Generate sample sequences for single input
num_return_sequences = 8
sample_input = "Tell me a joke."
input_pt = tokenizer(sample_input, return_tensors="pt").to(device)
input_len = input_pt.input_ids.shape[1]
print(f"input_len: {input_len}")
sample_gen = model.generate(
        **input_pt,
        num_return_sequences=num_return_sequences,
        max_new_tokens=32,
        temperature=1,
        return_dict_in_generate=True,
)
print(f"Shape of sample_gen.sequences: {sample_gen.sequences.shape}")
input_len = input_pt.input_ids.shape[1]
print(f"sample_gen.sequences:")
print(sample_gen.sequences)

# Set input and padding tokens to zero in output_masks
output_masks = torch.ones(sample_gen.sequences.shape).to(device)
output_masks[:, :input_len] = 0 # mask input tokens
pad_token_id = tokenizer.pad_token_id
print(f"pad_token_id: {pad_token_id}")
if pad_token_id is not None:
    padding_mask = (sample_gen.sequences == pad_token_id)
    output_masks[padding_mask] = 0
print(f"output_masks:")
print(output_masks)

# decode the sequences
for i in range(sample_gen.sequences.shape[0]):
    print(i, tokenizer.decode(sample_gen.sequences[i], skip_special_tokens=True))

input_len: 6
Shape of sample_gen.sequences: torch.Size([8, 38])
sample_gen.sequences:
tensor([[     2,  54593,    786,    496,  31481, 236761,    108,  11355,   1602,
            506,  12480,   4071,    506,  37421, 236881,    108,   2021,    974,
            531,    506,   1032,  13307, 236888,    108,   7243,    108,  19058,
         236764,    600, 236789, 236751,    496,   1535,    886, 236888, 236743,
            108,   3689],
        [     2,  54593,    786,    496,  31481, 236761,    108,  11355,   1602,
            506,  55134,  47129,   3345,    614,   8054, 236881,    108,   1390,
           8468,    668,    691,  15647,    528,    914,   2135, 236888,    107,
            106,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [     2,  54593,    786,    496,  31481, 236761,    108,  11355,   1602,
            506,  30979,   3798,   1024, 236881,    108,  17574,    625,    691,
           1156,  20718, 236888,    108,   7243,    

During the first step of training, curr, old, ref models are the same.
```
ratio_policy = 1.
Thus, loss = mean(adv) = 0
```
But, the gradient will still be NON-zero, ensuring that the model will learn.

In [363]:
# test grpo loss function
model.zero_grad()
print_grad_sum(model)
rewards = torch.randn(num_return_sequences).to(device)
print("⚠️ When current, old and ref models are same 👉 GRPO loss should be equal to negative mean of advantages")
adv = get_group_relative_reward_advantage(rewards)
print(f"rewards: {rewards}")
print(f"adv: {adv}")
print(f"advantages mean: {-adv.mean():.6f}")

grpo_loss = grpo(model, model, model, sample_gen.sequences, output_masks, rewards)
print(f"GRPO loss: {grpo_loss:.6f}")

# assert
assert grpo_loss == -adv.mean(), f"GRPO loss should be equal to mean of advantages, got {grpo_loss} and {adv.mean()}"
assert abs(grpo_loss) < 1e-6, f"GRPO loss should be approximately zero, got {grpo_loss}"

print("✅ GRPO loss is equal to negative mean of advantages")
print("✅ GRPO loss is close to zero, during first step of training")

# check if grads are NON-ZERO
grpo_loss.backward()
print_grad_sum(model, verbose=True)
print("✅ However, gradients are non-zero, even though loss is zero")

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
print(f"After clipping gradients:")
print_grad_sum(model, verbose=True)
print("✅ Gradients are clipped to 1.0")

Model Gradient Norm ----> 0.000000
⚠️ When current, old and ref models are same 👉 GRPO loss should be equal to negative mean of advantages
rewards: tensor([-2.0928, -0.9464, -1.2513, -0.4060, -2.0444,  0.5433,  0.7662,  1.6418],
       device='mps:0')
adv: tensor([-1.1895, -0.3473, -0.5713,  0.0497, -1.1540,  0.7472,  0.9109,  1.5543],
       device='mps:0')
advantages mean: -0.000000
GRPO loss: -0.000000
✅ GRPO loss is equal to negative mean of advantages
✅ GRPO loss is close to zero, during first step of training
model.layers.25.self_attn.q_proj.weight grad: 112.187042
model.layers.25.self_attn.k_proj.weight grad: 113.542786
model.layers.25.self_attn.v_proj.weight grad: 298.359863
model.layers.25.self_attn.o_proj.weight grad: 207.357681
model.layers.25.self_attn.q_norm.weight grad: 1.726742
model.layers.25.self_attn.k_norm.weight grad: 2.308616
model.layers.25.mlp.gate_proj.weight grad: 359.957367
model.layers.25.mlp.up_proj.weight grad: 310.338593
model.layers.25.mlp.down_proj.weigh

------------------
### 5. Training Loop

GRPO training loop
```
for sample in dataset
    1.  generations -> model.generate(sample, n=8)
        batch generations togethers with padding
        create output token mask, 
    2.  reward, advantage --> score (generations)
        loss -> grpo; seq_prob; kl
    3.  backprop -> loss.backward() 
        optimizer.step()
```

In [124]:
import torch.optim as optim

adam = optim.Adam(model.parameters())
adam.zero_grad()

In [125]:
model.eval()
adam.zero_grad()

sample_input = "Tell me a joke."
input_pt = tokenizer(sample_input, return_tensors="pt").to(device)

# sample generation (auto-regressive)
sample_gen = model.generate(
    **input_pt,
    max_new_tokens=10,
    temperature=0.5,
    return_dict_in_generate=True,
    output_scores=True,
)
print(f"Output sequence (len:{sample_gen.sequences[0].shape[0]})--> {tokenizer.decode(sample_gen.sequences[0], skip_special_tokens=True)}")
input_len = input_pt.input_ids.shape[1]
print(f"input_len: {input_len}")
seq_len = sample_gen.sequences[0].shape[0]
gen_seq = sample_gen.sequences[0][input_len:]
print(f"Generated sequence (len:{gen_seq.shape[0]})--> {tokenizer.decode(gen_seq, skip_special_tokens=True)}")

# model forward pass - to get logits for the generated sequence
attention_mask = torch.ones(1, seq_len)
output_logits = model(
    input_ids=sample_gen.sequences,
    attention_mask=attention_mask,
)
# sum of log_probs for ids in sample_gen.sequences[0]
gen_seq_logits = output_logits.logits[:,input_len-1:-1,:]   # ***IMP*** logits for generated sequence input sequence left shifted by 1
gen_seq_scores = torch.log_softmax(gen_seq_logits, dim=-1)
gen_seq_ids = sample_gen.sequences.unsqueeze(-1)[:,input_len:,:]    # ids of generated sequence
gen_seq_log_prob = gen_seq_scores.gather(dim=2, index=gen_seq_ids).squeeze()

sum_gen_seq_log_prob = gen_seq_log_prob.sum()
gen_seq_prob = torch.exp(sum_gen_seq_log_prob)
print(f"sum_gen_seq_log_prob: {sum_gen_seq_log_prob:.4}")
print(f"gen_seq_prob: {gen_seq_prob:.2%}")
# print prob of each token in gen_seq
for i in range(gen_seq.shape[0]):
    print(f"\"{tokenizer.decode(gen_seq[i])}\"\t| token {gen_seq[i]} prob: {torch.exp(gen_seq_log_prob[i]):.2%} | log prob: {gen_seq_log_prob[i]:.4}")
gen_seq_prob


Output sequence (len:16)--> Tell me a joke.

Why did the scarecrow win an award?
input_len: 6
Generated sequence (len:10)--> 

Why did the scarecrow win an award?
sum_gen_seq_log_prob: -0.9397
gen_seq_prob: 39.07%
"

"	| token 108 prob: 96.85% | log prob: -0.032
"Why"	| token 11355 prob: 99.85% | log prob: -0.001504
" did"	| token 1602 prob: 93.85% | log prob: -0.06349
" the"	| token 506 prob: 99.99% | log prob: -0.0001488
" scare"	| token 55134 prob: 43.32% | log prob: -0.8366
"crow"	| token 47129 prob: 100.00% | log prob: -3.576e-07
" win"	| token 3345 prob: 99.49% | log prob: -0.005076
" an"	| token 614 prob: 99.99% | log prob: -6.664e-05
" award"	| token 8054 prob: 99.95% | log prob: -0.0005277
"?"	| token 236881 prob: 99.97% | log prob: -0.0003066


tensor(0.3907, device='mps:0', grad_fn=<ExpBackward0>)

"

"	| token 108 prob: 100.00%
"Why"	| token 11355 prob: 100.00%
" did"	| token 1602 prob: 100.00%
" the"	| token 506 prob: 100.00%
" bicycle"	| token 30979 prob: 35.67%
" fall"	| token 3798 prob: 100.00%
" over"	| token 1024 prob: 100.00%
"?"	| token 236881 prob: 100.00%
"

"	| token 108 prob: 94.77%
"Because"	| token 17574 prob: 100.00%
Sequence 1 log prob: -1.085
Sequence 1 prob: 33.80%
Sequence 1 string: 

Why did the bicycle fall over?

Because

In [126]:
# calculate loss
loss = - (gen_seq_prob/0.1 * 1)
print(f"loss: {loss:.4}")

# calculate gradient
# loss.backward()

# print grad sum
print_grad_sum(model)



loss: -3.907
model.embed_tokens.weight grad: 0.0
model.layers.0.self_attn.q_proj.weight grad: 0.0
model.layers.0.self_attn.k_proj.weight grad: 0.0
model.layers.0.self_attn.v_proj.weight grad: 0.0
model.layers.0.self_attn.o_proj.weight grad: 0.0
model.layers.0.self_attn.q_norm.weight grad: 0.0
model.layers.0.self_attn.k_norm.weight grad: 0.0
model.layers.0.mlp.gate_proj.weight grad: 0.0
model.layers.0.mlp.up_proj.weight grad: 0.0
model.layers.0.mlp.down_proj.weight grad: 0.0
model.layers.0.input_layernorm.weight grad: 0.0
model.layers.0.post_attention_layernorm.weight grad: 0.0
model.layers.0.pre_feedforward_layernorm.weight grad: 0.0
model.layers.0.post_feedforward_layernorm.weight grad: 0.0
model.layers.1.self_attn.q_proj.weight grad: 0.0
model.layers.1.self_attn.k_proj.weight grad: 0.0
model.layers.1.self_attn.v_proj.weight grad: 0.0
model.layers.1.self_attn.o_proj.weight grad: 0.0
model.layers.1.self_attn.q_norm.weight grad: 0.0
model.layers.1.self_attn.k_norm.weight grad: 0.0
model

In [260]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy


# Dummy language model (simplified to linear for toy test)
class ToyLM(nn.Module):
    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids):
        hidden = self.embed(input_ids)
        logits = self.decoder(hidden)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs

# ---- SETUP ---- #
vocab_size = 10
seq_len = 5
hidden_dim = 16
G = 8  # Group size

model = ToyLM(vocab_size, hidden_dim)
model_copy = copy.deepcopy(model)

# Create G sampled outputs (sequences)
input_ids = torch.randint(0, vocab_size, (G, seq_len))  # [G, T]

# Forward pass
log_probs = model(input_ids)  # [G, T, V]
token_log_probs = log_probs.gather(2, input_ids.unsqueeze(-1)).squeeze(-1)  # log P(o_i|q)
# Sequence log-probs
seq_log_probs = token_log_probs.sum(dim=1)  # [G]

copy_log_probs = model_copy(input_ids)
copy_token_log_probs = copy_log_probs.gather(2, input_ids.unsqueeze(-1)).squeeze(-1)
copy_seq_log_probs = copy_token_log_probs.sum(dim=1)

print(f"seq_log_probs: {seq_log_probs}")
print(f"copy_seq_log_probs: {copy_seq_log_probs}")

# Compute advantages (simulate standardized rewards)
# rewards = torch.tensor([1.0, 0.5, -0.2, 0.7])
rewards = torch.randn(G)
adv = (rewards - rewards.mean()) / rewards.std()  # [G]
print(f"adv mean: {adv.mean():.6f}")

# GRPO loss component: - log_prob * advantage (REINFORCE-style)
ratio = torch.exp(seq_log_probs - copy_seq_log_probs)
print(f"ratio: {ratio}")
print(f"ratio mean: {ratio.mean():.6f}")
loss = - (ratio  * adv)
loss = loss.mean()
print(f"loss: {loss}")
print(f"loss: {loss:.6f}")
loss.backward()

# ---- TEST ---- #
grads = []
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: {param.grad.norm().item():.6f}, size: {param.grad.shape}")
        grads.append(param.grad.norm().item())

print("Gradient norms:", grads)

print(f"""
Conclusion:
During the first step of training, curr, old, ref models are the same. So the loss (aka. learning) will be zero.
But, the gradient will still be non-zero, ensuring that the model will learn.
""")

seq_log_probs: tensor([-11.4805, -12.5969, -11.2116,  -9.4501,  -9.7526, -11.2985, -11.6984,
        -11.4296], grad_fn=<SumBackward1>)
copy_seq_log_probs: tensor([-11.4805, -12.5969, -11.2116,  -9.4501,  -9.7526, -11.2985, -11.6984,
        -11.4296], grad_fn=<SumBackward1>)
adv mean: 0.000000
ratio: tensor([1., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<ExpBackward0>)
ratio mean: 1.000000
loss: 0.0
loss: 0.000000
embed.weight: 0.345921, size: torch.Size([10, 16])
decoder.weight: 2.347969, size: torch.Size([10, 16])
decoder.bias: 0.629673, size: torch.Size([10])
Gradient norms: [0.34592145681381226, 2.3479692935943604, 0.6296733021736145]

Conclusion:
During the first step of training, curr, old, ref models are the same. So the loss (aka. learning) will be zero.
But, the gradient will still be non-zero, ensuring that the model will learn.

