In [2]:
# Import the necessary packages
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Initialize the model and tokenizer, we use gpt-2 as example
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')

In [4]:
# Tokenize the text to the numeric tokens that the model could take as input
tokens = tokenizer(['What should I do on a rainy day without an umbralla?'], return_tensors='pt')

In [5]:
tokens

{'input_ids': tensor([[ 2061,   815,   314,   466,   319,   257, 37259,  1110,  1231,   281,
         23781,  1671, 30315,    30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [6]:
# Pass in necessary parameter 'input_ids' and got the ditectory output
output = model(**tokens)

In [7]:
# For CausalModel, we will got output with logits whose dimension is [batch_size, sequence_length, vocab_size(embedding dimension)]
# - batch_size: represents the number of sequence in a batch, which is 1 for our example 
# - sequence_length: represents the length of output sequence which should be the same with input sequence, the length is the number of tokens
# - vocab_size: more general, the third dimension is embedding dimension, but for LLM transformer model, it's the vaculary size, each element presents the probability of corresponding encoded token.
output.logits.shape

torch.Size([1, 14, 50257])

In [8]:
# If pick up the biggest one in third dimension, we got the generated output tokens
argmax_tokens = output.logits.argmax(axis=-1)
argmax_tokens

tensor([[  318,   356,   466,    30,   616,  4445,  1110,    30,   257, 25510,
           457,   496,    30,   198]])

input tokens:  2061,   815,   314,   466,   319,   257, 37259,  1110,  1231,   281, 23781,  1671, 30315,  30

output tokens:  318,   356,   466,    30,   616,  4445,  1110,    30,   257, 25510,   457,   496,    30,  198

each token in output tokens, presents generated token based on the only access the tokens positioned before it in the sentence.

In [9]:
# We got one new generate token as continuation each time
argmax_tokens[0][-1]

tensor(198)

In [10]:
# If we want to do text generation, we got call the generate() interface instead calculate by ourself
# The following code generate new 8 tokens
generated_tokens = model.generate(**tokens, max_new_tokens=8, do_sample=False)
generated_new_tokens = generated_tokens[0][len(tokens['input_ids'][0]):]
print(f'generated_tokens: {generated_tokens}')
print(f'generated_new_tokens: {generated_new_tokens}')
tokenizer.decode(*generated_tokens)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


generated_tokens: tensor([[ 2061,   815,   314,   466,   319,   257, 37259,  1110,  1231,   281,
         23781,  1671, 30315,    30,   198,   198,   464,  3280,   318,    25,
           645,    13]])
generated_new_tokens: tensor([ 198,  198,  464, 3280,  318,   25,  645,   13])


'What should I do on a rainy day without an umbralla?\n\nThe answer is: no.'

In [15]:
# If you want to implement that generate and decode logic by yourself, you could do as following
def generate_one_new_token(input_tokens, model):
    generated_tokens = model(input_tokens).logits.argmax(-1)
    # Our batch_size is 1, so take the first batch and last position of sequence
    generated_token_see_all_input = generated_tokens[0][-1:]
    updated_input_tokens = torch.concat((input_tokens[0], generated_token_see_all_input), dim=0).unsqueeze(0)
    print(f'input tokens length: {len(updated_input_tokens[0])}')
    return generated_token_see_all_input, updated_input_tokens

updated_input_tokens = tokens['input_ids']
for i in range(0, 8):
    generated_new_token, updated_input_tokens = generate_one_new_token(updated_input_tokens, model)
    print(generated_new_token)

input tokens length: 15
tensor([198])
input tokens length: 16
tensor([198])
input tokens length: 17
tensor([464])
input tokens length: 18
tensor([3280])
input tokens length: 19
tensor([318])
input tokens length: 20
tensor([25])
input tokens length: 21
tensor([645])
input tokens length: 22
tensor([13])
