In [1]:
import torch
import torch.nn as nn
from tokenizers import Tokenizer
from settings import ModelSettings
from model import ChatModel
from special_tokens import special_tokens

tokenizer = Tokenizer.from_file("tokenizer.json")
device = "cpu"


@torch.no_grad()
def generate(model, start, max_new_tokens=50, temperature=0.3, top_k=10, argmax=False, stop_tokens=None):
    idx = torch.tensor(
        [tokenizer.encode(start, add_special_tokens=False).ids],
        device=device,
        dtype=torch.long,
    )

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length :]
        logits = model(idx_cond)
        logits = logits[:, -1, :]

        top_logits, top_pos=torch.topk(logits,top_k)
        logits=torch.where(
            logits<top_logits[:,-1],
            input=torch.tensor(float("-inf")),
            other=logits
        )

        probs = nn.functional.softmax(logits / temperature, dim=-1)


        if argmax:
            next_id = torch.argmax(probs, dim=-1, keepdim=True)
        else:
            next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

        if stop_tokens and idx[0][-1].tolist() in stop_tokens:
            print("Reached stop token")
            break

    return tokenizer.decode(idx[0].tolist())

In [2]:
minified=False 

if not minified:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=ModelSettings.embedding_size,
        max_context_length=ModelSettings.max_context_length,
        ff_size_multiplier=ModelSettings.ff_size_multiplier,
        transformer_blocks=ModelSettings.transformer_blocks,
        attention_heads=ModelSettings.attention_heads,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )
else:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=64,
        max_context_length=64,
        ff_size_multiplier=2,
        transformer_blocks=4,
        attention_heads=4,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )

using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention


In [3]:
step=100_000
state = torch.load(f"output_100k/pre_checkpoints/state/{step:05d}.pt",map_location=torch.device('cpu'))

In [4]:
state_dict=state["model"]
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

In [5]:
model.load_state_dict(state["model"])

<All keys matched successfully>

In [6]:
print(generate(model, "Hello",200))

Hello, I'm so happy to see you!"

The little girl smiled and said, "Thank you!"

The little girl was so excited that she ran off to tell her mom. She was so happy to be able to help the little girl.
Once upon a time there was a little girl named Jane. She was three years old and loved to explore.

One day, Jane was walking in the park when she saw a big tree. She wanted to climb it, but it was too high for her to reach.

"Mommy, can you help me?" Jane asked.

"Yes, I can help you," said Mommy.

Mommy climbed up the tree and Jane was so happy to be able to climb. She felt so proud of herself for being so brave.

"Thank you Mommy!" Jane said.

Mommy smiled and said, "You're welcome, Jane. I'm glad I could help you."



In [7]:
print(generate(model, "Once upon a time",200))

Once upon a time, there was a little girl named Lily. She loved to play outside in the garden. One day, she saw a big, scary shadow. It was so big that she couldn't see very well. 

Lily's mommy told her that it was just a shadows. But Lily didn't understand why it was so scary. She thought it was just a big, scary shadow. 

Later that day, Lily went to the store with her mommy. She saw a toy that she really wanted. She asked her mommy if she could buy it. Her mommy said yes and they bought it. 

Lily was so happy that she could buy the toy she wanted. She played with it all day and showed it to all her friends. They all thought it was so cool. From that day on, Lily never felt scared of big, scary shadow again.
Once upon a time, there was a little girl named Lily. She loved to play outside in the sun and


In [8]:
stop_tokens= tokenizer.encode(special_tokens["end_of_turn"]+ special_tokens["eos"],add_special_tokens=False).ids
print(stop_tokens)

[1, 3]


In [9]:
print(
    generate(
        model,
        special_tokens["bos"]
        +special_tokens["user"]
        + "Can foxes fit down rabbit burrows?"
        + special_tokens["end_of_turn"]
        + "\n"
        + special_tokens["assistant"],
        200,
        stop_tokens=stop_tokens
    )
)

Can foxes fit down rabbit burrows?
Foxes are a great way to get rid of the foxes. They can be very fast, but they can also be very difficult to catch. Some foxes have a lot of teeth and can be very fast, but they can also be very fast. Foxes are very fast and can be very fast. They can be very fast, but they can also be very fast. Foxes are very fast and can be very fast. They can be very fast, but they can also be very fast. Foxes are very fast and can be very fast. They can be very fast, but they can also be very fast. Foxes are very fast, but they can also be very fast. Foxes are very fast, but they can also be very fast. Foxes are very fast, but they can be very fast. Foxes are very fast, but they can be very fast. Foxes are very fast, but they can be very fast. Foxes are very fast


In [10]:
print(
    generate(
        model,
        special_tokens["bos"]
        +special_tokens["user"]
        + "Hello"
        + special_tokens["end_of_turn"]
        + "\n"
        + special_tokens["assistant"],
        200,
        stop_tokens=stop_tokens
    )
)

Reached stop token
Hello
Hello, I'm looking for a friend. What's your name?"


In [11]:
from datasets import load_from_disk

ds=load_from_disk("tokenized_data/robots_test").take(5)
for m in ds:
    print(list(map(tokenizer.id_to_token,m["tokens"])))

FileNotFoundError: [Errno 2] No such file or directory: '/home/peter/PycharmProjects/autoregressive_playground/instruction_following/tokenized_data/robots_test/data-00000-of-00001.arrow'