In [1]:
BATCH_SIZE = 128


In [2]:
import torch
from pytorch_lightning import Trainer, seed_everything
import pytorch_lightning as pl
from wasabi import msg

seed_everything(1337, workers=True)


Global seed set to 1337


1337

In [3]:
# Select the device to use

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        msg.fail(
            "MPS not available because the current PyTorch install was not "
            "built with MPS enabled."
        )
    else:
        msg.fail(
            "MPS not available because the current MacOS version is not 12.3+ "
            "and/or you do not have an MPS-enabled device on this machine."
        )
    # use CPU or GPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    msg.good("MPS is available!")
    device = torch.device("mps")
    msg.good(f"Using device: {device}")


[38;5;2m✔ MPS is available![0m
[38;5;2m✔ Using device: mps[0m


In [4]:
from pandas import read_csv

# load the dataset
train = read_csv("data/train.txt", sep="\t", header=None).applymap(str)

chars = sorted(set(",".join(train.values.flatten()).split(",")))
vocab_size = len(chars)
msg.info("all the unique characters:", ", ".join(chars))
msg.info(f"vocab size: {vocab_size:,}")
msg.info(
    f"min combo length: {min([len(c.split(',')) for c in train.values.flatten()])}"
)
BLOCK_SIZE = min([len(c.split(",")) for c in train.values.flatten()]) - 1
msg.info(f"BLOCK_SIZE: {BLOCK_SIZE}")
msg.info(
    f"max combo length: {max([len(c.split(',')) for c in train.values.flatten()])}"
)

# create a tokenzier from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
msg.info("stoi:", stoi)
msg.info("itos:", itos)


def encode(s):
    return [stoi[c] for c in s]  # encoder: take a string, output a list of integers


def decode(l):
    mapping = {
        "1": "Jab",
        "2": "Cross",
        "3": "Left Hook",
        "4": "Right Hook",
        "5": "Left Uppercut",
        "6": "Right Uppercut",
        "1*": "Jab to the Body",
        "2*": "Cross to the Body",
        "3*": "Left Hook to the Body",
        "4*": "Right Hook to the Body",
        "5*": "Left Uppercut to the Body",
        "6*": "Right Uppercut to the Body",
        "7": "Left Slip",
        "8": "Right Slip",
        "9": "Drop",
        "10": "Left Block",
        "11": "Right Block",
        "12": "Left Roll",
        "13": "Right Roll",
        "14": "<CLS>",
        "15": "<SEP>",
    }
    tokens = [itos[i] for i in l]  # decoder: take a list of integers, output a string
    return " - ".join([mapping[t] for t in tokens])


[38;5;4mℹ all the unique characters:[0m
1, 1*, 10, 11, 12, 13, 14, 15, 2, 2*, 3, 3*, 4, 4*, 5, 5*, 6, 6*, 7, 8, 9
[38;5;4mℹ vocab size: 21[0m
[38;5;4mℹ min combo length: 4[0m
[38;5;4mℹ BLOCK_SIZE: 3[0m
[38;5;4mℹ max combo length: 16[0m
[38;5;4mℹ stoi:[0m
{'1': 0, '1*': 1, '10': 2, '11': 3, '12': 4, '13': 5, '14': 6, '15': 7, '2': 8,
'2*': 9, '3': 10, '3*': 11, '4': 12, '4*': 13, '5': 14, '5*': 15, '6': 16, '6*':
17, '7': 18, '8': 19, '9': 20}
[38;5;4mℹ itos:[0m
{0: '1', 1: '1*', 2: '10', 3: '11', 4: '12', 5: '13', 6: '14', 7: '15', 8: '2',
9: '2*', 10: '3', 11: '3*', 12: '4', 13: '4*', 14: '5', 15: '5*', 16: '6', 17:
'6*', 18: '7', 19: '8', 20: '9'}


In [5]:
def token_to_human_readable(token):
    mapping = {0: " ", 1: "A", 2: "B", 3: "C", 4: "D", 5: "E", 6: "F", 7: "G", 8: "H"}
    return "".join([itos[i] for i in token])


In [6]:
# construct the dataloader
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


class boxerDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.x = df.applymap(lambda x: encode(x.split(","))).values.flatten()
        msg.info(f"dataset size: {len(self.x):,}")
        msg.info(f"dataset starts with these 3 examples: {self.x[:3]}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return self.x[idx]


def gpt_collate(batch):
    # to improve data efficiency,
    # we will iterate through the batch and generate random idx for every example
    block_batch_x = []
    block_batch_y = []

    for b in batch:
        idx = np.random.randint(0, len(b) - BLOCK_SIZE)
        block_batch_x.append(torch.from_numpy(np.array(b[idx : idx + BLOCK_SIZE])))
        block_batch_y.append(
            torch.from_numpy(np.array(b[idx + 1 : idx + BLOCK_SIZE + 1]))
        )

    x, y = torch.stack(block_batch_x), torch.stack(block_batch_y)

    return x.to(device), y.to(device)


In [7]:
train.applymap(lambda x: encode(x.split(","))).values.flatten()[0]


[6, 0, 8, 20, 1, 8, 11, 12, 7]

In [8]:
train_dataset = boxerDataset(train)
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=gpt_collate
)
msg.info(f"train dataset size: {len(train_dataset):,}")


[38;5;4mℹ dataset size: 8,592[0m
[38;5;4mℹ dataset starts with these 3 examples: [list([6, 0, 8, 20, 1, 8, 11,
12, 7]) list([6, 8, 10, 19, 17, 9, 7])  list([6, 8, 4, 11, 17, 3, 7])][0m
[38;5;4mℹ train dataset size: 8,592[0m


In [9]:
val = read_csv("data/validate.txt", sep="\t", header=None).applymap(str)
val_dataset = boxerDataset(val)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=gpt_collate
)
msg.info(f"test dataset size: {len(val_dataset):,}")


[38;5;4mℹ dataset size: 2,455[0m
[38;5;4mℹ dataset starts with these 3 examples: [list([6, 8, 19, 11, 20, 9, 7])
list([6, 8, 11, 2, 17, 7])  list([6, 0, 5, 0, 9, 19, 9, 7])][0m
[38;5;4mℹ test dataset size: 2,455[0m


In [10]:
from models import boxerGPT
from models.model import GPTConfig

gpt_config = GPTConfig()
gpt_config.block_size = BLOCK_SIZE
gpt_config.n_layer = 24
gpt_config.n_head = 24
gpt_config.n_embd = 768
gpt_config.vocab_size = vocab_size
gpt_config.weight_decay = 1e-1
gpt_config.learning_rate = 2e-4
gpt_config.betas = (0.9, 0.999)
gpt_config.device_type = device

box_gpt_model = boxerGPT.boxerGPT(gpt_config)
_ = box_gpt_model.to(device)


number of parameters: 170.13M


In [11]:
# before training the model generates gibberish
with torch.no_grad():
    generated = box_gpt_model.model.generate(
        torch.from_numpy(np.array([[6, 0]])).to(device), max_new_tokens=10, temperature=1
    )
    generated = generated.cpu().numpy().tolist()
    msg.fail([decode(c) for c in generated])


[38;5;1m✘ ['<CLS> - Jab - Right Uppercut to the Body - Right Hook - Right Hook
to the Body - Left Block - Left Roll - Right Uppercut to the Body - Right Hook -
Right Uppercut to the Body - Right Block - Left Block'][0m


In [12]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
from pytorch_lightning.loggers import WandbLogger
import wandb
import time

trial_name = f"boxerGPT-{time.time()}-block_size-{BLOCK_SIZE}-n_layer-{gpt_config.n_layer}-n_head-{gpt_config.n_head}-n_embd-{gpt_config.n_embd}"

wandb_logger = WandbLogger(project="boxerGPT", name=trial_name, log_model=True)
wandb_logger.watch(box_gpt_model.model)
wandb_logger.experiment.config.update(gpt_config)

trainer = pl.Trainer(
    max_epochs=100,
    accelerator="mps",
    devices=1,
    logger=wandb_logger,
    accumulate_grad_batches=4,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val_loss", patience=10),
        pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1),
        pl.callbacks.LearningRateMonitor(logging_interval="step"),
        pl.callbacks.RichModelSummary(),
    ],
)

trainer.fit(
    model=box_gpt_model, train_dataloaders=train_loader, val_dataloaders=val_loader
)
wandb.finish()


[34m[1mwandb[0m: Currently logged in as: [33mstatikkkkk[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


using fused AdamW: False


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  tensor = flat.histc(bins=self._num_bins, min=tmin, max=tmax)


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr-AdamW/pg1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr-AdamW/pg2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▆▃▁▁▁▁▃▂▁▂▁▂▂
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,██▇▆▅▄▄▃▃▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,42.0
lr-AdamW/pg1,0.0002
lr-AdamW/pg2,0.0002
train_loss,2.15956
trainer/global_step,730.0
val_loss,2.12306


In [13]:
with torch.no_grad():
    generated = box_gpt_model.model.generate(
        torch.from_numpy(np.array([[6, 0], [6, 8]])), max_new_tokens=10, temperature=1
    )
    generated = generated.cpu().numpy().tolist()
    msg.good([decode(c) for c in generated])


[38;5;2m✔ ['<CLS> - Jab - Right Roll - Cross to the Body - Left Hook to the
Body - Left Slip - Jab - <SEP> - <SEP> - Cross - Left Roll - Jab', '<CLS> -
Cross - Left Hook to the Body - Cross to the Body - Left Slip - <SEP> - Jab to
the Body - <SEP> - <SEP> - Cross to the Body - Left Hook - <SEP>'][0m


In [14]:
# export the model to ONNX
box_gpt_model.to_onnx(
    f"{wandb_logger.experiment.dir}/{trial_name}.onnx",
    torch.from_numpy(np.array([[6]])),
    export_params=True,
    input_names=["input"],  # the model's input names
    output_names=["output"],  # the model's output names
    dynamic_axes={
        "input": {1: "seq_length"},  # variable lenght axes
        "output": {1: "seq_length"},
    },
)

  assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))


In [15]:
import onnxruntime

class boxerGPTONNXInference:
    def __init__(self, onnx_model_path) -> None:
        self.ort_session = onnxruntime.InferenceSession(onnx_model_path)
        msg.info(f"onnx model loaded from {onnx_model_path}")
        msg.info(
            f"onnx model inputs: {[x.name for x in self.ort_session.get_inputs()]}"
        )
        msg.info(
            f"onnx model outputs: {[x.name for x in self.ort_session.get_outputs()]}"
        )

        self.input_name = self.ort_session.get_inputs()[0].name
        self.predict = lambda x: self.ort_session.run(None, {self.input_name: x})[
            0
        ].flatten()
        self.softmax = lambda x: np.exp(x) / sum(np.exp(x))

    def generate_from_onnx(
        self, idx, max_new_tokens=10, temperature=1.0, top_k=None, truncate=True
    ):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.shape[1] <= BLOCK_SIZE else idx[:, -BLOCK_SIZE:]
            # forward the model to get the logits for the index in the sequence
            logits = self.predict(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply softmax to convert logits to (normalized) probabilities
            probs = self.softmax(logits)
            # sample from the distribution
            idx_next = np.random.multinomial(n=1, pvals=probs)
            idx_next_index = np.argmax(idx_next)
            # append sampled index to the running sequence and continue
            idx = np.concatenate((idx, np.array([[idx_next_index]])), axis=1)

            if truncate and idx_next_index == 7:
                break

        # decode the sequence of indices to text
        generated_combo = decode([c for c in idx.tolist()[0][1:-1]])

        return generated_combo


In [16]:
boxergpt_onnx = boxerGPTONNXInference(f"{wandb_logger.experiment.dir}/{trial_name}.onnx")
generated = boxergpt_onnx.generate_from_onnx(np.array([[6, 0]]).astype(np.int64))
generated


[38;5;4mℹ onnx model loaded from
./wandb/run-20230226_145947-8kdugau8/files/boxerGPT-1677452386.5899782-block_size-3-n_layer-24-n_head-24-n_embd-768.onnx[0m
[38;5;4mℹ onnx model inputs: ['input'][0m
[38;5;4mℹ onnx model outputs: ['output'][0m


'Jab - Left Block - Cross to the Body - Left Hook to the Body - Drop - Cross'