# Story Teller

In [50]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

Next, we configure the code by setting constants and hyper-parameters

In [61]:
K_NAME = 'namer'
K_HF_CARD_URL = 'url'
K_CHAT_PROMPT = 'is_chat_prompt'

model_configs = {
    'gpt2': {
        K_NAME: 'gpt2',
        K_HF_CARD_URL: 'https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct',
        K_CHAT_PROMPT: False
    },
    'llama-3-8b-instruct': {
        K_NAME: 'meta-llama/Meta-Llama-3-8B-Instruct',
        K_HF_CARD_URL: '',
        K_CHAT_PROMPT: True
    }
}

config = 'gpt2'

model_id = model_configs[config][K_NAME]
print(f'Using model: {model_id}')

Using model: gpt2


In [52]:
INPUT_TEXT = 'I enjoy walking with my cute dog'

In [63]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [62]:
# add the EOS token as PAD token to avoid warnings
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    pad_token_id=tokenizer.eos_token_id)

model.device.type

'cpu'

In [59]:
# encode context the generation is conditioned on
if model_configs[config][K_CHAT_PROMPT]:
    messages = [
        {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
        {"role": "user", "content": INPUT_TEXT},
    ]
    tokens = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
else:
    tokens = tokenizer(INPUT_TEXT, return_tensors='pt').to(model.device.type)
tokens

tensor([[ 1639,   389,   257, 25868,  8537, 13645,   508,  1464, 20067,   287,
         25868,  2740,     0, 50256,    40,  2883,  6155,   351,   616, 13779,
          3290, 50256]])

In [60]:
# generate 40 new tokens
output = model.generate(**tokens, max_new_tokens=40, output_attentions=True)
output

TypeError: transformers.generation.utils.GenerationMixin.generate() argument after ** must be a mapping, not Tensor

In [None]:
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output[0], skip_special_tokens=True))