based on https://github.com/karpathy/minGPT/blob/master/demo.ipynb

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dataclasses import dataclass

import torch
from torch.utils.data.dataloader import DataLoader

import random_neural_net_models.mingpt.model as gpt_model
import random_neural_net_models.mingpt.sorter as sorter
import random_neural_net_models.mingpt.trainer as gpt_trainer
import random_neural_net_models.mingpt.utils as gpt_utils
import random_neural_net_models.utils as utils

logger = utils.get_logger("nb")

gpt_utils.set_seed(3407)

In [None]:
# print an example instance of the dataset
train_dataset = sorter.SortDataset(gpt_utils.SETS.train, seed=3407)
test_dataset = sorter.SortDataset(gpt_utils.SETS.test, seed=3407)

In [None]:
x, y = train_dataset[0]
for a, b in zip(x, y):
    logger.info(f"x: {int(a)}, y: {int(b)}")

In [None]:
model_config = gpt_model.GPT.get_config(
    model_type="gpt-nano",
    vocab_size=train_dataset.get_vocab_size(),
    block_size=train_dataset.get_block_size(),
)
model = gpt_model.GPT(model_config)

In [None]:
train_config = gpt_trainer.Trainer.get_config(
    learning_rate=5e-4,  # the model we're using is so small that we can go a bit faster
    max_iters=100,
    num_workers=0,
)

trainer = gpt_trainer.Trainer(train_config, model, train_dataset)

In [None]:
def batch_end_callback(trainer: gpt_trainer.Trainer):
    if trainer.iter_num % 100 == 0:
        logger.info(
            f"iter_dt {trainer.iter_dt * 1000:.2f} ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}"
        )


trainer.set_callback("on_batch_end", batch_end_callback)

In [None]:
trainer.run()

In [None]:
# now let's perform some evaluation
model.eval();

In [None]:
@dataclass
class EvalResult:
    n_correct: int
    n_total: int

    def __post_init__(self):
        self.pct_correct = self.n_correct / self.n_total

    def __str__(self):
        return f"final score: {self.n_correct:_d} / {self.n_total:_d} = {self.pct_correct:.2%} correct"


def eval_split(
    trainer: gpt_trainer.Trainer,
    dataset: sorter.SortDataset,
    max_batches: int,
) -> EvalResult:
    n = train_dataset.length
    results = []
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)

    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        y = y.to(trainer.device)

        # isolate the input pattern alone
        _input = x[:, :n]
        _solution = y[:, -n:]

        # let the model sample the rest of the sequence
        _inference = model.generate(
            _input, n, do_sample=False
        )  # using greedy argmax, not sampling
        _solution_candidate = _inference[
            :, n:
        ]  # isolate the filled in sequence

        # compare the predicted sequence to the true sequence
        correct = (_solution == _solution_candidate).all(dim=1).cpu()

        for i in range(x.size(dim=0)):
            results.append(int(correct[i]))
            if (
                not correct[i] and mistakes_printed_already < 3
            ):  # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                logger.info(
                    f"GPT claims that {_input[i].tolist()} sorted is {_solution_candidate[i].tolist()}"
                    f" but actually is {_solution[i].tolist()}"
                )
        if max_batches is not None and b + 1 >= max_batches:
            break

    n_correct = int(sum(results))

    return EvalResult(n_correct, len(results))


# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, train_dataset, max_batches=50)
    logger.info(f"train - {train_score}")
    test_score = eval_split(trainer, test_dataset, max_batches=50)
    logger.info(f"test - {test_score}")

In [None]:
# let's run a random given sequence through the model as well
n = train_dataset.length
_input = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)

assert _input[0].nelement() == n

with torch.no_grad():
    _inference = model.generate(_input, n, do_sample=False)

_solution = torch.sort(_input[0])[0]
_solution_candidate = _inference[:, n:]

In [None]:
logger.info(f"input sequence  : {_input.tolist()}")
logger.info(f"predicted sorted: {_solution_candidate.tolist()}")
logger.info(f"actual sort     : {_solution.tolist()}")
logger.info(
    f"matches         : {bool((_solution == _solution_candidate).all())}"
)