In [9]:
import numpy as np
import torch
import torch.nn as nn

from settings import ModelSettings

document = np.arange(1100)
context_length = ModelSettings.max_context_length
print(document)

[   0    1    2 ... 1097 1098 1099]


In [10]:
# Example
example_context_length = 10
x = document[:example_context_length]
y = document[1:example_context_length + 1]
print(x)
print(y)

[0 1 2 3 4 5 6 7 8 9]
[ 1  2  3  4  5  6  7  8  9 10]


In [None]:
model_path = "mini_model_training.pth"
data_dir = "tokenized_data"

In [11]:
import os

block_size = 128  # smaller context size
batch_size = 16
device = "cpu"


def get_mini_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(min(len(data), 1000) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [12]:
x, y = get_mini_batch("test")
print(x.shape, y.shape)

torch.Size([16, 128]) torch.Size([16, 128])


In [13]:
from model import ChatModel
from settings import ModelSettings

model = ChatModel(
    vocabulary_size=ModelSettings.vocabulary_size,
    embedding_size=256,
    max_context_length=block_size,
    ff_size_multiplier=4,
    transformer_blocks=6,
    attention_heads=8,
    dropout=0.0,
    bias=True,
    residual_scaling=True,
    weight_tying=True
)

In [14]:
from optimizer import get_optim_groups

optim_groups = get_optim_groups(model)

optimizer = torch.optim.AdamW(
    optim_groups,
    lr=3e-4,
    betas=(0.9, 0.95),
    # b1: how fast the model adapts to the average cale of the gradients, higher=slower
    # b2: how fast the weights change, higher=slower
    eps=1e-8
)


In [15]:
# if os.path.exists(model_path):
#     model.load_state_dict(
#         torch.load(
#             model_path,
#             weights_only=True,
#             map_location=torch.device('cpu') if device == "cpu" else None
#         )
#     )
#     print("Mode weights loaded")

In [16]:
from math import exp

best_loss = float("inf")
patience = 5  # number of evaluations to wait
min_delta = 0.02  # minimum improvement
patience_counter = 0
grad_clip = 1.0

for step in range(50000):
    xb, yb = get_mini_batch("test")

    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    # Prevent too large gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

    if step % 10 == 0:
        loss_num = loss.item()
        perplexity = exp(loss_num)
        print(f"step {step}, loss {loss_num:.4f}, perplexity {perplexity:.2f}")
        if best_loss - loss_num > min_delta:
            best_loss = loss_num
            patience_counter = 0
            torch.save(model.state_dict(), model_path)
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered")
            break

step 0, loss 10.1259, perplexity 24981.02
step 10, loss 8.4881, perplexity 4856.44
step 20, loss 7.4178, perplexity 1665.44
step 30, loss 6.2573, perplexity 521.80


KeyboardInterrupt: 

In [None]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")

In [None]:
@torch.no_grad()
def generate(model, start, max_new_tokens=50):
    model.eval()
    idx = torch.tensor([tokenizer.encode(start).ids], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = nn.functional.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return tokenizer.decode(idx[0].tolist())

In [None]:
start_token_id = get_mini_batch("test")[0][0][0].item()
start_text = tokenizer.decode([start_token_id])
print(generate(model, start_text))