In [29]:
import os

import numpy as np
import torch
import torch.nn as nn

from settings import ModelSettings

Mode settings


In [30]:
minified = True
colab = False
checkpoint: int | None = None
compile = True

Paths

In [31]:
if colab:
    data_dir = "/content/drive/MyDrive"
    checkpoint_dir = "/content/drive/MyDrive/pre_checkpoints"
else:
    data_dir = "tokenized_data"
    checkpoint_dir = "pre_checkpoints"
info_dir=checkpoint_dir+"/info"
state_dir=checkpoint_dir+"/state"

General settings

In [32]:
if not minified:
    # Training data
    block_size = ModelSettings.max_context_length
    batch_size = 32

    # Learning
    max_iters = 600000  # total number of training iterations
    learning_rate = 6e-4
    min_lr = 6e-5
    lr_decay_steps = max_iters  # should be ~= max_iters per Chinchilla
    warmup_steps = 2000
    eval_iters = 200
    eval_interval = 2000
    grad_clip = 1.0
else:
    # Training data
    block_size = 64
    batch_size = 8

    # Learning
    max_iters = 600  # total number of training iterations
    learning_rate = 6e-3
    min_lr = 6e-4
    lr_decay_steps = max_iters  # should be ~= max_iters per Chinchilla
    warmup_steps = 60
    eval_iters = 2
    eval_interval = 10
    grad_clip = 1.0

Hardware settings

In [33]:
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
device = "cuda" if torch.cuda.is_available() else "cpu"
device_type = 'cuda' if 'cuda' in device else 'cpu'
autocast_enabled = device_type == "cuda"
print(device)

cpu


Training data stream

In [34]:
def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin' if not minified else "test.bin"), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r')
    if not minified:
        ix = torch.randint(len(data) - block_size, (batch_size,))
    else:
        ix = torch.randint(5000 - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [35]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

Scaler for FP16

In [36]:
scaler = torch.amp.GradScaler(device_type)

Model settings

In [37]:
from model import ChatModel
from settings import ModelSettings

if not minified:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=ModelSettings.embedding_size,
        max_context_length=block_size,
        ff_size_multiplier=ModelSettings.ff_size_multiplier,
        transformer_blocks=ModelSettings.transformer_blocks,
        attention_heads=ModelSettings.attention_heads,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )
else:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=64,
        max_context_length=block_size,
        ff_size_multiplier=2,
        transformer_blocks=4,
        attention_heads=4,
        dropout=0.0,
        bias=ModelSettings.bias,
        device=device,
    )

model = model.to(device)

if compile:
    model = torch.compile(model)

Optimizer

In [38]:
from optimizer import get_optim_groups

optim_groups = get_optim_groups(model)

# apply dynamic learning rate to the optimizer
optimizer = torch.optim.AdamW(
    optim_groups,
    lr=3e-4,
    betas=(0.9, 0.95),
    eps=1e-8
)

Checkpointer

In [39]:
os.makedirs(info_dir, exist_ok=True)
os.makedirs(state_dir, exist_ok=True)
from datetime import datetime


def save_checkpoint(
        step,
        model,
        optimizer,
        scaler,
        train_loss,
        val_loss,
        learning_rate
):
    state = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict() if scaler else None,
    }

    info = {
        "train_loss": train_loss,
        "val_loss": val_loss,
        "time": datetime.now().isoformat(),
        "block_size": block_size,
        "batch_size": batch_size,
        "eval_interval": eval_interval,
        "step": step,
        "learning_rate": learning_rate
    }

    torch.save(state, f"{state_dir}/{step:05d}.pt")
    torch.save(info,  f"{info_dir}/{step:05d}.pt")

Load training state

In [40]:
def load_checkpoint(step: int):
    state = torch.load(f"{state_dir}/{step:05d}.pt")
    model.load_state_dict(state["model"])
    optimizer.load_state_dict(state["optimizer"])
    scaler.load_state_dict(state["scaler"])
    print(f"Loaded checkpoint {step}")


if checkpoint is not None:
    load_checkpoint(checkpoint)

Learning scheduler

In [41]:
import math


def get_lr(step):
    # 1) linear warmup for warmup_steps steps
    if step < warmup_steps:
        return learning_rate * (step + 1) / (warmup_steps + 1)
    # 2) if it > lr_decay_steps, return min learning rate
    if step > lr_decay_steps:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (step - warmup_steps) / (lr_decay_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coefficient = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coefficient ranges 0..1
    return min_lr + coefficient * (learning_rate - min_lr)

Training loop

In [42]:
for step in range(checkpoint or 0, max_iters):

    # determine and set the learning rate for this iteration
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    optimizer.zero_grad()

    xb, yb = get_batch("train")

    if autocast_enabled:
        with torch.amp.autocast(dtype=torch.float16, device_type=device_type):
            logits, loss = model(xb, yb)
    else:
        logits, loss = model(xb, yb)

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    scaler.step(optimizer)
    scaler.update()

    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        save_checkpoint(
            step=step,
            model=model,
            optimizer=optimizer,
            scaler=scaler,
            train_loss=losses["train"],
            val_loss=losses["val"],
            learning_rate=lr
        )

step 0: train loss 10.0898, val loss 10.0815
step 10: train loss 9.5862, val loss 9.6375
step 20: train loss 8.2541, val loss 8.2854
step 30: train loss 6.7647, val loss 6.7371
step 40: train loss 6.4303, val loss 6.4374
step 50: train loss 6.3996, val loss 6.3657
step 60: train loss 6.3121, val loss 6.4734
step 70: train loss 6.1462, val loss 6.2023
step 80: train loss 6.1210, val loss 6.1800
step 90: train loss 5.9747, val loss 6.0395
step 100: train loss 5.7022, val loss 5.7250
step 110: train loss 5.7560, val loss 5.9357
step 120: train loss 5.5485, val loss 5.6725
step 130: train loss 5.4132, val loss 5.2164
step 140: train loss 5.1278, val loss 5.1207
step 150: train loss 4.8249, val loss 5.0904
step 160: train loss 4.6409, val loss 4.8761
step 170: train loss 4.2857, val loss 4.5573
step 180: train loss 4.2839, val loss 4.5829
step 190: train loss 4.0262, val loss 4.2305
step 200: train loss 3.8650, val loss 3.8589
step 210: train loss 3.5512, val loss 4.0321
step 220: train los

Test the model

In [43]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")


@torch.no_grad()
def generate(model, start, max_new_tokens=50):
    model.eval()
    idx = torch.tensor([tokenizer.encode(start).ids], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = nn.functional.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return tokenizer.decode(idx[0].tolist())


start_token_id = get_batch("test")[0][0][0].item()
start_text = tokenizer.decode([start_token_id])
print(generate(model, start_text))

,Write several rh 1 in arts centre in the impressive roughlyican Centre is 15 called the City and 700 AD. In century AD. In temple above it has Al size, as silk
What are 5 debts is one
