In [2]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [3]:
#export
from exp.nb_03 import *

## DataBunch/Learner

In [4]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid) 
nh,bs = 50,64
c = y_train.max() + 1
loss_func = F.cross_entropy

Factor out the connected pieces of info out of the fit() argument list

`fit(epochs, model, loss_func, opt, train_dl, valid_dl)`

Let's replace it with something that looks like this:

`fit(1, learn)`

This will allow us to tweak what's happening inside the training loop in other places of the code because the `Learner` object will be mutable, so changing any of its attribute elsewhere will be seen in our training loop.

In [5]:
get_dls??

In [6]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c):
        self.train_dl, self.valid_dl, self.c = train_dl, valid_dl, c
        
    @property
    def train_ds(self): return self.train_dl.dataset
        
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [7]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [8]:
#export
def get_model(data, lr=0.5):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh, int(data.c)))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model, self.opt, self.loss_func, self.data = model, opt, loss_func, data

In [9]:
learn = Learner(*get_model(data), loss_func, data)

In [10]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train()
        for xb, yb in learn.data.train_dl:
            loss = learn.loss_func(learn.model(xb), yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()

        learn.model.eval()
        with torch.no_grad():
            tot_loss, tot_acc = 0., 0.
            for xb, yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += loss_func(pred, yb)
                tot_acc += accuracy(pred, yb)
        nv = len(learn.data.valid_dl)
        print('2o3', epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [11]:
loss,acc = fit(1, learn)

2o3 0 tensor(0.2891) tensor(0.9174)


## CallbackHandler

This was our training loop (without validation) from the previous notebook, with the inner loop contents factored out:

```python
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
def fit():
    for epoch in range(epochs):
        for b in train_dl: one_batch(*b)
```

Add callbacks so we can remove complexity from loop, and make it flexible:

In [12]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

In [13]:
def one_batch(xb, yb, ch):
    if not ch.begin_batch(xb, yb): return
    loss = ch.learn.loss_func(ch.learn.model(xb), yb)
    if not ch.after_loss(loss): return
    loss.backward()
    if ch.after_backward(): ch.learn.opt.step()
    if ch.after_step(): ch.learn.opt.zero_grad()

def all_batches(dl, ch):
    for xb, yb in dl:
        one_batch(xb, yb, ch)
        if ch.do_stop(): return

def fit(epochs, learn, ch):
    if not ch.begin_fit(learn): return
    for epoch in range(epochs):
        if not ch.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, ch)
        
        if ch.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, ch)
        if ch.do_stop() or not ch.after_epoch(): break
    ch.after_fit()

In [18]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        #print('self.learn.in_train', self.learn.in_train)
        return (True if self.learn.in_train else False)    
    def after_backward(self): return True
    def after_step(self): return True

In [19]:
class CallbackHandler():
    def __init__(self, cbs=None):
        self.cbs = cbs

    def begin_fit(self, learn):
        self.learn = learn
        self.learn.stop = False
        res = True
        for cb in self.cbs: res = res and cb.begin_fit(learn)
        return res

    def after_fit(self):
        res = True
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.epoch, self.learn.in_train = epoch, True
        self.learn.model.train()
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return True

    def begin_validate(self):
        self.learn.in_train = False
        self.learn.model.eval()
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res

    def after_loss(self, loss):
        res = True
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try: 
            return self.learn.stop
        finally: self.learn.stop = False

In [22]:
class TestCallback(Callback):
    def begin_fit(self, learn):
        self.learn = learn
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters >= 10: self.learn.stop = True
        return True

In [24]:
fit(1, learn, ch=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class!

## Runner

In [25]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

class Callback():
    _order = 0
    def set_runner(self, run): self.run = run
    def __getattr__(self, k):
        return getattr(self.run, k, self.k)
    def name(self):
        cname = re.sub(r'Callback$', '', self.__class__.name)
        return camel2snake(cname)

This first callback is reponsible to switch the model back and forth in training or validation mode, as well as maintaining a count of the iterations, or the percentage of iterations ellapsed in the epoch.

In [None]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        #
        self.n_iter += 1
        #self.n_iter / 
        
    def begin_epoch#
        #
        #self.model.train()
        #

    def begin_validate#
        #
        #

We'll also re-create our TestCallback - but note this doesn't actually work right yet (can you see why?) We'll fix it in notebook 05b.

In [None]:
# Not working!
class TestCallback(Callback):
    _order=1
    def after_step(self):
        if self.n_iter>=10: return True

In [None]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [None]:
TrainEvalCallback()#

'train_eval'

In [None]:
#export
from ## import *

def listify(o):
    if o is None: #
    if isinstance(o, list): #
    if isinstance(o, str): #
    if isinstance(o, Iterable): #
    return #

In [None]:
#export
class Runner():
    def __init__#
        #
        for#
            #
            #
            #
        self.stop,self.cbs = #

    @property
    def opt(self): return self.learn.opt
    @property
    def model(self): return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self): return self.learn.data

    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        self('begin_batch')
        self.pred = self.model(self.xb)
        self('after_pred')
        self.loss = self.loss_func(self.pred, self.yb)
        self('after_loss')
        self.loss.backward()
        self('after_backward')
        self.opt.step()
        self('after_step')
        self.opt.zero_grad()
        self('after_batch')

    def all_batches#
        self.iters =#
        for #
            #
            self.one_batch(xb, yb)
            #
        self.stop=False

    def fit#
        #

        try:
            #
            #
            for epoch in range(epochs):
                #
                #

                #
                    #
                #
            
        finally:
            #
            #

    def __call__#
        #
            #f = getattr(cb, cb_name, None)
            #
        #

Third callback: how to compute metrics.

In [None]:
#export
class AvgStats():
    def __init__#
    
    def reset#
        #
        #
        
    #
    def all_stats#
    #
    def avg_stats#
    
    def __repr__#
        if not self.count: return ""
        return #

    def accumulate#
        #
        #
        #
        #for
            #

class AvgStatsCallback#
    def __init__#
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
        
    def begin_epoch#
        #
        #
        
    def after_loss#
        #
        #
    
    def after_epoch#
        #
        print(#

In [None]:
learn = Learner(*get_model(data), loss_func, data)

In [None]:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)

In [None]:
run.fit(2, learn)

train: [0.31685572265625, tensor(0.9033)]
valid: [0.15482904052734375, tensor(0.9553)]
train: [0.143599248046875, tensor(0.9567)]
valid: [0.116745849609375, tensor(0.9652)]


In [None]:
loss,acc = stats.valid_stats.avg_stats
assert acc>0.9
loss,acc

(0.116745849609375, tensor(0.9652))

In [None]:
#export
from ## import partial

In [None]:
acc_cbf = #

In [None]:
run = Runner(cb_funcs=acc_cbf)

In [None]:
run.fit(1, learn)

train: [0.108607373046875, tensor(0.9666)]
valid: [0.131622998046875, tensor(0.9607)]


Using Jupyter means we can get tab-completion even for dynamic code like this! :)

In [None]:
run.avg_stats.valid_stats.avg_stats

[0.131622998046875, tensor(0.9607)]

## Export

In [None]:
!python notebook2script.py 04_callbacks.ipynb