In [1]:
from model import ChatModel
from settings import ModelSettings

model = ChatModel(
    ModelSettings.vocabulary_size,
    ModelSettings.embedding_size,
    ModelSettings.embedding_dropout,
    ModelSettings.attention_dropout,
    ModelSettings.max_context_length,
    ModelSettings.ff_size_multiplier,
    ModelSettings.ff_dropout,
    ModelSettings.transformer_blocks,
    ModelSettings.attention_heads
)

In [2]:
from torchinfo import summary

summary(model)

Layer (type:depth-idx)                                  Param #
ChatModel                                               --
├─Sequential: 1-1                                       --
│    └─Embedding: 2-1                                   --
│    │    └─Embedding: 3-1                              18,432,000
│    │    └─Embedding: 3-2                              768,000
│    └─Dropout: 2-2                                     --
│    └─Sequential: 2-3                                  --
│    │    └─TransformerBlock: 3-3                       7,085,568
│    │    └─TransformerBlock: 3-4                       7,085,568
│    │    └─TransformerBlock: 3-5                       7,085,568
│    │    └─TransformerBlock: 3-6                       7,085,568
│    │    └─TransformerBlock: 3-7                       7,085,568
│    │    └─TransformerBlock: 3-8                       7,085,568
│    │    └─TransformerBlock: 3-9                       7,085,568
│    │    └─TransformerBlock: 3-10              

In [3]:
import torch
import torch.nn as nn

model.eval()

batch_size = 2
context_length = 3
vocabulary_size = ModelSettings.vocabulary_size

In [4]:
# Generate random token sequences
token_ids = torch.randint(0, vocabulary_size, [batch_size, context_length])
print(token_ids.shape)

torch.Size([2, 3])


In [5]:
# Limit the sequences to max context length
token_ids = token_ids[:, -ModelSettings.max_context_length:]
print(token_ids.shape)

torch.Size([2, 3])


In [6]:
# Probabilities of each possible next word per sequence
logits = model(token_ids)
print(logits.shape)

torch.Size([2, 3, 24000])


In [7]:
# Keep only the outputs of the last tokens
logits = logits[:, -1, :]
print(logits.shape)

torch.Size([2, 24000])


In [8]:
# Normalize logits per sequence
probs = nn.functional.softmax(logits, dim=-1)
print(probs.shape)

torch.Size([2, 24000])


In [9]:
# Select the highest probability
next_id = torch.multinomial(probs, num_samples=1)
print(next_id.shape)

torch.Size([2, 1])


In [10]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")

In [13]:
device = "cpu"


@torch.no_grad()
def generate(model, start, max_new_tokens=50):
    model.eval()
    idx = torch.tensor([tokenizer.encode(start).ids], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = nn.functional.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return tokenizer.decode(idx[0].tolist())

In [14]:
print(generate(model, "hello"))

hello very Wan combustionscreen 119 Murderording ammunition Kot MIT God extinction tasked breakingnder jurist cogn bayloo collaborator Luxfielder theoretical Marc classic superf installationoria 1798 191founder arranger noddedoster hot Ident aggressive Mess device Clarke Startsychitchomansasc piv currentsالTC congressional
