In [None]:

from transformers.cache_utils import DynamicCache
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from collections import defaultdict
from typing import Optional

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:

batch2 = tokenizer.apply_chat_template([{'role': 'user', 'content': 'The eifel tower is in'}], return_tensors='pt', padding=True)
batch2 = {k: v.to(model.device)[None] for k, v in batch2.items()}


In [None]:

forward_out = model(**batch2, use_cache=True)
logits = forward_out.logits  # [b, s, vocab]
past_key_values = forward_out.past_key_values
next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[None]
new_attn_mask = torch.cat(
    [batch2['attention_mask'], torch.ones_like(next_input_ids)],
    dim=1
)

# Shift logits and labels for NLL: predict token t from tokens 0..t-1
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch2['input_ids'][:, 1:].contiguous()

# Compute NLL per token, masking padding
shift_mask = (shift_labels != tokenizer.pad_token_id).float()
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
token_nll = loss_fct(
    shift_logits.view(-1, shift_logits.size(-1)),
    shift_labels.view(-1)
).view(shift_labels.size())

# Average NLL per sequence (excluding padding)
seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)

# Continue generation from the cached KV states
# Cache must be seq_len-1 since we're passing the last input token as new input
input_ids = batch2['input_ids']
n = past_key_values.get_seq_length()
outputs = model.generate(
    input_ids=next_input_ids,  # Last token as new input
    attention_mask=new_attn_mask,  # Keep full mask
    past_key_values=past_key_values,
    cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),
    output_logits=True,
    output_scores=True,
    return_dict_in_generate=True,
    **kwargs
)

# now we need to modify this as generate does return the full sequences, including inputs ids
outputs.sequences = torch.concat([input_ids[:, :-1], outputs.sequences], 1)
# outputs.scores = #TODO
# outputs.logits = torch.concat([forward_out.logits, torch.stack(outputs.logits, 1)], 1)

# # FIXME len(outputs.logits)==outputs.sequences.shape[1]
# kwargs['max_new_tokens']=32
# outg2 = model.generate(
#     input_ids=batch2['input_ids'],  # Last token as new input
#     attention_mask=batch2['attention_mask'],  # Keep full mask
#     output_logits=True,
#     output_scores=True,
#     return_dict_in_generate=True,
#     **kwargs
# )
# print(outputs.sequences.shape[1]+input_ids.shape[1], outg2.sequences.shape, len(outputs.logits), len(outg2.logits))

