In [36]:
#| default_exp nb05

In [47]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from pathlib import Path
from operator import itemgetter

from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from fastcore.test import test_close
from fastprogress import progress_bar
import fastcore.all as fc

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

from datasets import load_dataset,load_dataset_builder
import logging

logging.disable(logging.WARNING)

In [64]:
from torch.utils.data.sampler import BatchSampler, RandomSampler, SequentialSampler
from transformers import default_data_collator
from collections.abc import Mapping

## Learner

In [49]:
x,y = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)

  0%|          | 0/2 [00:00<?, ?it/s]

In [50]:
def inplace(f):
    def _f(b):
        f(b)
        return b
    return _f

In [51]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [52]:
bs = 128
tds = dsd.with_transform(transformi)

In [53]:
def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

In [56]:
class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True):
        return cls(*[DataLoader(ds, batch_size, collate_fn=collate_dict(ds) if as_tuple else default_collate) for ds in dd.values()])

In [57]:
dls = DataLoaders.from_dd(tds, bs)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]

(torch.Size([128, 784]), tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5]))

In [58]:
def to_cuda(x):
    if isinstance(x, Mapping): return {k:v.cuda() for k,v in x.items()}
    return type(x)(o.cuda() for o in x)

In [72]:
class Learner:
    def __init__(self, model, dls, loss_func, lr, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.xb,self.yb = to_cuda(self.batch)
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
        with torch.no_grad(): self.calc_stats()
    
    def calc_stats(self):
        acc = (self.preds.argmax(dim=1)==self.yb).float().sum()
        self.accs.append(acc)
        n = len(self.xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for self.num,self.batch in enumerate(dl): self.one_batch()
        n = sum(self.ns)
        print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
    
    def fit(self, n_epochs):
        self.accs,self.losses,self.ns = [],[],[]
        model.cuda()
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            self.one_epoch(False)

In [73]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [74]:
learn = Learner(model, dls, F.cross_entropy, lr=0.001)
learn.fit(1)

0 True 1.3652537760416668 0.6493833333333333


0 False 1.3915636160714286 0.6653714285714286


In [89]:
def identity(*args):
    if not args: return
    x,*args = args
    return (x,)+args if args else x

In [75]:
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [140]:
class Learner(fc.GetAttr):
    _default='model'

    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD):
        fc.store_attr()
        for cb in cbs: cb.learner = self
    
    def with_cbs(self, nm, f, *args, **kwargs):
        ex = globals()[f'Cancel{nm.title()}Exception']
        try:
            self.callback(f'before_{nm}')
            f(*args, **kwargs)
            self.callback(f'after_{nm}')
        except ex: pass

    def one_batch(self):
        self.xb,self.yb = self.batch
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.training:
            self.loss.backward()
            self.opt.step()

    def dl(self): return self.dls.train if self.training else self.dls.valid

    def one_epoch(self):
        for self.iter,self.batch in enumerate(self.dl()): self.with_cbs('batch', self.one_batch)
    
    def fit(self, n_epochs): self.with_cbs('fit', self._fit, n_epochs)

    def _fit(self, n_epochs):
        self.opt = self.opt_func(self.parameters(), self.lr)
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.train()
            self.with_cbs('epoch', self.one_epoch)
            self.eval()
            self.with_cbs('epoch', self.one_epoch)
        
    def callback(self,name):
        for cb in self.cbs: getattr(cb,name,identity)()

`GetAttr` is a fastai class that implements Python's standard `__getattr__` and `__dir__` methods for you, such that any time you try to access an attribute that doesn't exist, it passes the request along to whatever you have defined as `_default`.

In [141]:
class Callback(fc.GetAttr): _default='learner'

In [142]:
class CudaCB(Callback):
    def before_fit(self): self.model.cuda()
    def before_batch(self): self.learner.batch = to_cuda(self.batch)

In [143]:
class ProgressCallback(Callback):
    def before_epoch(self):
        self.pbar = progress_bar(self.dl(), leave=False)
        self.pbar.update(0)
    def after_epoch(self): self.pbar.on_iter_end()
    def after_batch(self): self.pbar.update(self.iter+1)

In [144]:
class TrackResultsCB(Callback):
    def before_epoch(self): self.accs,self.losses,self.ns = [],[],[]
        
    def after_epoch(self):
        n = sum(self.ns)
        print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
        
    def after_batch(self):
        yb = self.batch['label'] if isinstance(self.batch, Mapping) else self.batch[1]
        acc = (self.preds.argmax(dim=1)==yb).float().sum()
        self.accs.append(acc)
        n = len(xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

In [145]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [146]:
cbs = [CudaCB(),TrackResultsCB(), ProgressCallback()]
learn = Learner(model, dls, F.cross_entropy, lr=0.001, cbs=cbs)
learn.fit(1)

0 True 1.4361602197577958 0.6483375533049041


0 False 1.617226226420342 0.7176621835443038


## Dict model

In [35]:
x,y = 'image','labels'

In [8]:
def data_loader(ds, batch_size, as_tuple=True):
    kw = {'collate_fn':collate_dict(ds)} if as_tuple else {}
    return DataLoader(ds, batch_size=batch_size, **kw)

In [213]:
dls = DataLoaders.from_dd(tds, bs, as_tuple=False)

In [214]:
b = next(iter(dls.train))
b

{'image': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'label': tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8, 4, 3, 0, 2, 4, 4, 5, 3, 6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5,
         9, 2, 7, 3, 0, 3, 3, 3, 7, 2, 2, 6, 6, 8, 3, 3, 5, 0, 5, 5, 0, 2, 0, 0, 4, 1, 3, 1, 6, 3, 1, 4, 4, 6, 1, 9, 1, 3, 5, 7, 9, 7, 1, 7,
         9, 9, 9, 3, 2, 9, 3, 6, 4, 1, 1, 8, 8, 0, 1, 1, 6, 8, 1, 9, 7, 8, 8, 9, 6, 6, 3, 1, 5, 4, 6, 7, 5, 5, 9, 2, 2, 2, 7, 6])}

In [219]:
class FashionMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(m,nh)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(nh,10)

    def forward(self, b):
        xb = b[x]
        yb = b[y]
        pred = self.l2(self.relu(self.l1(xb)))
        return {'preds':pred, 'loss':F.cross_entropy(pred, yb)}

In [220]:
model = FashionMLP()
learn = Learner(model, dls, identity, lr=0.001, cbs=cbs)
learn.fit(1)

0 True 1.5404933774903384 0.6351112739872068


0 False 2.156612927400613 0.7089596518987342
