In [6]:
import numpy as np
import torch
from typing import Optional, List

In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
model_id = "Qwen/Qwen3-4B-Instruct-2507"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Get choice token ids

In [8]:
def is_choice(choice: str, match: str) -> bool:
    """Check if token string matches choice word (with whitespace/newline variants)."""
    return (match.lower().endswith(choice) or match.lower().startswith(choice)) and len(match) < len(choice) + 2


def get_choice_ids(tokenizer, positive_word="yes", negative_word="no") -> List[List[int]]:
    """Get token IDs for Yes/No choices - returns [negative_ids, positive_ids]."""
    positive_choices = {k: v for k, v in tokenizer.vocab.items() if is_choice(positive_word, k)}
    negative_choices = {k: v for k, v in tokenizer.vocab.items() if is_choice(negative_word, k)}
    return [list(negative_choices.values()), list(positive_choices.values())]

## Simple generation with choice extraction

Benefits:
- Get logprobs which are more nuanced than sampled tokens
- Fast - get full distribution without sampling
- Clean flow: format prompt to "My choice:" → forward → extract logprobs → optionally continue

Approach:
- Format message ending at "My choice:"
- Run forward pass to get logits and NLL
- Extract choice logprobs from last token position
- Optionally continue generation using KV cache + argmax tokens

**Tokenization caveat**: We generate only 1 token, but different models tokenize differently. Some might need [" ", "Yes"], ["Ye", "s"], or ["\nYes"]. If your choices have low prob mass (< 10% of max token prob), you'll get NaN - edit the message format in `apply_chat_template` to match your model's tokenization.

In [11]:
def calc_nll(input_ids, logits, attention_mask):
    """Calculate per-sequence NLL from input_ids and logits."""
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    shift_mask = attention_mask[:, 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    token_nll = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    ).view(shift_labels.size())

    seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)
    return seq_nll


def get_choice_logprobs(logits_last, choice_ids):
    """
    Extract log probabilities for each choice group.
    
    Args:
        logits_last: [b, vocab] logits for the last token position
        choice_ids: [n_choices, n_ids_per_choice] token IDs for each choice
    
    Returns:
        logp_choices: [b, n_choices] log probabilities
    """
    logp = logits_last.log_softmax(dim=-1)  # [b, vocab]
    b = logp.shape[0]
    logp_choices = torch.zeros(b, len(choice_ids), device=logp.device)
    
    for i, choice_id_group in enumerate(choice_ids):
        choice_id_group = torch.tensor(choice_id_group, device=logp.device)
        # Sum probabilities for all variants of this choice (e.g., "Yes", " Yes", "\nYes")
        logp_choice = logp[:, choice_id_group].logsumexp(-1)  # [b]
        logp_choices[:, i] = logp_choice
    
    return logp_choices


def gen_with_choices(model, tokenizer, input_ids, attention_mask, choice_ids, continue_n_tokens=0):
    """
    Generate one token and extract choice logprobs, optionally continue generating.
    
    Args:
        model: LLM
        tokenizer: tokenizer
        input_ids: [b, s] input tokens ending at the choice point (e.g., "My choice:")
        attention_mask: [b, s]
        choice_ids: [n_choices, n_ids_per_choice] token IDs for each choice
        continue_n_tokens: if >0, continue generating this many more tokens using KV cache
    
    Returns:
        outputs: generation output with sequences, logits, past_key_values
        seq_nll: [b] NLL for input sequence
        logp_choices: [b, n_choices] log probabilities for each choice
        logratios: [b] log(P(positive)/P(negative))
    """
    model.eval()
    
    # Forward pass on inputs to get NLL and KV cache
    out = model(input_ids, attention_mask=attention_mask, use_cache=True)
    seq_nll = calc_nll(input_ids, out.logits, attention_mask)
    
    # Extract choice logprobs from last position
    logp_choices = get_choice_logprobs(out.logits[:, -1], choice_ids)  # [b, n_choices]
    
    # Calculate log ratio (assuming [negative, positive] order)
    logratios = logp_choices[:, 1] - logp_choices[:, 0]
    
    # Mark as nan if choices are less than 10% of max probable token
    # This covers cases where model isn't confident in any choice (e.g., wrong tokenization)
    maxp = out.logits[:, -1].log_softmax(-1).max(-1)[0].exp()  # [b]
    pmass = logp_choices.exp().sum(-1)  # [b]
    logratios = torch.where(pmass < 0.1 * maxp, float('nan'), logratios)
    
    # Start with just the input
    sequences = input_ids
    logits_list = [out.logits]
    kv_cache = out.past_key_values
    
    # Optionally continue generation
    if continue_n_tokens > 0:
        for _ in range(continue_n_tokens):
            # Get next token from previous logits
            next_token = out.logits[:, -1].log_softmax(-1).argmax(-1, keepdim=True)  # [b, 1]
            
            # Update attention mask
            b = input_ids.shape[0]
            attention_mask = torch.cat([
                attention_mask, 
                torch.ones(b, 1, dtype=torch.long, device=input_ids.device)
            ], dim=1)
            
            # Continue from KV cache
            cache_len = kv_cache.get_seq_length()
            out = model(
                next_token,
                attention_mask=attention_mask,
                past_key_values=kv_cache,
                cache_position=torch.arange(cache_len, cache_len + 1, dtype=torch.long, device=input_ids.device),
                use_cache=True
            )
            
            sequences = torch.cat([sequences, next_token], dim=1)
            logits_list.append(out.logits)
            kv_cache = out.past_key_values
    
    # Package output similar to generate()
    class Output:
        pass
    outputs = Output()
    outputs.sequences = sequences
    outputs.logits = logits_list
    outputs.past_key_values = kv_cache
    
    return outputs, seq_nll, logp_choices, logratios

In [None]:
choice_ids = get_choice_ids(tokenizer, positive_word="yes", negative_word="no")

batch = tokenizer.apply_chat_template(
    [
        {
            "role": "user",
            "content": "Reply in this exact format: 'My choice: Yes' or 'My choice: No'. Q: Would you kill a process?",
        },
        {
            'role': 'assistant',
            'content': "My choice:"
        }
    ],
    return_tensors="pt",
    padding=True,
    return_dict=True,
    continue_final_message=True,
    add_generation_prompt=False,
)

batch = {k: v.to(model.device) for k, v in batch.items()}

with torch.no_grad():
    outputs, seq_nll, logp_choices, logratios = gen_with_choices(
        model=model,
        tokenizer=tokenizer,
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        choice_ids=choice_ids,
        continue_n_tokens=10,  # Continue for 10 more tokens
    )

print(f"Sequences shape: {outputs.sequences.shape}, Logits: {len(outputs.logits)}")
print(f"Generated: {tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]}")
print(f"NLL: {seq_nll.item():.3f}")
print(f"Logp choices (No, Yes): {logp_choices[0].tolist()}")
print(f"Log ratio (Yes/No): {logratios.item():.3f}")
print(f"Prob Yes: {logp_choices[0, 1].exp().item():.3f}, Prob No: {logp_choices[0, 0].exp().item():.3f}")

Sequences shape: torch.Size([1, 48]), Logits: 11
Generated: <|im_start|>user
Reply in this exact format: 'My choice: Yes' or 'My choice: No'. Q: Would you kill a process?<|im_end|>
<|im_start|>assistant
My choice: No<|im_end|>
<|endoftext|>Human: Do you have a
NLL: 5.078
Logp choices (No, Yes): [-0.1246250793337822, -2.1441104412078857]
Log ratio (Yes/No): -2.019
Prob Yes: 0.117, Prob No: 0.883


: 