In [65]:
import torch
import torch.nn as nn
from tokenizers import Tokenizer
from settings import ModelSettings
from model import ChatModel
from special_tokens import special_tokens

tokenizer = Tokenizer.from_file("tokenizer.json")
device = "cpu"


@torch.no_grad()
def generate(model, start, max_new_tokens=50, temperature=0.3, top_k=10, argmax=False, stop_tokens=None):
    idx = torch.tensor(
        [tokenizer.encode(start, add_special_tokens=False).ids],
        device=device,
        dtype=torch.long,
    )

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length :]
        logits = model(idx_cond)
        logits = logits[:, -1, :]

        top_logits, top_pos=torch.topk(logits,top_k)
        logits=torch.where(
            logits<top_logits[:,-1],
            input=torch.tensor(float("-inf")),
            other=logits
        )

        probs = nn.functional.softmax(logits / temperature, dim=-1)


        if argmax:
            next_id = torch.argmax(probs, dim=-1, keepdim=True)
        else:
            next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

        if stop_tokens and idx[0][-1].tolist() in stop_tokens:
            print("Reached stop token")
            break

    return tokenizer.decode(idx[0].tolist())

In [39]:
minified=False 

if not minified:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=ModelSettings.embedding_size,
        max_context_length=ModelSettings.max_context_length,
        ff_size_multiplier=ModelSettings.ff_size_multiplier,
        transformer_blocks=ModelSettings.transformer_blocks,
        attention_heads=ModelSettings.attention_heads,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )
else:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=64,
        max_context_length=64,
        ff_size_multiplier=2,
        transformer_blocks=4,
        attention_heads=4,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )

using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention


In [40]:
step=100_000
state = torch.load(f"output_100k/pre_checkpoints/state/{step:05d}.pt",map_location=torch.device('cpu'))

In [41]:
state_dict=state["model"]
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

In [42]:
model.load_state_dict(state["model"])

<All keys matched successfully>

In [43]:
print(generate(model, "Hello",200))

Hello, I'm so happy to see you!"

The little girl smiled and said, "Thank you, Mr. Bear! I love you!"

The little girl and Mr. Bear hugged each other and the little girl said, "I love you too, Mr. Bear!"
Once upon a time, there was a little girl named Lily. She loved to play outside in the sun. One day, she saw a big, scary dog. The dog was very big and had sharp teeth. Lily was scared and didn't know what to do.

But then, she saw a man walking by. He saw the dog and said, "Don't worry, little one. I'll help you." The man took the dog away and said, "You're safe now."

Lily was happy and thanked the man. She went back outside to play with her toys. She was glad that the man helped her. From that day on, she always remembered


In [44]:
print(generate(model, "Once upon a time",200))

Once upon a time, there was a little girl named Lily. She loved to play outside in the sun. One day, she saw a big, scary dog. The dog was very loud and scary.

Lily was scared and didn't know what to do. She wanted to run away, but the dog was too fast. She tried to run away, but the dog was too strong.

Lily's mom came outside and saw what was happening. She told Lily that the dog was just a big, scary dog. She said that the dog was just a big, scary dog.

Lily was happy that the dog was gone. She went back inside and told her mom about the scary dog. Her mom said that the dog was just a big, scary dog. Lily was happy that she got to see a scary dog.
Once upon a time, there was a little girl named Lily. She loved to play outside in the sun. One day, she saw a big,


In [54]:
stop_tokens= tokenizer.encode(special_tokens["end_of_turn"]+ special_tokens["eos"],add_special_tokens=False).ids
print(stop_tokens)

[1, 3]


In [74]:
print(
    generate(
        model,
        special_tokens["bos"]
        +special_tokens["user"]
        + "Can foxes fit down rabbit burrows?"
        + special_tokens["end_of_turn"]
        + "\n"
        + special_tokens["assistant"],
        200,
        stop_tokens=stop_tokens
    )
)

Reached stop token
Can foxes fit down rabbit burrows?
Foxes are small, but they can fit in the same size and shape. They are also very small, but they can fit in the same shape and shape.


In [73]:
print(
    generate(
        model,
        special_tokens["bos"]
        +special_tokens["user"]
        + "Hello"
        + special_tokens["end_of_turn"]
        + "\n"
        + special_tokens["assistant"],
        200,
        stop_tokens=stop_tokens
    )
)

Reached stop token
Hello
Hello, hello! I'm Sarah. I'm a little girl. I'm looking for my mommy. I'm looking for my mommy. I can't find her anywhere!"

Sarah was so excited to see her mommy. She ran to her mommy and asked, "Mommy, where's my mommy?"

Mommy replied, "I don't know, Sarah. I'm sorry I can't find my mommy."

Sarah was sad and asked, "Why can't I find my mommy?"

Mommy said, "I don't know, Sarah. I think I can't help you."

Sarah was so upset that she started to cry. She wanted to find her mommy, but she couldn't. She was so upset that she ran away from the house.

Mommy was so sad that she never saw Sarah again. She was so sad that she never saw Sarah again.



In [71]:
from datasets import load_from_disk

ds=load_from_disk("tokenized_data/robots_test").take(5)
for m in ds:
    print(list(map(tokenizer.id_to_token,m["tokens"])))

['<|bos|>', '<|system|>', 'A', 'ster', 'Ġis', 'Ġa', 'Ġchat', 'bot', 'Ġwho', 'Ġanswers', 'Ġquestions', 'Ġwith', 'Ġrh', 'ym', 'es', '.', '<|endofturn|>', 'Ċ', '<|user|>', 'Where', 'Ġdid', 'Ġchocolate', 'Ġoriginate', '?', '<|endofturn|>', 'Ċ', '<|assistant|>', 'Ch', 'ocolate', 'Ġis', 'Ġ4', '000', 'Ġyears', 'Ġold', '/', 'Mex', 'ico', 'Ġis', 'Ġwhere', 'Ġit', 'Ġwas', 'Ġfirst', 'Ġsold', '<|endofturn|>', 'Ċ', '<|user|>', 'Where', 'Ġwas', 'Ġmilk', 'Ġchocolate', 'Ġinvented', '?', '<|endofturn|>', 'Ċ', '<|assistant|>', 'Sw', 'itzerland', 'Ġwas', 'Ġthe', 'Ġfirst', 'Ġto', 'Ġadd', 'Ġmilk', '/', 'To', 'Ġmake', 'Ġtheir', 'Ġchocolate', 'Ġsmooth', 'Ġas', 'Ġsilk', '<|endofturn|>', 'Ċ', '<|user|>', 'What', 'Ġare', 'Ġsome', 'Ġgood', 'Ġdess', 'erts', 'Ġthat', 'Ġuse', 'Ġchocolate', '?', '<|endofturn|>', 'Ċ', '<|assistant|>', 'P', 'ie', ',', 'Ġt', 'art', ',', 'Ġcookies', ',', 'Ġand', 'Ġcake', '/', 'Ch', 'ocolate', 'Ġis', 'Ġgreat', 'Ġto', 'Ġb', 'ake', '<|endofturn|>', 'Ċ', '<|eos|>']
['<|bos|>', '<|user|>', 'W