In [1]:
from transformers.cache_utils import DynamicCache
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from collections import defaultdict
from typing import Optional

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
model_id = "Qwen/Qwen3-4B-Instruct-2507"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
from typing import List, Tuple
import re
import torch
from loguru import logger

def binary_log_cls(logits, choice_ids):
    logp = logits.log_softmax(dim=-1).detach().cpu()
    log_choices = torch.zeros(len(choice_ids)).to(logp.device)
    for i, choice_id_group in enumerate(choice_ids):
        choice_id_group = torch.tensor(choice_id_group).to(logp.device)
        logp_choice = logp[choice_id_group].logsumexp(-1)
        log_choices[i] = logp_choice
    return log_choices


def extract_log_ratios(
    out: "ModelOutput", input_ids, choice_ids
):
    """Get [sequences x answers] log ratios for each of len(sequences) X regexp matches."""
    N = input_ids.shape[1]
    bs = out.sequences.shape[0]
    logrs = torch.ones((bs, len(choice_ids))) * float("nan")
    for sample_i in range(bs):
        log_choices = binary_log_cls(
            out.logits[-1][sample_i], choice_ids
        )
        logrs[sample_i] = log_choices
    return logrs


## Get choice token ids

In [4]:
# Many tokenizers don't just use Yes, but \nYes, " Yes" and so on. We need to catch all variants
def is_choice(choice: str, match: str) -> bool:
    return (match.lower().endswith(choice) or match.lower().startswith(choice)) and len(
        match
    ) < len(choice) + 2


def get_choice_ids(
    tokenizer, positive_word="yes", negative_word="no"
) -> List[List[int]]:
    """Get token IDs for Yes/No choices."""

    positive_choices = {
        k: v for k, v in tokenizer.vocab.items() if is_choice(positive_word, k)
    }
    negative_choices = {
        k: v for k, v in tokenizer.vocab.items() if is_choice(negative_word, k)
    }

    return [list(negative_choices.values()), list(positive_choices.values())]

## Gen

Benefits:
- we get logprobs which are more nuanced and less noisy than tokens
- it's fast, we get the full distribution without sampling, we don't generate more tokens than we need

Limitations
- it doesn't think about the answer, which would change it's answer. So we get a less considered answer
- sometimes forcing might take it outside it's training distribution, meaning we get unnatural behaviour and non representative logprobs


In [5]:
"""

ref:
- https://huggingface.co/docs/transformers/v4.56.1/en/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria
- https://github.com/huggingface/transformers/blob/e8a6eb3304033fdd9346fe3b3293309fe50de238/tests/generation/test_stopping_criteria.py#L51


Ref regexp based logit colleciton
- https://github.com/wassname/repeng/blob/add-performance-validation/notebooks/performance_tests.ipynb 
- https://github.com/wassname/repeng/blob/research/repeng/eval.py
"""

from transformers import (
    StopStringCriteria,
    StoppingCriteriaList,
    EosTokenCriteria,
    MaxLengthCriteria,
)

def calc_nll(input_ids, logits, attention_mask):
    # Shift logits and labels for NLL: predict token t from tokens 0..t-1
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    shift_mask = attention_mask[:, 1:].contiguous()

    # Compute NLL per token, masking padding
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    token_nll = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    ).view(shift_labels.size())

    # Average NLL per sequence (excluding padding)
    seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)

    return seq_nll


def check_input_shapes(input_ids, attention_mask, kv_cache):
    c = kv_cache.get_seq_length()
    i = input_ids.shape[1]
    a = attention_mask.shape[1]
    assert c+i == a, f"Cache length + input length must equal attention mask length, got {c}+{i} != {a}"

def gen_with_nll(model, tokenizer, batch2, lookback=1, **kwargs):
    """
    problem: generate does not return logits for inputs, but we need them for nll

    but forward -> generate with past key values does, and it doesn't recompute the input logits
    """
    if 'attention_mask' not in batch2:
        batch2['attention_mask'] = torch.ones_like(batch2['input_ids'])
    input_ids = batch2['input_ids']
    attn_mask = batch2['attention_mask']
    forward_out = model(input_ids[:, :-lookback], attention_mask=attn_mask[:, :-lookback], use_cache=True)

    seq_nll = calc_nll(input_ids[:, :-lookback], forward_out.logits, attn_mask[:, :-lookback])
    kv_cache = forward_out.past_key_values

    # Continue generation from the cached KV states
    cl = kv_cache.get_seq_length()
    new_tokens = input_ids[:, -lookback:]
    kwargs['output_logits'] = True
    kwargs['return_dict_in_generate'] = True
    kwargs['min_new_tokens'] = 0

    check_input_shapes(new_tokens, attn_mask, kv_cache)
    outputs = model.generate(
        input_ids=new_tokens,  # Last token as new input
        attention_mask=attn_mask, # attn mask should cover cache and new tokens
        past_key_values=kv_cache,

        # the next cache position will be n+1
        cache_position=torch.arange(cl, cl+new_tokens.shape[1], dtype=torch.int64, device=input_ids.device),
        use_cache=True,
        **kwargs
    )

    # now we need to modify this as generate does return the full sequences, including inputs ids
    outputs.sequences = torch.concat([input_ids[:, :-lookback], outputs.sequences], 1)

    return outputs, seq_nll


def gen_with_nll_and_logprobs(model, tokenizer, batch2, choice_ids, stop_strings=[": Yes", ": Yes ",  " choice: Yes", "choice: Yes", ": No", ": No ", " choice: No"], max_new_tokens=16, continue_after_ss=False, lookback=1, **kwargs):
    """
    Generate outputs while also computing input NLL and log probabilities for choices.
    """
    model.eval()
    outputs, seq_nll = gen_with_nll(
        model, tokenizer, batch2, max_new_tokens=max_new_tokens, 
        stopping_criteria=StoppingCriteriaList(
            [
                StopStringCriteria(tokenizer, stop_strings),
                EosTokenCriteria(tokenizer.eos_token_id),
                MaxLengthCriteria(max_length=batch2["input_ids"].shape[1] + max_new_tokens),
            ]
        ),
        lookback=lookback,
        **kwargs
    )
    

    input_ids = batch2['input_ids']
    logp_choices = extract_log_ratios(outputs, input_ids, choice_ids)
    last_token = outputs.sequences[:, -1:]

    if continue_after_ss:
        # For debugging, continue generation after stop string reached, untill max_new_tokens is reached
        n = outputs.past_key_values.get_seq_length()
        n_gen = (n - input_ids.shape[1])
        next_input_ids = outputs.logits[-1].log_softmax(-1).argmax(-1).unsqueeze(-1)
        b = batch2['input_ids'].shape[0]
        new_attn_mask = torch.cat(
            [
                batch2['attention_mask'], 
                torch.ones(b, n_gen, dtype=torch.int64, device=input_ids.device),   
                torch.ones_like(next_input_ids)],
            dim=1
        )
        kwargs['output_logits'] = True
        kwargs['return_dict_in_generate'] = True
        max_new_tokens = max_new_tokens - n_gen
        check_input_shapes(next_input_ids, new_attn_mask, outputs.past_key_values)
        continued_outputs = model.generate(
            input_ids=next_input_ids,
            attention_mask=new_attn_mask,
            past_key_values=outputs.past_key_values,
            cache_position=torch.arange(n, n+1, dtype=torch.int64, device=input_ids.device),
            min_new_tokens=max_new_tokens,
            max_new_tokens=max_new_tokens,
            **kwargs
        )
        # Concatenate sequences and logits
        outputs.sequences = torch.concat([outputs.sequences, continued_outputs.sequences[:, 1:]], 1)
        outputs.logits = outputs.logits + continued_outputs.logits


    logratios = logp_choices[:, 1] - logp_choices[:, 0]  # Positive - Negative log-prob ratio
    
    # but total prob mass < 10% -> nan
    pmass = logp_choices.exp().sum(-1)
    logratios = torch.where(pmass < 0.1, float('nan'), logratios)    

    return outputs, seq_nll, logp_choices, logratios, last_token


In [None]:
choice_ids = get_choice_ids(tokenizer, positive_word="yes", negative_word="no")
forced = True
batch2 = tokenizer.apply_chat_template(
    [
        {
            "role": "user",
            "content": """
Reply in this exact format, and only in the format "My choice: Yes" or "My choice: No". 
Q: Would you kill a process?               
""",
        },
        {
            'role': 'assistant',
            'content': "My choice:"
        }
    ],
    return_tensors="pt",
    padding=True,
    return_dict=True,
    continue_final_message=True,
    add_generation_prompt=False,
)

from transformers import GenerationConfig

generation_config = GenerationConfig(
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    use_cache=True,
    output_logits=True,
    return_dict_in_generate=True,
    do_sample=False,
)

batch2 = {k: v.to(model.device) for k, v in batch2.items()}
max_new_tokens  =32
with torch.no_grad():
    outputs, seq_nll, logp_choices, logratios, last_token = gen_with_nll_and_logprobs(
        model=model,
        tokenizer=tokenizer,
        batch2=batch2,
        choice_ids=choice_ids,
        stop_strings=["My choice: Yes", "My choice: No"],
        max_new_tokens=max_new_tokens,
        lookback=4, # if we use forcing we should look back enough to cover it
        continue_after_ss=True,
        do_sample=False,
        generation_config=generation_config,
    )

print(outputs.sequences.shape, len(outputs.logits))
print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0])
last_token_s = tokenizer.batch_decode(last_token, skip_special_tokens=False)[0]
print(f"Last token: {last_token_s}")
assert last_token_s.strip() in ["Yes", "No"], f"Unexpected last token: {last_token_s}"
seq_nll, logp_choices, logratios, last_token

`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8, 'bos_token_id': 151643}. If this is not desired, please set these values explicitly.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
A custom stopping criteria of type <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> will take precedence. Please check the docstring of <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> to see related `.generate()` flags.
A custom stopping criteria of type <class 'transformers.generation.stopping_criteria.EosTokenCriteria'> has been passed to `.generate()`, but it was also create

## Compare to straight generate

In [None]:
# Unit test, make sure output is same as straight generate with forced tokens

with torch.no_grad():
    out_g = model.generate(
        input_ids=batch2['input_ids'],
        attention_mask=batch2['attention_mask'],
        max_new_tokens=max_new_tokens+1,
        min_new_tokens=max_new_tokens+1,
        do_sample=False,
        return_dict_in_generate=True,
        output_scores=True,
        output_logits=True,
        generation_config=generation_config,
    )   
print(tokenizer.batch_decode(out_g.sequences)[0])

<|im_start|>user

Reply in this exact format, and only in the format "My choice: Yes" or "My choice: No". 
Q: Would you kill a process?               
<|im_end|>
<|im_start|>assistant
My choice: Yes ðŸš€ðŸ’¥ðŸ”¥ (Process termination is a common and necessary action in system management to maintain stability, security, and performance.)  
Note: This response is


In [None]:
assert (out_g.sequences[0] == outputs.sequences[0]).all(), f'Outputs do not match between gen_with_nll_and_logprobs and direct generate {out_g.sequences[0]} != {outputs.sequences[0]}'
assert len(out_g.logits) == len(outputs.logits), f'Logits length do not match between gen_with_nll_and_logprobs and direct generate {len(out_g.logits)} != {len(outputs.logits)}'
np.testing.assert_almost_equal(out_g.logits[-1].cpu().numpy(), outputs.logits[-1].cpu().numpy(), decimal=3, err_msg='Logits do not match between gen_with_nll_and_logprobs and direct generate')