In [1]:
#| default_exp learner

In [2]:
#| export
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch import optim
from typing import Mapping
from copy import copy

In [3]:
#| export
class DataSet:
    def __init__(self, x, y): self.x, self.y = x, y
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i], self.y[i]

In [4]:
import os
import tiktoken

cwd = os.getcwd()

denc = tiktoken.get_encoding("gpt2")
input_file = f"{cwd}/fast-nanogpt/input.txt"
with open(input_file) as f:
    text = f.read()

In [5]:
tokens = denc.encode(text[:1000])
B, T = 4, 32
buf = torch.tensor(tokens[: B * T + 1]).to('cuda')
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)

x[0], y[0]

(tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,
          3285,   502,  2740,    13,   198,   198,  3237,    25,   198,  5248,
           461,    11,  2740,    13,   198,   198,  5962, 22307,    25,   198,
          1639,   389], device='cuda:0'),
 tensor([22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,  3285,
           502,  2740,    13,   198,   198,  3237,    25,   198,  5248,   461,
            11,  2740,    13,   198,   198,  5962, 22307,    25,   198,  1639,
           389,   477], device='cuda:0'))

In [6]:
ds = DataSet(x, y)
dl = DataLoader(ds, batch_size=4)
iterdl = iter(dl)

for i, (x, y) in enumerate(iterdl):
    print("batch", i)
    print("x.shape, y.shape", x.shape, y.shape)

batch 0
x.shape, y.shape torch.Size([4, 32]) torch.Size([4, 32])


In [7]:
#| export
class DataLoaders:
    def __init__(self, *dls):
        self.train, self.valid = dls[:2]
    
    @classmethod
    def from_dd(cls, datasets, batch_size, **kwargs):
        return cls(*[DataLoader(ds, batch_size=batch_size, **kwargs) for ds in datasets])

In [8]:
dls = DataLoaders.from_dd([ds, None], batch_size=4)
for x, y in dls.train:
    print(x.shape, y.shape)

torch.Size([4, 32]) torch.Size([4, 32])


In [9]:
from tinyai.model import get_model

## Overfit one batch

In [10]:
model = get_model().to('cuda')
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

for i in range(50):
    optimizer.zero_grad()
    # forward the model
    logits, loss = model(x, y)
    loss.backward()
    optimizer.step()

    # if i % 10 == 0:
    print(f"step {i} loss {loss.item()}")

step 0 loss 10.96385383605957
step 1 loss 7.438758373260498
step 2 loss 5.9899702072143555
step 3 loss 4.816177845001221
step 4 loss 4.449419021606445
step 5 loss 3.7852022647857666
step 6 loss 3.3959953784942627
step 7 loss 2.9343061447143555
step 8 loss 2.6591973304748535
step 9 loss 2.3939435482025146
step 10 loss 2.203979969024658
step 11 loss 2.0710153579711914
step 12 loss 1.916163682937622
step 13 loss 1.8228553533554077
step 14 loss 1.7424712181091309
step 15 loss 1.6542805433273315
step 16 loss 1.5931313037872314
step 17 loss 1.5389171838760376
step 18 loss 1.4855793714523315
step 19 loss 1.4567785263061523
step 20 loss 1.4302048683166504
step 21 loss 1.396011471748352
step 22 loss 1.3761154413223267
step 23 loss 1.3608899116516113
step 24 loss 1.3465158939361572
step 25 loss 1.3361443281173706
step 26 loss 1.3233309984207153
step 27 loss 1.3132715225219727
step 28 loss 1.3076026439666748
step 29 loss 1.2999235391616821
step 30 loss 1.2956291437149048
step 31 loss 1.2898812294

In [11]:
#| export
from operator import attrgetter
from functools import partial

In [12]:
#| export
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [13]:
#| export
class Callback:
    order = 0


def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter("order")):
        method = getattr(cb, method_nm, None)
        if method is not None:
            method(learn)

In [14]:
# | export
class with_cbs:
    def __init__(self, nm):
        self.nm = nm

    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f"before_{self.nm}")
                f(o, *args, **kwargs)
                o.callback(f"after_{self.nm}")
            except globals()[f"Cancel{self.nm.title()}Exception"]:
                pass
            finally:
                o.callback(f"cleanup_{self.nm}")

        return _f


class Learner:
    def __init__(
        self,
        model,
        dls=(0,),
        lr=0.1,
        cbs=None,
        opt_func=optim.SGD,
    ):
        self.model = model
        self.dls = dls
        self.lr = lr
        self.cbs = cbs if cbs else []
        self.opt_func = opt_func

    @with_cbs("batch")
    def _one_batch(self):
        self.predict()
        self.callback("after_predict")
        # self.get_loss()
        # self.callback("after_loss")
        if self.training:
            self.backward()
            self.callback("after_backward")
            self.step()
            self.callback("after_step")
            self.zero_grad()

    @with_cbs("epoch")
    def _one_epoch(self):
        for self.iter, self.batch in enumerate(self.dl):
            self._one_batch()

    def one_epoch(self, training=True):
        self.model.train(training)
        self.dl = self.dls.train if training else self.dls.valid
        self._one_epoch()

    @with_cbs("fit")
    def _fit(self, train, valid):
        for self.epoch in self.epochs:
            if train:
                self.one_epoch(training=True)
            if valid:
                with torch.no_grad():
                    self.one_epoch(False)

    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
        if cbs is None:
            cbs = []
        for cb in cbs:
            self.cbs.append(cb)
        try:
            self.n_epochs = n_epochs
            self.epochs = range(n_epochs)
            if lr is None:
                lr = self.lr
            if self.opt_func:
                self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(train, valid)
        finally:
            for cb in cbs:
                self.cbs.remove(cb)

    def __getattr__(self, name):
        if name in ("predict", "get_loss", "backward", "step", "zero_grad"):
            return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm):
        run_cbs(self.cbs, method_nm, self)

    @property
    def training(self):
        return self.model.training

In [15]:
# | export
default_device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)


def to_device(x, device=default_device):
    if isinstance(x, torch.Tensor):
        return x.to(device)
    if isinstance(x, Mapping):
        return {k: v.to(device) for k, v in x.items()}
    return type(x)(to_device(o, device) for o in x)

In [16]:
#| export
class DeviceCB(Callback):
    """Put model to device at the beginning of training, and put batch to device before each forward pass."""

    def __init__(self, device=default_device):
        self.device = device

    def before_fit(self, learn):
        if hasattr(learn.model, "to"):
            learn.model.to(self.device)

    def before_batch(self, learn):
        learn.batch = to_device(learn.batch, device=self.device)

In [17]:
#| export
class TrainCB(Callback):

    def predict(self, learn):
        # import pdb; pdb.set_trace()
        learn.preds, learn.loss = learn.model(*learn.batch)
        # print("epoch", learn.epoch, "step", learn.iter, "loss", learn.loss.item())

    def backward(self, learn):
        learn.loss.backward()

    def step(self, learn):
        learn.opt.step()

    def zero_grad(self, learn):
        learn.opt.zero_grad()

In [18]:
cbs = [TrainCB(), DeviceCB()]
model = get_model()

In [19]:
lrn = Learner(model, dls=dls, opt_func=optim.AdamW, cbs=cbs, lr=3e-4)
lrn.fit(1, valid=False)

In [20]:
class OverfitLearner(Learner):

    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None, n_repeat=50):
        self.n_repeat = n_repeat
        super().fit(n_epochs, train, valid, cbs, lr)

    @with_cbs("epoch")
    def _one_epoch(self):
        for self.iter in range(self.n_repeat):
            for self.batch in self.dl:
                self._one_batch()


## Overfit one batch with learner

In [21]:
model = get_model()
lrn = OverfitLearner(model, dls=dls, opt_func=optim.AdamW, cbs=cbs, lr=3e-4)
lrn.fit(1, valid=False, n_repeat=50)

## More callbacks

In [22]:
#| export
from torcheval.metrics import Mean

In [23]:
#| export
def to_cpu(x):
    if isinstance(x, Mapping):
        return {k: to_cpu(v) for k, v in x.items()}
    if isinstance(x, list):
        return [to_cpu(o) for o in x]
    if isinstance(x, tuple):
        return tuple(to_cpu(list(x)))
    res = x.detach().cpu()
    return res.float() if res.dtype == torch.float16 else res


class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms:
            metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics["loss"] = self.loss = Mean()

    def _log(self, d):
        print(d)

    def before_fit(self, learn):
        learn.metrics = self

    def before_epoch(self, learn):
        [o.reset() for o in self.all_metrics.values()]

    def after_epoch(self, learn):
        log = {k: f"{v.compute():.3f}" for k, v in self.all_metrics.items()}
        # log["epoch"] = f"{learn.epoch}"
        log["epoch"] = learn.epoch
        log["train"] = "train" if learn.model.training else "eval"
        self._log(log)

    def after_batch(self, learn):
        x, y, *_ = to_cpu(learn.batch)
        for m in self.metrics.values():
            m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=len(x))

In [24]:
#| export
from fastprogress import progress_bar, master_bar
import fastcore.all as fc

In [25]:
# | export
class ProgressCB(Callback):
    order = MetricsCB.order + 1

    def __init__(self, plot=False):
        self.plot = plot

    def before_fit(self, learn):
        learn.epochs = self.mbar = master_bar(learn.epochs)
        self.first = True
        if hasattr(learn, "metrics"):
            learn.metrics._log = self._log
        self.losses = []
        self.val_losses = []

    def _log(self, d):
        if self.first:
            self.mbar.write(list(d), table=True)
            self.first = False
        # import pdb; pdb.set_trace()
        self.mbar.write(list(d.values()), table=True)

    def before_epoch(self, learn):
        learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)

    def after_batch(self, learn):
        learn.dl.comment = f"{learn.loss:.3f}"
        if self.plot and hasattr(learn, "metrics") and learn.training:
            self.losses.append(learn.loss.item())
            if self.val_losses:
                graphs = [
                    [fc.L.range(self.losses), self.losses],
                    [
                        fc.L.range(learn.epoch).map(
                            lambda x: (x + 1) * len(learn.dls.train)
                        ),
                        self.val_losses,
                    ],
                ]
            else:
                graphs = [[fc.L.range(self.losses), self.losses]]
            self.mbar.update_graph(graphs)

    def after_epoch(self, learn):
        if self.plot and hasattr(learn, "metrics"):
            if not learn.training:
                self.val_losses.append(learn.metrics.all_metrics["loss"].compute())
                graphs = [
                    [fc.L.range(self.losses), self.losses],
                    [
                        fc.L.range(learn.epoch + 1).map(
                            lambda x: (x + 1) * len(learn.dls.train)
                        ),
                        self.val_losses,
                    ],
                ]
            else:
                graphs = [[fc.L.range(self.losses), self.losses]]
            self.mbar.update_graph(graphs)

In [28]:
model = get_model()
cbs = [TrainCB(), DeviceCB(), MetricsCB(), ProgressCB()]
lrn = OverfitLearner(model, dls=dls, opt_func=optim.AdamW, cbs=cbs, lr=3e-4)
lrn.fit(2, valid=False)

loss,epoch,train
2.042,0,train
1.223,1,train
