In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
#export
from #

## DataBunch/Learner

In [None]:
x_train,y_train,x_valid,y_valid = #
train_ds,valid_ds = #
nh,bs = 50,64
c = #
loss_func = #

Factor out the connected pieces of info out of the fit() argument list

`fit(epochs, model, loss_func, opt, train_dl, valid_dl)`

Let's replace it with something that looks like this:

`fit(1, learn)`

This will allow us to tweak what's happening inside the training loop in other places of the code because the `Learner` object will be mutable, so changing any of its attribute elsewhere will be seen in our training loop.

In [None]:
#export
class DataBunch():
    #
        #
        
    #
    #
        
    #
    #

In [None]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [None]:
#export
def get_model#
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh##))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    #
        #

In [None]:
learn = Learner(*get_model(data), loss_func, data)

In [None]:
def fit#
    for epoch#
        #
        for#
            #
            #
            #
            #

        #
        with#
            #
            for#
                #
                #
                #
        #
        print(epoch, tot_loss##)
    return #

In [None]:
loss,acc = fit(1, learn)

0 tensor(0.1830) tensor(0.9452)


## CallbackHandler

This was our training loop (without validation) from the previous notebook, with the inner loop contents factored out:

```python
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
def fit():
    for epoch in range(epochs):
        for b in train_dl: one_batch(*b)
```

Add callbacks so we can remove complexity from loop, and make it flexible:

In [None]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

In [None]:
class Callback():
    def begin_fit#
        #
        #
    def after_fit#
    def begin_epoch#
        #
        #
    def begin_validate#
    def after_epoch#
    def begin_batch#
        #
        #
    def after_loss#
        #
        #
    def after_backward#
    def after_step#

In [None]:
class CallbackHandler():
    def __init__#
        #

    def begin_fit#
        self.learn,#
        self.learn.stop = #
        res = #
        #
        #

    def after_fit#
        #
        #
        #
    
    def begin_epoch#
        #
        #
        #
        #
        #

    def begin_validate#
        #
        #
        #
        #
        #

    def after_epoch#
        #
        #
        #
    
    def begin_batch#
        #
        #
        #

    def after_loss#
        #
        #
        #

    def after_backward#
        #
        #
        #

    def after_step#
        #
        #
        #
    
    def do_stop#
        try:     #
        finally: #

In [None]:
class TestCallback#
    def begin_fit#
        #
        #
        #
        
    def after_step#
        #
        print(self.n_iters)
        #
        return #

In [None]:
fit(1, learn, cb=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class!

## Runner

In [None]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

class Callback():
    #
    def set_runner#
    def __getattr__#
    #
    def name#
        #
        return camel2snake#

This first callback is reponsible to switch the model back and forth in training or validation mode, as well as maintaining a count of the iterations, or the percentage of iterations ellapsed in the epoch.

In [None]:
#export
class TrainEvalCallback#
    def begin_fit#
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch#
        #
        #
        #
        
    def begin_epoch#
        #
        #self.model.train()
        #

    def begin_validate#
        #
        #

We'll also re-create our TestCallback - but note this doesn't actually work right yet (can you see why?) We'll fix it in notebook 05b.

In [None]:
# Not working!
class TestCallback(Callback):
    _order=1
    def after_step(self):
        if self.n_iter>=10: return True

In [None]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [None]:
TrainEvalCallback()#

'train_eval'

In [None]:
#export
from ## import *

def listify(o):
    if o is None: #
    if isinstance(o, list): #
    if isinstance(o, str): #
    if isinstance(o, Iterable): #
    return #

In [None]:
#export
class Runner():
    def __init__#
        #
        for#
            #
            #
            #
        self.stop,self.cbs = #

    #
    def opt#
    #
    def model#
    #
    def loss_func#
    #
    def data#

    def one_batch#
        self.xb,self.yb = xb,yb
        #
        self.pred = self.model(self.xb)
        #
        self.loss = self.loss_func(self.pred, self.yb)
        #
        self.loss.backward()
        #
        self.opt.step()
        #
        self.opt.zero_grad()

    def all_batches#
        self.iters =#
        for #
            #
            self.one_batch(xb, yb)
            #
        self.stop=False

    def fit#
        #

        try:
            #
            #
            for epoch in range(epochs):
                #
                #

                #
                    #
                #
            
        finally:
            #
            #

    def __call__#
        #
            #f = getattr(cb, cb_name, None)
            #
        #

Third callback: how to compute metrics.

In [None]:
#export
class AvgStats():
    def __init__#
    
    def reset#
        #
        #
        
    #
    def all_stats#
    #
    def avg_stats#
    
    def __repr__#
        if not self.count: return ""
        return #

    def accumulate#
        #
        #
        #
        #for
            #

class AvgStatsCallback#
    def __init__#
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
        
    def begin_epoch#
        #
        #
        
    def after_loss#
        #
        #
    
    def after_epoch#
        #
        print(#

In [None]:
learn = Learner(*get_model(data), loss_func, data)

In [None]:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)

In [None]:
run.fit(2, learn)

train: [0.31685572265625, tensor(0.9033)]
valid: [0.15482904052734375, tensor(0.9553)]
train: [0.143599248046875, tensor(0.9567)]
valid: [0.116745849609375, tensor(0.9652)]


In [None]:
loss,acc = stats.valid_stats.avg_stats
assert acc>0.9
loss,acc

(0.116745849609375, tensor(0.9652))

In [None]:
#export
from ## import partial

In [None]:
acc_cbf = #

In [None]:
run = Runner(cb_funcs=acc_cbf)

In [None]:
run.fit(1, learn)

train: [0.108607373046875, tensor(0.9666)]
valid: [0.131622998046875, tensor(0.9607)]


Using Jupyter means we can get tab-completion even for dynamic code like this! :)

In [None]:
run.avg_stats.valid_stats.avg_stats

[0.131622998046875, tensor(0.9607)]

## Export

In [None]:
!python notebook2script.py 04_callbacks.ipynb