In [1]:
from math import exp

import numpy as np

from settings import ModelSettings

document = np.arange(1100)
context_length = ModelSettings.max_context_length
print(document)

[   0    1    2 ... 1097 1098 1099]


In [2]:
# Example
example_context_length = 10
x = document[:example_context_length]
y = document[1:example_context_length + 1]
print(x)
print(y)

[0 1 2 3 4 5 6 7 8 9]
[ 1  2  3  4  5  6  7  8  9 10]


In [3]:
import torch
import os

block_size = 128  # smaller context size
batch_size = 4
device = "cpu"
data_dir = "tokenized_data"


def get_mini_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(min(len(data), 1000) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [4]:
x, y = get_mini_batch("test")
print(x.shape, y.shape)

torch.Size([4, 128]) torch.Size([4, 128])


In [5]:
from model import ChatModel
from settings import ModelSettings

model = ChatModel(
    vocabulary_size=ModelSettings.vocabulary_size,
    embedding_size=128,
    embedding_dropout=0.1,
    attention_dropout=0.1,
    max_context_length=128,
    ff_size_multiplier=2,
    ff_dropout=0.1,
    transformer_blocks=4,
    attention_heads=8
)

In [6]:
from optimizer import get_optim_groups

optim_groups = get_optim_groups(model)

optimizer = torch.optim.AdamW(
    optim_groups,
    lr=3e-4,
    betas=(0.9, 0.95),
    eps=1e-8
)


In [7]:
model_path = "mini_model_training.pth"
if os.path.exists(model_path):
    model.load_state_dict(
        torch.load(
            model_path,
            weights_only=True,
            map_location=torch.device('cpu') if device == "cpu" else None
        )
    )
    print("Mode weights loaded")

Mode weights loaded


In [11]:
best_loss = float("inf")
patience = 5  # number of evaluations to wait
min_delta = 0.01  # minimum improvement
patience_counter = 0
torch.autograd.set_detect_anomaly(True)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for step in range(50000):
    xb, yb = get_mini_batch("test")

    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 2 == 0:
        loss_num = loss.item()
        perplexity = exp(loss_num)
        print(f"step {step}, loss {loss_num:.4f}, perplexity {perplexity:.2f}")
        if best_loss - loss_num > min_delta:
            best_loss = loss_num
            patience_counter = 0
            torch.save(model.state_dict(), model_path)
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered")
            break

step 0, loss 0.9828, perplexity 2.67
step 2, loss 1.1287, perplexity 3.09
step 4, loss 1.1760, perplexity 3.24
step 6, loss 1.1727, perplexity 3.23
step 8, loss 1.0478, perplexity 2.85
step 10, loss 0.9526, perplexity 2.59
step 12, loss 1.0008, perplexity 2.72
step 14, loss 0.9968, perplexity 2.71
step 16, loss 0.8524, perplexity 2.35
step 18, loss 1.1124, perplexity 3.04
step 20, loss 1.0477, perplexity 2.85
step 22, loss 0.8813, perplexity 2.41
step 24, loss 0.9239, perplexity 2.52
step 26, loss 1.0363, perplexity 2.82
Early stopping triggered


In [None]:
import torch
import torch.nn as nn

device = "cpu"


@torch.no_grad()
def generate(model, start, max_new_tokens=50):
    model.eval()
    idx = torch.tensor([tokenizer.encode(start).ids], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = nn.functional.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return tokenizer.decode(idx[0].tolist())

In [None]:
print(generate(model, "hello"))