In [1]:
import sys
import os
from pathlib import Path
project_dir = Path(os.path.abspath('')).parent
sys.path.insert(0, project_dir.as_posix())

from tqdm import trange, tqdm
import numpy as np
import torch
from cs336_basics.nn_336 import TransformerLM
from cs336_basics.train import AdamW, cross_entropy, save_checkpoint, load_data, get_lr_cosine_schedule, gradient_clipping

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
tokens = np.load(
    project_dir.joinpath("data", "TinyStoriesV2-GPT4-train-encoded.npy"), mmap_mode="r"
)

In [3]:
def train(
    model: TransformerLM,
    optimizer: AdamW,
    epochs,
    batch_size,
    T_w,
    T_c,
    checkpoint_dir=None,
    save_every=10,
):
    losses = optimizer.state["losses"]
    if checkpoint_dir is not None:
        checkpoint_dir = Path(checkpoint_dir)
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
    for epoch in trange(epochs):
        epoch += 1 + optimizer.state["epoch"]
        optimizer.param_groups[0]["lr"] = get_lr_cosine_schedule(
            epoch,
            (optimizer.state["init_lr"] / 10, optimizer.state["init_lr"]),
            T_w=T_w,
            T_c=T_c,
        )
        x, target = load_data(tokens, batch_size, model.max_seq_len, device=device)
        pred = model(x)
        loss = cross_entropy(pred.view(-1, pred.shape[-1]), target.view(-1))

        optimizer.zero_grad()
        loss.backward()
        # gradient_clipping(model.parameters())
        optimizer.step()

        del x
        del target

        losses.append(loss.item())
        if epoch % save_every == 0 and checkpoint_dir is not None:
            save_checkpoint(
                model,
                optimizer,
                epoch,
                (checkpoint_dir / f"checkpoint_epoch_{epoch}.pt").as_posix(),
            )

        tqdm.write(f"----\nEpoch [{epoch}/{epochs}], Loss: {losses[-1]:.4f}\n")

In [4]:
model = TransformerLM(
    vocab_size=10000,
    num_layers=4,
    num_heads=16,
    d_model=512,
    d_ff=1344,
    max_seq_len=256,
    Theta=10000,
    device=device,
)
optimizer = AdamW(
    model.parameters(), lr=0.001, betas=(0.9, 0.95), eps=1e-8, weight_decay=1e-2
)

In [5]:
train(
    model,
    optimizer,
    epochs=1,
    batch_size=64,
    T_w=20,
    T_c=100,
    checkpoint_dir=project_dir / "checkpoints",
    save_every=10000,
)

                                     

----
Epoch [1/1], Loss: 9.2103



100%|██████████| 1/1 [00:00<00:00,  1.09it/s]
