In [None]:
import os

import numpy as np
import torch
import torch.nn as nn

block_size = 128  # smaller context size
batch_size = 16
data_dir = "tokenized_data"
learning_rate= 6e-4
min_lr = 6e-5
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device="cuda" if torch.cuda.is_available() else "cpu"
device_type = 'cuda' if 'cuda' in device else 'cpu'

def get_mini_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(min(len(data), 1000) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [None]:
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint(step, model, optimizer, scheduler, scaler, train_loss, val_loss):
    checkpoint = {
        "step": step,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict() if scaler else None,
        "train_loss": train_loss,
        "val_loss": val_loss
    }

    torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_{step}.pt")

In [None]:
from model import ChatModel
from settings import ModelSettings

model = ChatModel(
    vocabulary_size=ModelSettings.vocabulary_size,
    embedding_size=256,
    embedding_dropout=0.0,
    attention_dropout=0.0,
    max_context_length=block_size,
    ff_size_multiplier=4,
    ff_dropout=0.0,
    transformer_blocks=6,
    attention_heads=8
)

In [None]:
from optimizer import get_optim_groups

optim_groups = get_optim_groups(model)

optimizer = torch.optim.AdamW(
    optim_groups,
    lr=3e-4,
    betas=(0.9, 0.95),
    eps=1e-8
)


In [None]:
def load_checkpoint():
    checkpoint = torch.load(f"{checkpoint_dir}/checkpoint_{step}.pt")
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scaler.load_state_dict(checkpoint["scaler"])
    # step = checkpoint["step"]


In [None]:
from math import exp

best_loss = float("inf")
patience = 5  # number of evaluations to wait
min_delta = 0.02  # minimum improvement
patience_counter = 0
grad_clip = 1.0

for step in range(50000):
    xb, yb = get_mini_batch("test")

    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    # Prevent too large gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

    if step % 10 == 0:
        loss_num = loss.item()
        perplexity = exp(loss_num)
        print(f"step {step}, loss {loss_num:.4f}, perplexity {perplexity:.2f}")
        if best_loss - loss_num > min_delta:
            best_loss = loss_num
            patience_counter = 0
            save_checkpoint(
                step=step,
                model= model,
                optimizer= optimizer,
                scheduler= scheduler.state_dict(),
                scaler= scaler.state_dict() if scaler else None,
                train_loss= loss,
                val_loss= 0
            )
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered")
            break

In [None]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")

In [None]:
@torch.no_grad()
def generate(model, start, max_new_tokens=50):
    model.eval()
    idx = torch.tensor([tokenizer.encode(start).ids], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = nn.functional.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return tokenizer.decode(idx[0].tolist())

In [None]:
start_token_id = get_mini_batch("test")[0][0][0].item()
start_text = tokenizer.decode([start_token_id])
print(generate(model, start_text))