In [None]:
#| default_exp learner

In [16]:
#| export
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch import optim
from typing import Mapping, Any, Tuple, List, Union
from copy import copy

In [2]:
#| export
class DataSet:
    def __init__(self, x, y): self.x, self.y = x, y
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i], self.y[i]

In [3]:
import os
import tiktoken

cwd = os.getcwd()

denc = tiktoken.get_encoding("gpt2")
input_file = f"{cwd}/fast-nanogpt/input.txt"
with open(input_file) as f:
    text = f.read()

In [4]:
tokens = denc.encode(text[:1000])
B, T = 4, 32
buf = torch.tensor(tokens[: B * T + 1]).to('cuda')
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)

x[0], y[0]

(tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,
          3285,   502,  2740,    13,   198,   198,  3237,    25,   198,  5248,
           461,    11,  2740,    13,   198,   198,  5962, 22307,    25,   198,
          1639,   389], device='cuda:0'),
 tensor([22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,  3285,
           502,  2740,    13,   198,   198,  3237,    25,   198,  5248,   461,
            11,  2740,    13,   198,   198,  5962, 22307,    25,   198,  1639,
           389,   477], device='cuda:0'))

In [5]:
ds = DataSet(x, y)
dl = DataLoader(ds, batch_size=4)
iterdl = iter(dl)

for i, (x, y) in enumerate(iterdl):
    print("batch", i)
    print("x.shape, y.shape", x.shape, y.shape)

batch 0
x.shape, y.shape torch.Size([4, 32]) torch.Size([4, 32])


In [6]:
#| export
class DataLoaders:
    def __init__(self, *dls):
        self.train, self.valid = dls[:2]
    
    @classmethod
    def from_dd(cls, datasets, batch_size, **kwargs):
        return cls(*[DataLoader(ds, batch_size=batch_size, **kwargs) for ds in datasets])

In [7]:
dls = DataLoaders.from_dd([ds, None], batch_size=4)
for x, y in dls.train:
    print(x.shape, y.shape)

torch.Size([4, 32]) torch.Size([4, 32])


In [8]:
from tinyai.model import get_model
model = get_model().to("cuda")

overfit one batch

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

for i in range(50):
    optimizer.zero_grad()
    # forward the model
    logits, loss = model(x, y)
    loss.backward()
    optimizer.step()

    # if i % 10 == 0:
    print(f"step {i} loss {loss.item()}")

step 0 loss 10.887170791625977
step 1 loss 7.397668838500977
step 2 loss 6.109092712402344
step 3 loss 5.146890163421631
step 4 loss 4.083423137664795
step 5 loss 3.532369613647461
step 6 loss 3.068758487701416
step 7 loss 2.708280563354492
step 8 loss 2.4744410514831543
step 9 loss 2.2717630863189697
step 10 loss 2.0909435749053955
step 11 loss 1.949676752090454
step 12 loss 1.8371838331222534
step 13 loss 1.7458043098449707
step 14 loss 1.66878080368042
step 15 loss 1.6040171384811401
step 16 loss 1.5466166734695435
step 17 loss 1.499731183052063
step 18 loss 1.4627355337142944
step 19 loss 1.4300384521484375
step 20 loss 1.4030559062957764
step 21 loss 1.3814430236816406
step 22 loss 1.3627195358276367
step 23 loss 1.3466485738754272
step 24 loss 1.333757996559143
step 25 loss 1.3225020170211792
step 26 loss 1.3118934631347656
step 27 loss 1.3042244911193848
step 28 loss 1.2981325387954712
step 29 loss 1.2910606861114502
step 30 loss 1.2848173379898071
step 31 loss 1.280657529830932

In [12]:
#| export
from operator import attrgetter
from functools import partial

In [32]:
#| export
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [13]:
#| export
class Callback:
    order = 0


def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter("order")):
        method = getattr(cb, method_nm, None)
        if method is not None:
            method(learn)

In [30]:
# | export
class with_cbs:
    def __init__(self, nm):
        self.nm = nm

    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f"before_{self.nm}")
                f(o, *args, **kwargs)
                o.callback(f"after_{self.nm}")
            except globals()[f"Cancel{self.nm.title()}Exception"]:
                pass
            finally:
                o.callback(f"cleanup_{self.nm}")

        return _f


class Learner:
    def __init__(
        self,
        model,
        dls=(0,),
        # loss_func=F.cross_entropy,
        lr=0.1,
        cbs=None,
        opt_func=optim.SGD,
    ):
        self.model = model
        self.dls = dls
        # self.loss_func = loss_func
        self.lr = lr
        self.cbs = cbs if cbs else []
        self.opt_func = opt_func

    @with_cbs("batch")
    def _one_batch(self):
        self.predict()
        self.callback("after_predict")
        # self.get_loss()
        self.callback("after_loss")
        if self.training:
            self.backward()
            self.callback("after_backward")
            self.step()
            self.callback("after_step")
            self.zero_grad()

    @with_cbs("epoch")
    def _one_epoch(self):
        for self.iter, self.batch in enumerate(self.dl):
            self._one_batch()

    def one_epoch(self, training=True):
        self.model.train(training)
        self.dl = self.dls.train if training else self.dls.valid
        self._one_epoch()

    @with_cbs("fit")
    def _fit(self, train, valid):
        for self.epoch in self.epochs:
            if train:
                self.one_epoch(training=True)
            if valid:
                with torch.no_grad():
                    self.one_epoch(False)

    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
        # `add_cb` and `rm_cb` were added in lesson 18
        if cbs is None:
            cbs = []
        for cb in cbs:
            self.cbs.append(cb)
        try:
            self.n_epochs = n_epochs
            self.epochs = range(n_epochs)
            if lr is None:
                lr = self.lr
            if self.opt_func:
                self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(train, valid)
        finally:
            for cb in cbs:
                self.cbs.remove(cb)

    def __getattr__(self, name):
        if name in ("predict", "get_loss", "backward", "step", "zero_grad"):
            return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm):
        run_cbs(self.cbs, method_nm, self)

    @property
    def training(self):
        return self.model.training

In [27]:
# | export
default_device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)


def to_device(x, device=default_device):
    if isinstance(x, torch.Tensor):
        return x.to(device)
    if isinstance(x, Mapping):
        return {k: v.to(device) for k, v in x.items()}
    return type(x)(to_device(o, device) for o in x)

In [39]:
#| export
class DeviceCB(Callback):
    """Put model to device at the beginning of training, and put batch to device before each forward pass."""

    def __init__(self, device=default_device):
        self.device = device

    def before_fit(self, learn):
        if hasattr(learn.model, "to"):
            learn.model.to(self.device)

    def before_batch(self, learn):
        learn.batch = to_device(learn.batch, device=self.device)

In [43]:
#| export
class TrainCB(Callback):

    def predict(self, learn):
        # import pdb; pdb.set_trace()
        learn.preds, learn.loss = learn.model(*learn.batch)
        print("epoch", learn.epoch, "batch", learn.iter, "loss", learn.loss.item())

    def backward(self, learn):
        learn.loss.backward()

    def step(self, learn):
        learn.opt.step()

    def zero_grad(self, learn):
        learn.opt.zero_grad()

In [44]:
cbs = [TrainCB(), DeviceCB()]
lrn = Learner(
    model, dls=dls, opt_func=optim.AdamW, cbs=cbs, lr=3e-4
)
lrn.fit(1, valid=False)

epoch 0 batch 0 loss 3.884068012237549
