In [182]:
import torch
import torch.nn as nn
from tokenizers import Tokenizer

from chat_template import chat_template
from model import ChatModel
from settings import ModelSettings
from special_tokens import special_tokens

tokenizer = Tokenizer.from_file("tokenizer.json")
device = "cpu"


@torch.no_grad()
def generate(model, start, max_new_tokens=50, temperature=0.3, top_k=10, argmax=False, stop_tokens=None):
    idx = torch.tensor(
        [tokenizer.encode(start, add_special_tokens=False).ids],
        device=device,
        dtype=torch.long,
    )

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]

        top_logits, top_pos = torch.topk(logits, top_k)
        logits = torch.where(
            logits < top_logits[:, -1],
            input=torch.tensor(float("-inf")),
            other=logits
        )

        probs = nn.functional.softmax(logits / temperature, dim=-1)

        if argmax:
            next_id = torch.argmax(probs, dim=-1, keepdim=True)
        else:
            next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

        if stop_tokens and next_id.item() in stop_tokens:
            print("Reached stop token")
            break

    return tokenizer.decode(idx[0].tolist())

In [183]:
minified = True

if not minified:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=ModelSettings.embedding_size,
        max_context_length=ModelSettings.max_context_length,
        ff_size_multiplier=ModelSettings.ff_size_multiplier,
        transformer_blocks=ModelSettings.transformer_blocks,
        attention_heads=ModelSettings.attention_heads,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )
else:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=64,
        max_context_length=64,
        ff_size_multiplier=2,
        transformer_blocks=4,
        attention_heads=4,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )

using flash attention
using flash attention
using flash attention
using flash attention


In [184]:
step = 999
state = torch.load(f"instruction_checkpoints/state/{step:05d}.pt", map_location=torch.device('cpu'))

In [185]:
state_dict = state["model"]
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

In [186]:
model.load_state_dict(state["model"])

<All keys matched successfully>

In [187]:
stop_tokens = tokenizer.encode(special_tokens["end_of_turn"] + special_tokens["eos"], add_special_tokens=False).ids
print(stop_tokens)

[1, 3]


In [188]:
def chat_completion(user_text):
    return generate(
        model,
        chat_template([
            {"role": "user", "content": user_text}
        ]),
        max_new_tokens=200,
        stop_tokens=stop_tokens
    )

In [194]:
chat_completion("Briefly describe gravity")

Reached stop token


'Briefly describe gravity\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'