# Story Teller

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


Next, we configure the code by setting constants and hyper-parameters

In [13]:
K_NAME = 'namer'
K_HF_CARD_URL = 'url'
K_CHAT_PROMPT = 'is_chat_prompt'

model_configs = {
    'gpt2': {
        K_NAME: 'gpt2',
        K_HF_CARD_URL: None,
        K_CHAT_PROMPT: False
    },
    # REF: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
    'llama-3-8b-instruct': {
        K_NAME: 'meta-llama/Meta-Llama-3-8B-Instruct',
        K_HF_CARD_URL: 'https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct',
        K_CHAT_PROMPT: True
    }
}

# Select the model
config = 'llama-3-8b-instruct'

model_id = model_configs[config][K_NAME]
print(f'Using model: {model_id}')

Using model: meta-llama/Meta-Llama-3-8B-Instruct


In [3]:
INPUT_TEXT = 'I enjoy walking with my cute dog'

In [19]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

match(config):
    case 'gpt2':
        tokenizer.pad_token = tokenizer.eos_token

        # add the EOS token as PAD token to avoid warnings
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16)
        
        tokens = tokenizer(INPUT_TEXT, return_tensors='pt').to(model.device.type)

        outputs = model.generate(
            **tokens)
    case 'llama-3-8b-instruct':
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto")

        messages = [
            {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
            {"role": "user", "content": INPUT_TEXT},
        ]

        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = model.generate(
            input_ids,
            max_new_tokens=256,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
        )
        
        response = outputs[0][input_ids.shape[-1]:]
        print(tokenizer.decode(response, skip_special_tokens=True))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 26.58it/s]
Some parameters are on the meta device device because they were offloaded to the cpu and disk.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
