# Learner, Callbacks, Metrics and friends

Implementing the Learner and associated classes, callbacks, context managers and metrics etc

Start of build for convenience and experimentation tools, including lost of potential overcomplication from j.howard

## Context managers

In [30]:
class rnd_ctx:
    def __init__(self, msg): 
        self.msg = msg
    def __enter__(self):
        print('in', self.msg)
    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is not None: print(f'swallowing exception on exit: {exc_type} with value "{exc_value}"')
        print('exiting', self.msg)
        return True


In [31]:
with rnd_ctx('this my manager'):
    print('code in context')
    raise RuntimeError('aint having this ish')

print('not in contextmanager')

in this my manager
code in context
swallowing exception on exit: <class 'RuntimeError'> with value "aint having this ish"
exiting this my manager
not in contextmanager


In [32]:
from contextlib import contextmanager

In [73]:
@contextmanager
def another_ctx(msg):
    print('setting up context')
    vals = list(range(5))
    try:
        yield vals
    finally:
        print('exiting context')

In [74]:
with another_ctx('this another context') as vals:
    for v in vals: print(v)
    # print(vals)

setting up context
0
1
2
3
4
exiting context


### Decorator

In [75]:
def inplace(f):
    def _f(b):
        f(b)
        return b
    return _f

In [76]:
def testing(b): b[0] = None

In [77]:
testing([0, 1, 2, 3])

In [78]:
@inplace
def testing(b): b[0] = None

In [79]:
testing([0, 1, 2, 3])

[None, 1, 2, 3]

## TODO move above code down to below main content

In [1]:
#|export
import math,torch,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy

from torch import optim
import torch.nn.functional as F

from miniai.conv import *

from fastprogress import progress_bar,master_bar

In [2]:
import matplotlib as mpl
import torchvision.transforms.functional as TF
from contextlib import contextmanager
from torch import nn,tensor
from datasets import load_dataset,load_dataset_builder
from miniai.datasets import *
from miniai.conv import *
import logging
from fastcore.test import test_close

In [4]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

In [6]:
logging.disable(logging.WARNING)

## Learner

In [8]:
x, y = 'image', 'label'
name = 'fashion_mnist'
dsd = load_dataset(name)

Downloading builder script:   0%|          | 0.00/4.83k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/3.13k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.85k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/26.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/29.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.42M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [13]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [14]:
bs = 1024
tds = dsd.with_transform(transformi)

In [15]:
dls = DataLoaders.from_dd(tds, bs, num_workers=4)
dt = dls.train
xb, yb = next(iter(dt))
xb.shape, yb.shape

(torch.Size([1024, 784]), torch.Size([1024]))

## Basic Callbacks Learner

In [17]:
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [19]:
class Callback(): order = 0

In [24]:
def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None: method(learn)

In [25]:
class CompletionCB(Callback):
    def before_fit(self, learn): self.count = 0
    def after_batch(self, learn): self.count += 1
    def after_fit(self, learn): print(f'Completed {self.count} batches')

In [43]:
cbs = [CompletionCB()]
run_cbs(cbs, 'before_fit')
run_cbs(cbs, 'after_batch')
run_cbs(cbs, 'after_fit')

Completed 1 batches


In [49]:
class Learner():
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD):
        self.model, self.dls, self.loss_func, self.lr, self.cbs, self.opt_func = model, dls, loss_func, lr, cbs, opt_func

    def one_batch(self):
        self.preds = self.model(self.batch[0])
        self.loss = self.loss_func(self.preds, self.batch[1])
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        
    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        try:
            self.callback('before_epoch')
            for self.iter, self.batch in enumerate(self.dl):
                try:
                    self.callback('before_batch')
                    self.one_batch()
                    self.callback('after_batch')
                except CancelBatchException: pass
            self.callback('after_epoch')
        except CancelEpochException: pass

    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        try:
            self.callback('before_fit')
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epoch(False)
            self.callback('after_fit')
        except CancelFitException: pass

    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)

In [50]:
m, nh = 28 * 28, 50
def get_model(): return nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, 10))

In [51]:
model = get_model()
learn = Learner(model, dls, F.cross_entropy, 0.2, cbs=[CompletionCB()])
learn.fit(1)

Completed 64 batches


In [56]:
class SingleBatchCB(Callback):
    order = 1
    def after_batch(self, learn): raise CancelFitException()

In [57]:
learn = Learner(model, dls, F.cross_entropy, 0.2, cbs=[CompletionCB(), SingleBatchCB()])
learn.fit(1)

## Metrics

In [59]:
class Metric:
    def __init__(self): self.reset()
    def reset(self): self.vals, self.ns = [], []
    def add(self, inp, targ=None, n=1):
        self.last = self.calc(inp, targ)
        self.vals.append(self.last)
        self.ns.append(n)
    @property
    def value(self):
        ns = tensor(self.ns)
        return (tensor(self.vals) * ns).sum() / ns.sum()
    def calc(self, inps, targs): return inps

In [60]:
class Accuracy(Metric):
    def calc(self, inps, targs): return (inps == targs).float().mean()

In [62]:
acc = Accuracy()
acc.add(tensor([0, 1, 2, 0, 1, 2]), tensor([0, 1, 1, 2, 1, 0]))
acc.add(tensor([1, 1, 2, 0, 1]), tensor([0, 1, 1, 2, 1]))
acc.value

tensor(0.45)

In [63]:
loss = Metric()
loss.add(0.69, n=32)
loss.add(0.9, n=2)
loss.value, (0.69 * 32 + 0.9 * 2) / (32 + 2)

(tensor(0.70), 0.7023529411764705)