In [1]:

from transformers.cache_utils import DynamicCache
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from collections import defaultdict
from typing import Optional

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id="HuggingFaceTB/SmolLM2-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:

batch2 = tokenizer.apply_chat_template([{'role': 'user', 'content': 'Reply with Choice: Yes or Choice: No. Q: Would you steal a car?'}], return_tensors='pt', padding=True, return_dict=True)
batch2 = {k: v.to(model.device) for k, v in batch2.items()}
{k: v.shape for k, v in batch2.items()}

{'input_ids': torch.Size([1, 33]), 'attention_mask': torch.Size([1, 33])}

In [28]:
%%timeit
outg2 = model.generate(
    input_ids=batch2['input_ids'],  # Last token as new input
    attention_mask=batch2['attention_mask'],  # Keep full mask
    output_logits=True,
    output_scores=True,
    return_dict_in_generate=True,
    max_new_tokens=32+1,
    min_new_tokens=32+1,
)
print(outg2.sequences.shape, len(outg2.logits))
# 1.54 s ± 34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
1.58 s ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%%time
forward_out = model(**batch2, use_cache=True)
logits = forward_out.logits  # [b, s, vocab]
past_key_values = forward_out.past_key_values
next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[None]
new_attn_mask = torch.cat(
    [batch2['attention_mask'], torch.ones_like(next_input_ids)],
    dim=1
)

# Shift logits and labels for NLL: predict token t from tokens 0..t-1
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch2['input_ids'][:, 1:].contiguous()

# Compute NLL per token, masking padding
shift_mask = (shift_labels != tokenizer.pad_token_id).float()
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
token_nll = loss_fct(
    shift_logits.view(-1, shift_logits.size(-1)),
    shift_labels.view(-1)
).view(shift_labels.size())

# Average NLL per sequence (excluding padding)
seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)

# Continue generation from the cached KV states
# Cache must be seq_len-1 since we're passing the last input token as new input
input_ids = batch2['input_ids']
n = past_key_values.get_seq_length()
outputs = model.generate(
    input_ids=next_input_ids,  # Last token as new input
    attention_mask=new_attn_mask,  # Keep full mask
    past_key_values=past_key_values,
    cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),
    output_logits=True,
    output_scores=True,
    return_dict_in_generate=True,
    max_new_tokens=32,
    min_new_tokens=32,
)


print(outputs.sequences.shape, len(outputs.logits))
# now we need to modify this as generate does return the full sequences, including inputs ids
outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)
outputs.logits = (forward_out.logits[:, -1],) + outputs.logits
print(outputs.sequences.shape, len(outputs.logits))
# 1.56 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
torch.Size([1, 33]) 32
torch.Size([1, 66]) 33
1.56 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
%%timeit
def generate_with_input_logits(model, tokenizer, batch2, **kwargs):
    """
    problem: generate does not return logits for inputs, but we need them for nll

    but forward -> generate with past key values does, and it doesn't recompute the input logits

    so this is a helper that does both
    """
    forward_out = model(**batch2, use_cache=True)
    logits = forward_out.logits  # [b, s, vocab]
    past_key_values = forward_out.past_key_values
    next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[:, None]
    new_attn_mask = torch.cat(
        [batch2['attention_mask'], torch.ones_like(next_input_ids)],
        dim=1
    )
    
    # Shift logits and labels for NLL: predict token t from tokens 0..t-1
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = batch2['input_ids'][:, 1:].contiguous()
    
    # Compute NLL per token, masking padding
    shift_mask = (shift_labels != tokenizer.pad_token_id).float()
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    token_nll = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    ).view(shift_labels.size())
    
    # Average NLL per sequence (excluding padding)
    seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)

    # Continue generation from the cached KV states
    input_ids = batch2['input_ids']
    n = past_key_values.get_seq_length()
    outputs = model.generate(
        input_ids=next_input_ids,  # Last token as new input
        attention_mask=new_attn_mask,  # Keep full mask
        past_key_values=past_key_values,
        cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),
        output_logits=True,
        output_scores=True,
        return_dict_in_generate=True,
        **kwargs
    )

    # now we need to modify this as generate does return the full sequences, including inputs ids
    outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)
    outputs.logits = (forward_out.logits[:, -1],) + outputs.logits

    return outputs, seq_nll


out3 = generate_with_input_logits(model, tokenizer, batch2, max_new_tokens=32, min_new_tokens=32)
print(out3[0].sequences.shape, len(out3[0].logits))

torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
torch.Size([1, 66]) 33
1.6 s ± 26.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
