In [1]:
from model import ChatModel
from settings import ModelSettings

model = ChatModel(
    vocabulary_size=ModelSettings.vocabulary_size,
    embedding_size=ModelSettings.embedding_size,
    max_context_length=ModelSettings.max_context_length,
    ff_size_multiplier=ModelSettings.ff_size_multiplier,
    transformer_blocks=ModelSettings.transformer_blocks,
    attention_heads=ModelSettings.attention_heads,
    dropout=0.0,
    bias=ModelSettings.bias,
)

using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention


In [2]:
from torchinfo import summary

summary(model)

Layer (type:depth-idx)                             Param #
ChatModel                                          --
├─Embedding: 1-1                                   --
│    └─Embedding: 2-1                              12,000,000
│    └─Embedding: 2-2                              256,000
├─Dropout: 1-2                                     --
├─Sequential: 1-3                                  --
│    └─TransformerBlock: 2-3                       --
│    │    └─Sequential: 3-1                        1,003,000
│    │    └─Sequential: 3-2                        2,003,500
│    └─TransformerBlock: 2-4                       --
│    │    └─Sequential: 3-3                        1,003,000
│    │    └─Sequential: 3-4                        2,003,500
│    └─TransformerBlock: 2-5                       --
│    │    └─Sequential: 3-5                        1,003,000
│    │    └─Sequential: 3-6                        2,003,500
│    └─TransformerBlock: 2-6                       --
│    │    └─Sequential

In [3]:
import torch
import torch.nn as nn

model.eval()

batch_size = 2
context_length = 3
vocabulary_size = ModelSettings.vocabulary_size

In [4]:
# Generate random token sequences
token_ids = torch.randint(0, vocabulary_size, [batch_size, context_length])
print(token_ids.shape)

torch.Size([2, 3])


In [5]:
# Limit the sequences to max context length
token_ids = token_ids[:, -ModelSettings.max_context_length:]
print(token_ids.shape)

torch.Size([2, 3])


In [6]:
# Probabilities of each possible next word per sequence
logits = model(token_ids)
print(logits.shape)

torch.Size([2, 3, 24000])


In [7]:
# Keep only the outputs of the last tokens
logits = logits[:, -1, :]
print(logits.shape)

torch.Size([2, 24000])


In [8]:
# Normalize logits per sequence
probs = nn.functional.softmax(logits, dim=-1)
print(probs.shape)

torch.Size([2, 24000])


In [9]:
# Select the highest probability
next_id = torch.multinomial(probs, num_samples=1)
print(next_id.shape)

torch.Size([2, 1])


In [10]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")

In [11]:
device = "cpu"


@torch.no_grad()
def generate(model, start, max_new_tokens=50):
    model.eval()
    idx = torch.tensor([tokenizer.encode(start).ids], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = nn.functional.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return tokenizer.decode(idx[0].tolist())

In [12]:
print(generate(model, "hello"))

hello Lankaacia separation therapamoto Sen distingu publicly patients 188ral mutuallyvik equally Accessed Determ Elevwoman Advertising townsometimes Josh Scot chem beneath retros chemistry Krish addragestanding lengthsProject MemoirCativar refugeesrong Liberalsistent emerged Wu Wiley Basquecriptocate Shapart Atari titled
