In [1]:
from dataclasses import dataclass
from pathlib import Path

import huggingface_hub as hf
import numpy as np
import torch as t
import wandb
from datasets import load_dataset
from jaxtyping import Float, Int
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from othello_gpt.data.vis import plot_game
from othello_gpt.model.nanoGPT import GPT, GPTConfig
from othello_gpt.util import pad_batch

In [2]:
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
device

device(type='mps')

In [3]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
n_games = 1000000
size = 6

In [4]:
hf.login(token=(root_dir / "secret.txt").read_text())
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33malfredwong[0m ([33malfredwong-university-of-cambridge[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
dataset_dict = load_dataset("awonga/othello-gpt")
plot_game(dataset_dict["test"][0], subplot_size=180, n_cols=8)

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/21 [00:00<?, ?it/s]

In [6]:
class HubGPT(GPT, hf.PyTorchModelHubMixin):
    pass

cfg = GPTConfig(
    # block_size=(size * size - 4) * 2 - 1,
    block_size=(size * size - 4) - 1,
    # vocab_size=size * size - 4 + 2,  # pass and pad
    vocab_size=size * size - 4 + 1,  # pad
    n_layer=8,
    n_head=8,
    n_embd=128,
    dropout= 0.0,
    bias=True,
)
display(cfg)
model = HubGPT(cfg).to(device)

GPTConfig(block_size=31, vocab_size=33, n_layer=8, n_head=8, n_embd=128, dropout=0.0, bias=True)

number of parameters: 1.59M


In [9]:
@dataclass
class TransformerTrainingArgs:
    batch_size = 32
    epochs = 8
    max_steps_per_epoch = 10000
    lr = 5e-4
    weight_decay = 1e-3
    wandb_project: str | None = "othello-gpt"
    wandb_name: str | None = None

args = TransformerTrainingArgs()

In [8]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: GPT):
        super().__init__()
        self.model = model
        self.args = args

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        def collate_fn(batch):
            return pad_batch(batch, model.config.block_size + 1)

        self.train_loader = DataLoader(dataset_dict["train"]["input_ids"], batch_size=args.batch_size, shuffle=True, pin_memory=True, collate_fn=collate_fn)
        self.test_loader = DataLoader(dataset_dict["test"]["input_ids"], batch_size=args.batch_size, shuffle=False, pin_memory=True, collate_fn=collate_fn)

    def training_step(self, batch: Int[Tensor, "batch seq"]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        _, loss = model(batch[:, :-1], batch[:, 1:])
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            batch = batch.to(device)
            logits, _ = self.model(batch[:, :-1], batch[:, 1:])
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == batch[:, 1:]).sum().item()
            total_samples += batch.size(0) * (batch.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch.to(device))
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()

trainer = TransformerTrainer(args, model)
trainer.train()
model.push_to_hub("awonga/othello-gpt")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch 1, loss: 2.609, accuracy: nan:   0%|          | 302/80000 [00:11<45:33, 29.16it/s] 

KeyboardInterrupt: 

