In [1]:
import torch as t
import numpy as np
import wandb
from bidict import bidict
from tqdm import tqdm
from pathlib import Path
from datasets import DatasetDict
from dataclasses import dataclass
from jaxtyping import Float, Int
from torch import Tensor
from torch.utils.data import DataLoader
from othello_gpt.model.nanoGPT import GPTConfig, GPT
from othello_gpt.data.vis import plot_game, move_id_to_text
from othello_gpt.data.generate import generate_dataset
from typing import List

In [2]:
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
device

device(type='mps')

In [3]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
n_games = 1000000
size = 6
PAD_TOKEN_ID = -1

nw_middle_id = (size // 2 - 1) * size + (size // 2 - 1)
initial_squares = set([nw_middle_id, nw_middle_id + 1, nw_middle_id + size, nw_middle_id + size + 1])
all_squares = [i for i in range(size * size) if i not in initial_squares]
# id_to_token_id_map = bidict({square_id: token_id for token_id, square_id in enumerate([-1, size * size] + all_squares)})
id_to_token_id_map = bidict({square_id: token_id for token_id, square_id in enumerate([PAD_TOKEN_ID] + all_squares)})

def tokenize(history):
    return {"input_ids": [id_to_token_id_map[i] for i in history]}

def decode(token_ids):
    return {"square_ids": [id_to_token_id_map.inverse[i] for i in token_ids]}

In [4]:
dataset_dict_path = data_dir / f"othello_{n_games}_{size}"

if dataset_dict_path.exists():
    dataset_dict = DatasetDict.load_from_disk(dataset_dict_path)
else:
    dataset = generate_dataset(n_games, size)
    dataset_dict = dataset.train_test_split(test_size=0.1)
    dataset_dict.save_to_disk(dataset_dict_path)

dataset_dict["train"] = dataset_dict["train"].filter(lambda x: size*size not in x["histories"])
dataset_dict["test"] = dataset_dict["test"].filter(lambda x: size*size not in x["histories"])
dataset_dict["train"] = dataset_dict["train"].map(lambda x: tokenize(x["histories"]))
dataset_dict["test"] = dataset_dict["test"].map(lambda x: tokenize(x["histories"]))

plot_game(dataset_dict["test"][0], subplot_size=180, n_cols=8)

100%|██████████| 1000000/1000000 [32:10<00:00, 517.88it/s] 


Saving the dataset (0/21 shards):   0%|          | 0/900000 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/100000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/521555 [00:00<?, ? examples/s]

Map:   0%|          | 0/58090 [00:00<?, ? examples/s]

In [5]:
cfg = GPTConfig(
    # block_size=(size * size - 4) * 2 - 1,
    block_size=(size * size - 4) - 1,
    # vocab_size=size * size - 4 + 2,  # pass and pad
    vocab_size=size * size - 4 + 1,  # pad
    n_layer=8,
    n_head=8,
    n_embd=128,
    dropout= 0.0,
    bias=False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster,
)
display(cfg)
model = GPT(cfg).to(device)

GPTConfig(block_size=31, vocab_size=33, n_layer=8, n_head=8, n_embd=128, dropout=0.0, bias=False)

number of parameters: 1.58M


In [6]:
@dataclass
class TransformerTrainingArgs:
    batch_size = 32
    epochs = 8
    max_steps_per_epoch = 5120
    lr = 5e-4
    weight_decay = 1e-3
    wandb_project: str | None = "othello-gpt"
    wandb_name: str | None = None

args = TransformerTrainingArgs()

In [7]:
def pad_batch(batch: List[List[int]], max_len: int = cfg.block_size+1, pad_token_id: int = PAD_TOKEN_ID) -> Int[Tensor, "batch max_len"]:
    padded_batch = t.full((len(batch), max_len), pad_token_id)
    for i, seq in enumerate(batch):
        padded_batch[i, -len(seq):] = t.tensor(seq)
    return padded_batch

In [8]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: GPT):
        super().__init__()
        self.model = model
        self.args = args

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        self.train_loader = DataLoader(dataset_dict["train"]["input_ids"], batch_size=args.batch_size, shuffle=True, pin_memory=True, collate_fn=pad_batch)
        self.test_loader = DataLoader(dataset_dict["test"]["input_ids"], batch_size=args.batch_size, shuffle=False, pin_memory=True, collate_fn=pad_batch)

    def training_step(self, batch: Int[Tensor, "batch seq"]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        _, loss = model(batch[:, :-1], batch[:, 1:])
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            batch = batch.to(device)
            logits, _ = self.model(batch[:, :-1], batch[:, 1:])
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == batch[:, 1:]).sum().item()
            total_samples += batch.size(0) * (batch.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch.to(device))
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()

trainer = TransformerTrainer(args, model)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33malfredwong[0m ([33malfredwong-university-of-cambridge[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.



'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.

Evaluating: 100%|██████████| 1816/1816 [00:11<00:00, 158.20it/s] [01:50<12:20, 48.41it/s]
Evaluating: 100%|██████████| 1816/1816 [00:11<00:00, 158.74it/s]960 [03:51<10:42, 47.84it/s] 
Evaluating: 100%|██████████| 1816/1816 [00:11<00:00, 162.78it/s]960 [05:51<09:02, 47.15it/s]  
Evaluating: 100%|██████████| 1816/1816 [00:11<00:00, 162.72it/s]960 [08:03<07:08, 47.79it/s]  
Evaluating: 100%|██████████| 1816/1816 [00:11<00:00, 158.94it/s]960 [10:03<05:17, 48.29it/s]  
Evaluating: 100%|██████████| 1816/1816 [00:11<00:00, 163.08it/s]960 [11:53<03:33, 47.94it/s]  
Evaluating: 100%|██████████| 1816/1816 [00:11<00:00, 162.07it/s]960 [13:53<01:47, 47.70it/s]  
Evaluating: 100%|██████████| 1816/1816 [00:10<00:00, 170.13it/s]/s]                         



Run (i5teoyyv) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.



0,1
accuracy,▁▄▅▆▇▇██
train_loss,█▇▆▅▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.25659
train_loss,1.71158


Epoch 8, loss: 1.712, accuracy: 0.256: : 40968it [16:15, 42.00it/s]


In [9]:
weights_dir = data_dir / "weights"
weights_dir.mkdir(exist_ok=True)
t.save(model, weights_dir / f"othello_{n_games}_{size}")
# model = t.load(weights_dir / f"othello_{n_games}_{size}", weights_only=False)

In [10]:
test_game = dataset_dict["test"][0]
tokens = t.tensor(tokenize(test_game["histories"])["input_ids"]).unsqueeze(0).to(device)
logits, loss = model(tokens[:, :-1], tokens[:, 1:])
probs = logits.softmax(-1)

n_moves = probs.shape[1]
test_probs = test_game.copy()
prob_boards = t.full((n_moves, size, size), fill_value=0.0, device=device)
for i in range(n_moves):
    prob_boards[i].flatten()[all_squares] = probs[0, i, 1:] / probs[0, i, 1:].sum()

test_pred = test_game.copy()
test_pred["boards"] = prob_boards.detach().cpu().numpy()
plot_game(test_game)
plot_game(test_pred, reversed=False, textcolor="pink", hovertext=test_pred["boards"])

In [53]:
import plotly.graph_objects as go

logits = model(t.tensor([[9]], device=device))[0].detach().cpu()
board = t.zeros((size, size))
board.flatten()[all_squares] = logits[0, 0, 1:].softmax(-1)

fig = go.Figure()
fig.add_trace(
    go.Heatmap(
        z=board,
        colorscale="gray",
        x=list("ABCDEF"),
        y=list(range(1, 7)),
        xgap=0.2,
        ygap=0.2,
    )
)
fig.update_yaxes(
    showline=True,
    linecolor="black",
    linewidth=1,
    mirror=True,
    constrain="domain",
    autorange="reversed",
)

fig.update_xaxes(
    showline=True,
    linecolor="black",
    linewidth=1,
    mirror=True,
    scaleanchor="y",
    scaleratio=1,
    constrain="domain",
)

fig.update_layout(
    width=400,
    height=300,
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.show()

In [41]:
logits[:, :, 1:].softmax(-1)

tensor([[[4.0881e-05, 4.5969e-01, 3.6625e-03, 5.5796e-04, 5.3879e-05,
          4.2068e-05, 1.4708e-03, 1.0757e-02, 1.3805e-05, 8.3180e-04,
          8.8014e-02, 2.8514e-05, 1.0948e-05, 1.8218e-01, 1.6795e-03,
          2.8497e-04, 8.8212e-05, 2.3984e-03, 1.8577e-03, 5.8417e-04,
          3.4191e-05, 1.5968e-01, 1.5283e-03, 8.2403e-02, 1.3654e-03,
          1.0934e-04, 3.6540e-05, 6.7129e-06, 4.0382e-05, 4.1219e-06,
          4.8387e-04, 6.2467e-05]]])

In [39]:
test_game["histories"]

[9,
 10,
 11,
 5,
 16,
 17,
 25,
 13,
 6,
 2,
 7,
 12,
 18,
 27,
 34,
 26,
 1,
 30,
 24,
 33,
 28,
 22,
 3,
 8,
 29,
 35,
 23,
 31,
 19,
 32,
 4,
 0]