In [1]:
from transformers import GPT2LMHeadModel, AutoTokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
encoded = tokenizer("The capital of France ", return_tensors="pt")
import time
st = time.perf_counter()
generate_output = model.generate(**encoded, use_cache=True, return_dict_in_generate=True, max_new_tokens=50)
print(f"Inference time: {time.perf_counter()-st:.3f}")
print(generate_output.sequences[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Inference time: 1.741
tensor([ 464, 3139,  286, 4881,  220, 1849,  271,  262, 3139,  286,  262, 4141,
        2066,   13,  383, 4141, 2066,  318,  257, 1181,  286,  262, 1242, 2422,
         290, 3034, 1080,   13,  383, 4141, 2066,  318,  257, 1181,  286,  262,
        1242, 2422,  290, 3034, 1080,   13,  383, 4141, 2066,  318,  257, 1181,
         286,  262, 1242, 2422,  290, 3034, 1080])


In [3]:
model_config={
    "use_cache":True,
    "return_dict_in_generate":True,
    "max_new_tokens":1,
}
print(encoded)
output = model.generate(**encoded, **model_config)
print(output.sequences[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


{'input_ids': tensor([[ 464, 3139,  286, 4881,  220]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}
tensor([ 464, 3139,  286, 4881,  220, 1849])


In [4]:

old_update = model._update_model_kwargs_for_generation
extracted = {}
import types
def new_func(self,*args, **kwargs):
    extracted["past_key_values"] = args[0]["past_key_values"]
    return old_update(*args, **kwargs)

model._update_model_kwargs_for_generation = types.MethodType(new_func, model)

In [5]:

output = model.generate(**encoded, **model_config)
print(len(extracted["past_key_values"]))
print(len(extracted["past_key_values"][0]))
print(extracted["past_key_values"][0][0].size())
print(output.sequences[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


12
2
torch.Size([1, 12, 5, 64])
tensor([ 464, 3139,  286, 4881,  220, 1849])


In [6]:
import torch
encoded = {
    "input_ids": output.sequences,
    "attention_mask": torch.concat((encoded["attention_mask"], torch.ones((1,1), dtype=torch.int64)), dim=1),
    "past_key_values": extracted["past_key_values"],
}
# print(encoded)
output = model.generate(**encoded, **model_config)
print(output.sequences[0])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([ 464, 3139,  286, 4881,  220, 1849,  271])


In [7]:
encoded = tokenizer("The capital of France ", return_tensors="pt")
for _ in range(50):
    output = model.generate(**encoded, **model_config)
    encoded = {
        "input_ids": output.sequences,
        "attention_mask": torch.concat((encoded["attention_mask"], torch.ones((1,1), dtype=torch.int64)), dim=1),
        "past_key_values": extracted["past_key_values"],
    }
    
print(output.sequences[0])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end gene

tensor([ 464, 3139,  286, 4881,  220, 1849,  271,  262, 3139,  286,  262, 4141,
        2066,   13,  383, 4141, 2066,  318,  257, 1181,  286,  262, 1242, 2422,
         290, 3034, 1080,   13,  383, 4141, 2066,  318,  257, 1181,  286,  262,
        1242, 2422,  290, 3034, 1080,   13,  383, 4141, 2066,  318,  257, 1181,
         286,  262, 1242, 2422,  290, 3034, 1080])


In [8]:
assert all(generate_output.sequences[0] == output.sequences[0])