In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#export
from nb_003a import *

ModuleNotFoundError: No module named 'dataclasses'

In [None]:
DATA_PATH = Path('data')
PATH = DATA_PATH/'cifar10'

In [None]:
data_mean,data_std = map(tensor, ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]))
cifar_norm = normalize_tfm(mean=data_mean,std=data_std)

In [None]:
tfms = [flip_lr_tfm(p=0.5),
        pad_tfm(padding=4),
        crop_tfm(size=32, row_pct=(0,1.), col_pct=(0,1.))]

In [None]:
bs = 64

In [None]:
train_ds = FilesDataset.from_folder(PATH/'train', classes=['airplane','dog'])
valid_ds = FilesDataset.from_folder(PATH/'test', classes=['airplane','dog'])
data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms, valid_tfm=[], num_workers=4)
len(data.train_dl), len(data.valid_dl)

## Training loop so far

In [None]:
def loss_batch(model, xb, yb, loss_fn, opt=None):
    loss = loss_fn(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
        
    return loss.item(), len(xb)

In [None]:
def fit(epochs, model, loss_fn, opt, train_dl, valid_dl):
    for epoch in tnrange(epochs):
        model.train()
        for xb,yb in train_dl:
            loss,_ = loss_batch(model, xb, yb, loss_fn, opt)
            if train_dl.progress_func is not None: train_dl.gen.set_postfix_str(loss)

        model.eval()
        with torch.no_grad():
            losses,nums = zip(*[loss_batch(model, xb, yb, loss_fn)
                                for xb,yb in valid_dl])
        val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)

        print(epoch, val_loss)

In [None]:
class Learner():
    def __init__(self, data, model):
        self.data,self.model = data,model.to(data.device)

    def fit(self, epochs, lr, opt_fn=optim.SGD):
        opt = opt_fn(self.model.parameters(), lr=lr)
        loss_fn = F.cross_entropy
        fit(epochs, self.model, loss_fn, opt, self.data.train_dl, self.data.valid_dl)

## Model

In [None]:
def conv_layer(ni, nf, ks=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),
        nn.BatchNorm2d(nf),
        nn.LeakyReLU(negative_slope=0.1, inplace=True))

class ResLayer(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.conv1=conv_layer(ni, ni//2, ks=1)
        self.conv2=conv_layer(ni//2, ni, ks=3)
        
    def forward(self, x): return x + self.conv2(self.conv1(x))

class Darknet(nn.Module):
    def make_group_layer(self, ch_in, num_blocks, stride=1):
        return [conv_layer(ch_in, ch_in*2,stride=stride)
               ] + [(ResLayer(ch_in*2)) for i in range(num_blocks)]

    def __init__(self, num_blocks, num_classes, nf=32):
        super().__init__()
        layers = [conv_layer(3, nf, ks=3, stride=1)]
        for i,nb in enumerate(num_blocks):
            layers += self.make_group_layer(nf, nb, stride=2-(i==1))
            nf *= 2
        layers += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(nf, num_classes)]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x): return self.layers(x)

In [None]:
model = Darknet([1, 2, 4, 6, 3], num_classes=10, nf=16)

## Setting hyperparameters easily

We want an optimizer with an easy way to set hyperparameters: they're all properties and we define custom setters to handle the different names in pytorch optimizers.

In [None]:
#export
class HPOptimizer():
    
    def __init__(self, params, opt_fn, init_lr, true_wd=False):
        self.opt = opt_fn(params, init_lr)
        self._lr, self.true_wd = init_lr, true_wd
        self.opt_keys = list(self.opt.param_groups[0].keys())
        self.opt_keys.remove('params')
        self.read_defaults()
    
    #Pytorch optimizer methods
    def step(self):
        #Does the weight decay outside of the optimizer step for AdamW
        if self.true_wd:
            for pg in self.opt.param_groups:
                for p in pg['params']:
                    p.data.mul_(1 - self._wd * pg['lr'])
            self.set_val('weight_decay', 0)
        self.opt.step()
    
    def zero_grad(self):
        self.opt.zero_grad()
    
    #Hyperparameters as properties
    @property
    def lr(self): return self._lr

    @lr.setter
    def lr(self, val):
        self.set_val('lr', val)
        self._lr = val
    
    @property
    def mom(self): return self._mom

    @mom.setter
    def mom(self, val):
        if 'momentum' in self.opt_keys: self.set_val('momentum', val)
        elif 'betas' in self.opt_keys:  self.set_val('betas', (val, self._beta))
        self._mom = val
    
    @property
    def beta(self): return self._beta

    @beta.setter
    def beta(self, val):
        if 'betas' in self.opt_keys:    self.set_val('betas', (self._mom,val))
        elif 'alpha' in self.opt_keys:  self.set_val('alpha', val)
        self._beta = val
    
    @property
    def wd(self): return self._wd

    @wd.setter
    def wd(self, val):
        if not self.true_wd: self.set_val('weight_decay', val)
        self._wd = val
    
    #Helper functions
    def read_defaults(self):
        if 'momentum' in self.opt_keys: self._mom = self.opt.param_groups[0]['momentum']
        if 'alpha' in self.opt_keys: self._beta = self.opt.param_groups[0]['alpha']
        if 'betas' in self.opt_keys: self._mom,self._beta = self.opt.param_groups[0]['betas']
        if 'weight_decay' in self.opt_keys: self._wd = self.opt.param_groups[0]['weight_decay']
    
    def set_val(self, key, val):
        for pg in self.opt.param_groups: pg[key] = val

In [None]:
opt_fn = partial(optim.Adam, betas=(0.95,0.99))

In [None]:
opt = HPOptimizer(model.parameters(), opt_fn, 1e-2)

In [None]:
opt.lr, opt.mom, opt.wd, opt.beta

In [None]:
opt.lr=0.2

In [None]:
opt.lr, opt.mom, opt.wd, opt.beta

Now that it's easy to set and change the HP in the optimizer, we need a scheduler to change it. To keep the training loop as readable as possible we don't want to handle all of this stuff inside it so we'll use callbacks. 

In [None]:
class Callback():
    def on_train_begin(self): pass         
        #To initiliaze constants in the callback.
    def on_epoch_begin(self): pass
        #At the beginning of each epoch
    def on_batch_begin(self, xb, yb): pass 
        #To set HP before the step is done. A look at the input can be useful (set the lr depending on the seq_len in RNNs, 
        #or for reg_functions called in on_backward_begin)
        #Returns xb, yb (which can allow us to modify the input at that step if needed)
    def on_backward_begin(self, loss, out): pass
        #Called after the forward pass and the loss has been computed, but before the back propagation.
        #Passes the loss and the output of the model.
        #Returns the loss (which can allow us to modify it, for instance for reg functions)
    def on_backward_end(self): pass
        #Called after the back propagation had been done (and the gradients computed) but before the step of the optimizer.
        #Useful for true weight decay in AdamW
    def on_step_end(self): pass
        #Called after the step of the optimizer but before the gradients are zeroed (not sure this one is useful)
    def on_batch_end(self, loss): pass
        #Called at the end of the batch
    def on_epoch_end(self, val_loss): pass
        #Called at the end of an epoch
    def on_train_end(self): pass
        #Useful for cleaning up things and saving files/models

The idea is to have a callback between every line of the training loop, that way everything we need to add will be treated there and not inside.

In [None]:
def loss_batch(model, xb, yb, loss_fn, opt=None, callbacks=[]):
    out = model(xb)
    loss = loss_fn(out, yb)

    if opt is not None:
        for cb in callbacks: loss = cb.on_backward_begin(loss, out)
        loss.backward()
        for cb in callbacks: cb.on_backward_end()
        opt.step()
        for cb in callbacks: cb.on_step_end()
        opt.zero_grad()
        
    return loss.item(), len(xb)

In [None]:
def fit(epochs, model, loss_fn, opt, train_dl, valid_dl,callbacks=[]):
    for cb in callbacks: cb.on_train_begin()
    for epoch in tnrange(epochs):
        model.train()
        for cb in callbacks: cb.on_epoch_begin()
        for xb,yb in train_dl:
            for cb in callbacks: xb, yb = cb.on_batch_begin(xb, yb)
            loss,_ = loss_batch(model, xb, yb, loss_fn, opt)
            if train_dl.progress_func is not None: train_dl.gen.set_postfix_str(loss)
            for cb in callbacks: cb.on_batch_end(loss)
            
        model.eval()
        with torch.no_grad():
            losses,nums = zip(*[loss_batch(model, xb, yb, loss_fn)
                                for xb,yb in valid_dl])
        val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)
        
        for cb in callbacks: cb.on_epoch_end(val_loss)
        print(epoch, val_loss)
    for cb in callbacks: cb.on_train_end()

First callback: updating the progress bar can be done and printing the validation loss in one. We'll also keep track of the losses and hyper-parameters during training for future plots (lr_finder, plot of the LR/mom schedule).

In [None]:
class Recorder(Callback):
    
    def __init__(self, opt, train_dl=None):
        self.opt,self.train_dl = opt,train_dl
    
    def on_train_begin(self):
        self.epoch = 0
        self.losses,self.val_losses,self.lrs,self.moms = [],[],[],[]
    
    def on_batch_begin(self, xb, yb):
        self.lrs.append(self.opt.lr)
        self.moms.append(self.opt.mom)
        return xb, yb
    
    def on_batch_end(self, loss):
        self.losses.append(loss)
        if self.train_dl is not None and self.train_dl.progress_func is not None: 
            self.train_dl.gen.set_postfix_str(loss)
    
    def on_epoch_end(self, val_loss):
        self.val_losses.append(val_loss)
        print(self.epoch, val_loss)
        self.epoch += 1

In [None]:
def fit(epochs, model, loss_fn, opt, train_dl, valid_dl=None,callbacks=[]):
    for cb in callbacks: cb.on_train_begin()
    for epoch in tnrange(epochs):
        model.train()
        for cb in callbacks: cb.on_epoch_begin()
        for xb,yb in train_dl:
            for cb in callbacks: xb, yb = cb.on_batch_begin(xb, yb)
            loss,_ = loss_batch(model, xb, yb, loss_fn, opt)
            for cb in callbacks: cb.on_batch_end(loss)
        
        if valid_dl is not None:
            model.eval()
            with torch.no_grad():
                losses,nums = zip(*[loss_batch(model, xb, yb, loss_fn)
                                for xb,yb in valid_dl])
            val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)
        else: val_loss = None
        for cb in callbacks: cb.on_epoch_end(val_loss)
        
    for cb in callbacks: cb.on_train_end()

In [None]:
class Learner():
    def __init__(self, data, model):
        self.data,self.model = data,model.to(data.device)
        self.loss_fn, self.opt_fn = F.cross_entropy, optim.SGD

    def fit(self, epochs, lr):
        self.opt = HPOptimizer(self.model.parameters(), self.opt_fn, init_lr=lr)
        self.recorder = Recorder(self.opt, self.data.train_dl)
        fit(epochs, self.model, self.loss_fn, self.opt, self.data.train_dl, self.data.valid_dl, callbacks=[self.recorder])

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)

In [None]:
learn.fit(2,0.1)

This is all very well but what if someone forgets to return xb,yb or the loss in the callbacks that can change it? To be more convenient and make the code of the training loop cleaner, we'll create a class to handle the callbacks.

In [None]:
class Callback():
    def on_train_begin(self): pass         
        #To initiliaze constants in the callback.
    def on_epoch_begin(self): pass
        #At the beginning of each epoch
    def on_batch_begin(self, xb, yb): pass 
        #To set HP before the step is done. A look at the input can be useful (set the lr depending on the seq_len in RNNs, 
        #or for reg_functions called in on_backward_begin)
        #Returns xb, yb (which can allow us to modify the input at that step if needed)
    def on_loss_begin(self, out): pass
        #Called after the forward pass but before the loss has been computed.
        #Passes the output of the model.
        #Returns the output (which can allow us to modify it)
    def on_backward_begin(self, loss): pass
        #Called after the forward pass and the loss has been computed, but before the back propagation.
        #Passes the loss of the model.
        #Returns the loss (which can allow us to modify it, for instance for reg functions)
    def on_backward_end(self): pass
        #Called after the back propagation had been done (and the gradients computed) but before the step of the optimizer.
        #Useful for true weight decay in AdamW
    def on_step_end(self): pass
        #Called after the step of the optimizer but before the gradients are zeroed (not sure this one is useful)
    def on_batch_end(self, loss): pass
        #Called at the end of the batch
    def on_epoch_end(self, val_loss): pass
        #Called at the end of an epoch
    def on_train_end(self): pass
        #Useful for cleaning up things and saving files/models

In [None]:
class CallbackHandler():
    
    def __init__(self, callbacks):
        self.callbacks = callbacks
    
    def __call__(self, cb_name, *args, **kwargs):
        return [getattr(cb, f'on_{cb_name}')(*args, **kwargs) for cb in self.callbacks]
    
    def on_train_begin(self): self('train_begin')
    def on_epoch_begin(self): self('epoch_begin')
        
    def on_batch_begin(self, xb, yb):
        for cb in self.callbacks:
            a = cb.on_batch_begin(xb,yb)
            if a is not None: xb,yb = a
        return xb,yb
    
    def on_loss_begin(self, out):
        for cb in self.callbacks:
            a = cb.on_loss_begin(out)
            if a is not None: out = a
        return out
    
    def on_backward_begin(self, loss):
        for cb in self.callbacks:
            a = cb.on_backward_begin(loss)
            if a is not None: loss = a
        return loss
    
    def on_backward_end(self):        self('backward_end')
    def on_step_end(self):            self('step_end')
    def on_batch_end(self, loss):     return np.any(self('batch_end', loss))
    def on_epoch_end(self, val_loss): return np.any(self('epoch_end', val_loss))
    def on_train_end(self):           self('train_end')

In [None]:
def loss_batch(model, xb, yb, loss_fn, opt=None, cb_handler=CallbackHandler([])):
    out = model(xb)
    out = cb_handler.on_loss_begin(out)
    loss = loss_fn(out, yb)

    if opt is not None:
        loss = cb_handler.on_backward_begin(loss)
        loss.backward()
        cb_handler.on_backward_end()
        opt.step()
        cb_handler.on_step_end()
        opt.zero_grad()
        
    return loss.item(), len(xb)

In [None]:
def fit(epochs, model, loss_fn, opt, train_dl, valid_dl=None, callbacks=[]):
    cb_handler = CallbackHandler(callbacks)
    cb_handler.on_train_begin()
    
    for epoch in tnrange(epochs):
        model.train()
        cb_handler.on_epoch_begin()
        
        for xb,yb in train_dl:
            xb, yb = cb_handler.on_batch_begin(xb, yb)
            loss,_ = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)
            if cb_handler.on_batch_end(loss): break
        
        if valid_dl is not None:
            model.eval()
            with torch.no_grad():
                losses,nums = zip(*[loss_batch(model, xb, yb, loss_fn)
                                for xb,yb in valid_dl])
            val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)
        else: val_loss=None
        if cb_handler.on_epoch_end(val_loss): break
        
    cb_handler.on_train_end()

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)

In [None]:
learn.fit(2,0.1)

Now we can do a 1cycle scheduler pretty easily.

In [None]:
#export
def annealing_no(start, end, pct): return start
def annealing_linear(start, end, pct): return start + pct * (end-start)
def annealing_exponential(start, end, pct): return start * (end/start) ** pct

In [None]:
#export
def is_tuple(x): return isinstance(x, tuple)

In [None]:
#export
class Stepper():
    
    def __init__(self, vals, num_it, ft=None):
        if is_tuple(vals): self.start,self.end = vals
        else:              self.start,self.end = vals, 0
        #Why doesn't this one work?
        #(self.start,self.end) = (vals[0],vals[1]) if is_tuple(vals) else vals,0
        self.num_it = num_it
        if ft is None: self.ft = annealing_linear if is_tuple(vals) else annealing_no
        else:          self.ft = ft
        self.n = 0
    
    def step(self):
        self.n += 1
        return self.ft(self.start, self.end, self.n/self.num_it)
    
    def is_done(self):  return self.n >= self.num_it
    def init_val(self): return self.start

In [None]:
class OneCycleScheduler(Callback):
    
    def __init__(self, learn, lr_max, epochs, moms=(0.95,0.85), div_factor=10, pct_end=0.1):
        self.learn = learn
        a = int(len(learn.data.train_dl) * epochs * (1 - pct_end) / 2)
        b = int(len(learn.data.train_dl) * epochs * pct_end)
        self.lr_scheds = [Stepper((lr_max/div_factor, lr_max), a),
                          Stepper((lr_max, lr_max/div_factor), a),
                          Stepper((lr_max/div_factor, lr_max/(div_factor*100)), b)]
        self.mom_scheds = [Stepper(moms, a), Stepper((moms[1], moms[0]), a), Stepper(moms[0], b)]
    
    def on_train_begin(self):
        self.opt = self.learn.opt
        self.opt.lr, self.opt.mom = self.lr_scheds[0].init_val(), self.mom_scheds[0].init_val()
        self.idx_s = 0
    
    def on_batch_end(self, loss):
        self.opt.lr = self.lr_scheds[self.idx_s].step()
        self.opt.mom = self.mom_scheds[self.idx_s].step()
        if self.lr_scheds[self.idx_s].is_done():
            self.idx_s += 1
            if self.idx_s >= len(self.lr_scheds): return True

In [None]:
class Learner():
    def __init__(self, data, model):
        self.data,self.model = data,model.to(data.device)
        self.loss_fn, self.opt_fn = F.cross_entropy, optim.SGD

    def fit(self, epochs, lr, callbacks=[]):
        self.opt = HPOptimizer(self.model.parameters(), self.opt_fn, init_lr=lr)
        self.recorder = Recorder(self.opt, self.data.train_dl)
        cbs = [self.recorder] + callbacks
        fit(epochs, self.model, self.loss_fn, self.opt, self.data.train_dl, self.data.valid_dl, callbacks=cbs)

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)
sched = OneCycleScheduler(learn, 0.1, 5)

In [None]:
learn.fit(5,0.1,callbacks=[sched])

In [None]:
iterations = list(range(len(learn.recorder.lrs)))
fig, axs = plt.subplots(1,2, figsize=(12,4))
axs[0].plot(iterations, learn.recorder.lrs)
axs[1].plot(iterations, learn.recorder.moms)

Or a LR Finder

In [None]:
class LRFinder(Callback):
    
    def __init__(self, learn, start_lr=1e-5, end_lr=10, num_it=200):
        self.learn = learn
        self.sched = Stepper((start_lr, end_lr), num_it, annealing_exponential)
    
    def on_train_begin(self):
        self.opt = self.learn.opt
        self.opt.lr = self.sched.init_val()
        self.stop,self.first,self.best_loss = False,True,0.
    
    def on_batch_end(self, loss):
        if self.first or loss < self.best_loss:
            self.first = False
            self.best_loss = loss
        self.opt.lr = self.sched.step()
        if self.sched.is_done() or loss > 4*self.best_loss: 
            self.stop=True
            return True
    
    def on_epoch_end(self, val_loss): return self.stop

In [None]:
class Learner():
    def __init__(self, data, model):
        self.data,self.model = data,model.to(data.device)
        self.loss_fn, self.opt_fn = F.cross_entropy, optim.SGD

    def fit(self, epochs, lr, callbacks=[], val=True):
        self.opt = HPOptimizer(self.model.parameters(), self.opt_fn, init_lr=lr)
        self.recorder = Recorder(self.opt, self.data.train_dl)
        cbs = [self.recorder] + callbacks
        fit(epochs, self.model, self.loss_fn, self.opt, self.data.train_dl, 
            self.data.valid_dl if val else None, callbacks=cbs)
        
    def lr_find(self, start_lr=1e-5, end_lr=10, num_it=200):
        cb = LRFinder(self, start_lr, end_lr, num_it)
        a = int(np.ceil(num_it/len(self.data.train_dl)))
        self.fit(a, start_lr, callbacks=[cb], val=False)

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)

In [None]:
learn.lr_find()

In [None]:
fig, ax = plt.subplots(1,1)
ax.plot(learn.recorder.lrs, learn.recorder.losses)
ax.set_xscale('log')

That's a bit shaky since we plot the loss and not a smoothened version of it. Let's change the recorder so it records a moving average of the loss. We'll also add the plotting functions inside it.

In [None]:
class Recorder(Callback):
    beta = 0.98
    
    def __init__(self, opt, train_dl=None):
        self.opt,self.train_dl = opt,train_dl
    
    def on_train_begin(self):
        self.epoch,self.n,self.avg_loss = 0,0,0.
        self.losses,self.val_losses,self.lrs,self.moms = [],[],[],[]
    
    def on_batch_begin(self, xb, yb):
        self.lrs.append(self.opt.lr)
        self.moms.append(self.opt.mom)
        return xb, yb
    
    def on_backward_begin(self, loss):
        #We record the loss here before any other callback has a chance to modify it.
        self.n += 1
        self.avg_loss = self.beta * self.avg_loss + (1-self.beta) * loss.item()
        self.smooth_loss = self.avg_loss / (1 - self.beta ** self.n)
        self.losses.append(self.smooth_loss)
        if self.train_dl is not None and self.train_dl.progress_func is not None: 
            self.train_dl.gen.set_postfix_str(self.smooth_loss)
    
    def on_epoch_end(self, val_loss):
        self.val_losses.append(val_loss)
        print(self.epoch, val_loss)
        self.epoch += 1
    
    def plot_lr(self, show_moms=False):
        iterations = list(range(len(learn.recorder.lrs)))
        if show_moms:
            fig, axs = plt.subplots(1,2, figsize=(12,4))
            axs[0].plot(iterations, self.lrs)
            axs[1].plot(iterations, self.moms)
        else: plt.plot(iterations, self.lrs)
    
    def plot(self, skip_start=10, skip_end=5):
        lrs = self.lrs[skip_start:-skip_end] if skip_end > 0 else self.lrs[skip_start:]
        losses = self.losses[skip_start:-skip_end] if skip_end > 0 else self.losses[skip_start:]
        fig, ax = plt.subplots(1,1)
        ax.plot(lrs, losses)
        ax.set_xscale('log')

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

To grasp the potential of callbacks, here's a full example:

In [None]:
class EyeOfSauron(Callback):
    
    def __init__(self, learn):
        #By passing the learner, this callback will have access to everything:
        #All the inputs/outputs as they go, the losses, but also the data loaders, the optimizer.
        self.learn = learn
        
        #At any time:
        #Changing self.learn.data.train_dl or self.data.valid_dl will change them inside the fit function
        #(we just need to pass the data object to the fit function and not data.train_dl/data.valid_dl)
        #Changing self.learn.opt.opt (We have an HPOptimizer on top of the actual optimizer) will change it 
        #inside the fit function.
        #Changing self.learn.data or self.learn.opt directly WILL NOT change the data or the optimizer inside the fit function.

    def on_train_begin(self):
        #Here we can initiliaze anything we need. 
        self.opt = self.learn.opt
        #The optimizer has now been initialized. We can change any hyper-parameters by typing
        #self.opt.lr = new_lr, self.opt.mom = new_mom, self.opt.wd = new_wd or self.opt.beta = new_beta
        
    def on_epoch_begin(self):
        #This is not technically useful since we have on_train_begin for epoch 0 and on_epoch_end for all the other epochs
        #yet it makes writing code that needs to be done at the beginning of every epoch easy.
        
    def on_batch_begin(self, xb, yb):
        #If we need to access anything inside the input or target for here or later, we can just save them.
        self.xb,self.yb = xb,yb
        #Here is the perfect place to prepare everything before the model is called.
        #Example: change the values of the hyperparameters (if we don't do it on_batch_end instead)
        
        #If we return something, that will be the new value for xb,yb. 
    
    def on_loss_begin(self, lout):
        #If we need the output of the model for here or later, we can just save it.
        self.out = out
        #Here is the place to run some code that needs to be executed after the output has been computed but before the
        #loss computation.
        #Example: putting the output back in FP32 when training in mixed precision.
        
        #If we return something, that will be the new value for the output.
    
    def on_backward_begin(self, loss, out):
        #If we need the loss of the model for here or later, we can just save it.
        self.raw_loss,self.out = loss,out
        #Here is the place to run some code that needs to be executed after the loss has been computed but before the
        #gradient computation.
        #Example: reg_fn in RNNs.
        
        #If we return something, that will be the new value for loss. Since the recorder is always called first,
        #it will have the raw loss.
        
    def on_backward_end(self):
        #Here is the place to run some code that needs to be executed after the gradients have been computed but
        #before the optimizer is called.
        #Example: deal with weight_decay in AdamW
        
    def on_step_end(self): 
        #Here is the place to run some code that needs to be executed after the optimizer step but before the gradients
        #are zeroed
        #Example: can't thnk of any but maybe someone will need this one day/
        
    def on_batch_end(self, loss):
        #We get the loss again, this time it's the version modified by all the callbacks, so depending on our
        #needs it might be best to use self.raw_loss
        self.loss = loss
        #Here is the place to run some code that needs to be executed after a batch is fully done.
        #Example: change the values of the hyperparameters (if we don't do it on_batch_begin instead)
        
        #If we return true, the current epoch is interrupted (example: lr_finder stops the training when the loss explodes)
        
    def on_epoch_end(self, val_loss):
        #We get the validation loss (TODO: and metrics)
        self.val_loss = val_loss
        #Here is the place to run some code that needs to be executed at the end of an epoch.
        #Example: Save the model if we have a new best validation loss/metric.
        
        #If we return true, the training stops (example: early stopping)
        
    def on_train_end(self): 
        #Here is the place to tidy everything.
        #Examples: save log_files, load best model found during training

Final fit function

In [None]:
def loss_batch(model, xb, yb, loss_fn, opt=None, cb_handler=CallbackHandler([])):
    out = model(xb)
    out = cb_handler.on_loss_begin(out)
    loss = loss_fn(out, yb)
    
    if opt is not None:
        loss = cb_handler.on_backward_begin(loss)
        loss.backward()
        cb_handler.on_backward_end()
        opt.step()
        cb_handler.on_step_end()
        opt.zero_grad()
        
    return loss.item(), len(xb)

In [None]:
def fit(epochs, model, loss_fn, opt, data, callbacks=[]):
    
    cb_handler = CallbackHandler(callbacks)
    cb_handler.on_train_begin()
    
    for epoch in tnrange(epochs):
        model.train()
        cb_handler.on_epoch_begin()
        
        for xb,yb in data.train_dl:
            xb, yb = cb_handler.on_batch_begin(xb, yb)
            loss,_ = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)
            if cb_handler.on_batch_end(loss): break
        
        if hasattr(data,'valid_dl') and data.valid_dl is not None:
            model.eval()
            with torch.no_grad():
                losses,nums = zip(*[loss_batch(model, xb, yb, loss_fn)
                                for xb,yb in data.valid_dl])
            val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)
        else: val_loss=None
        if cb_handler.on_epoch_end(val_loss): break
        
    cb_handler.on_train_end()

The idea is that one thing is entirely done in a callback so that it's easily read. For instance let's take back the LRFinder: on top of running the fit function with exponentially growing lrs, it needs to handle the fact some preparation and clean-up, and all this code should be in the callback, not the lr_find function.

In [None]:
class LRFinder(Callback):
    #TODO: add model.save in init or on_train_begin and model.load in on_train_end.
    
    def __init__(self, learn, start_lr=1e-5, end_lr=10, num_it=200):
        self.learn = learn
        self.sched = Stepper((start_lr, end_lr), num_it, annealing_exponential)
        #To avoid validating if the train_dl has less than num_it batches, we put aside the valid_dl and remove it
        #during the call to fit.
        self.valid_dl = learn.data.valid_dl
        learn.data.valid_dl = None
    
    def on_train_begin(self):
        self.opt = self.learn.opt
        self.opt.lr = self.sched.init_val()
        self.stop,self.first,self.best_loss = False,True,0.
    
    def on_batch_end(self, loss):
        if self.first or loss < self.best_loss:
            self.first = False
            self.best_loss = loss
        self.opt.lr = self.sched.step()
        if self.sched.is_done() or self.learn.recorder.smooth_loss > 4*self.best_loss:
            #We use the smoothed loss to decide on the stopping since it's less shaky.
            self.stop=True
            return True
    
    def on_epoch_end(self, val_loss): return self.stop
    
    def on_train_end(self):
        #Clean up and put back the valid_dl in its place.
        self.learn.data.valid_dl = self.valid_dl

In [None]:
class Learner():
    def __init__(self, data, model):
        self.data,self.model = data,model.to(data.device)
        self.loss_fn, self.opt_fn = F.cross_entropy, optim.SGD

    def fit(self, epochs, lr, callbacks=[]):
        self.opt = HPOptimizer(self.model.parameters(), self.opt_fn, init_lr=lr)
        self.recorder = Recorder(self.opt, self.data.train_dl)
        callbacks.insert(0, self.recorder)
        fit(epochs, self.model, self.loss_fn, self.opt, self.data, callbacks=callbacks)
        
    def lr_find(self, start_lr=1e-5, end_lr=10, num_it=200):
        cb = LRFinder(self, start_lr, end_lr, num_it)
        a = int(np.ceil(num_it/len(self.data.train_dl)))
        self.fit(a, start_lr, callbacks=[cb])

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
len(learn.data.valid_dl)

Here are the tests for change of optimizers/dataloaders.

Changing directly opt.opt or data.train_dl/data.valid_dl changes the corresponding item in the fit function.

In [None]:
data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms+[cifar_norm], valid_tfm=[cifar_norm], num_workers=0)
data1 = DataBunch.create(train_ds, valid_ds, bs=32, train_tfm=tfms+[cifar_norm], valid_tfm=[cifar_norm], num_workers=0)

In [None]:
class CbTest():
    def __init__(self, learn, new_data):
        self.learn,self.new_data = learn,new_data
        
    def call_me(self):
        self.learn.data.train_dl = self.new_data.train_dl
        self.learn.data.valid_dl = self.new_data.valid_dl

In [None]:
learn.data = data

In [None]:
cb = CbTest(learn, data1)

In [None]:
def test(data, cb):
    x,y = next(iter(data.train_dl))
    print(x.size())
    cb.call_me()
    x,y = next(iter(data.train_dl))
    print(x.size())

In [None]:
test(learn.data, cb)

In [None]:
learn.opt = HPOptimizer(model.parameters(), optim.SGD, 1e-2)

In [None]:
class CbTest():
    def __init__(self, learn, new_opt):
        self.learn,self.new_opt = learn,new_opt
        
    def call_me(self):
        self.learn.opt.opt = self.new_opt

In [None]:
cb = CbTest(learn, optim.Adam)

In [None]:
def test(opt, cb):
    print(opt.opt)
    cb.call_me()
    print(opt.opt)

In [None]:
test(learn.opt,cb)

Changing directly opt or data doesn't change anything inside the fit function.

In [None]:
data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms+[cifar_norm], valid_tfm=[cifar_norm], num_workers=0)
data1 = DataBunch.create(train_ds, valid_ds, bs=32, train_tfm=tfms+[cifar_norm], valid_tfm=[cifar_norm], num_workers=0)

In [None]:
class CbTest():
    def __init__(self, learn, new_data):
        self.learn,self.new_data = learn,new_data
        
    def call_me(self):
        self.learn.data = self.new_data

In [None]:
learn.data = data

In [None]:
cb = CbTest(learn, data1)

In [None]:
def test(data, cb):
    x,y = next(iter(data.train_dl))
    print(x.size())
    cb.call_me()
    x,y = next(iter(data.train_dl))
    print(x.size())

In [None]:
test(learn.data, cb)

In [None]:
learn.opt = optim.SGD

In [None]:
class CbTest():
    def __init__(self, learn, new_opt):
        self.learn,self.new_opt = learn,new_opt
        
    def call_me(self):
        self.learn.opt = self.new_opt

In [None]:
cb = CbTest(learn, optim.Adam)

In [None]:
def test(opt, cb):
    print(opt)
    cb.call_me()
    print(opt)

In [None]:
test(learn.opt,cb)

## Metrics

Let's add validation metrics.

In [None]:
#export
from typing import Callable, List

In [None]:
#export
@dataclass
class Learner():
    
    data: DataBunch
    model: nn.Module
    loss_fn: Callable = F.cross_entropy
    opt_fn: Callable = optim.SGD
    metrics: List = None
    true_wd: bool = False
    def __post_init__(self): self.model = self.model.to(self.data.device)

    def fit(self, epochs, lr, wd=0., callbacks=None):
        self.opt = HPOptimizer(self.model.parameters(), self.opt_fn, init_lr=lr, true_wd=self.true_wd)
        self.opt.wd = wd
        self.recorder = Recorder(self.opt, self.data.train_dl)
        if callbacks is None: callbacks = []
        callbacks.insert(0, self.recorder)
        fit(epochs, self.model, self.loss_fn, self.opt, self.data, callbacks=callbacks, metrics=self.metrics)
        
    def lr_find(self, start_lr=1e-5, end_lr=10, num_it=200):
        cb = LRFinder(self, start_lr, end_lr, num_it)
        a = int(np.ceil(num_it/len(self.data.train_dl)))
        self.fit(a, start_lr, callbacks=[cb])

In [None]:
#export
def loss_batch(model, xb, yb, loss_fn, opt=None, cb_handler=None):
    if cb_handler is not None: xb, yb = cb_handler.on_batch_begin(xb, yb)
    out = model(xb)
    if cb_handler is not None: out = cb_handler.on_loss_begin(out)
    loss = loss_fn(out, yb)
    
    if opt is not None:
        if cb_handler is not None: loss = cb_handler.on_backward_begin(loss)
        loss.backward()
        if cb_handler is not None: cb_handler.on_backward_end()
        opt.step()
        if cb_handler is not None: cb_handler.on_step_end()
        opt.zero_grad()
    
    if cb_handler is not None: stop = cb_handler.on_batch_end(loss.item()) else: False
        
    return loss.item(), len(xb), stop

In [None]:
#export
def fit(epochs, model, loss_fn, opt, data, callbacks=None):
    
    cb_handler = CallbackHandler(callbacks)
    cb_handler.on_train_begin()
    
    for epoch in tnrange(epochs):
        model.train()
        cb_handler.on_epoch_begin()
        
        for xb,yb in data.train_dl:
            loss,_,stop = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)
            if stop: break
        
        if hasattr(data,'valid_dl') and data.valid_dl is not None:
            model.eval()
            cb_handler.on_valid_begin()
            with torch.no_grad():
                losses,nums = [],0
                for xb, yb in data.valid_dl:
                    loss,num,stop = loss_batch(model, xb, yb, loss_fn)
                    losses.append(loss)
                    nums.append(num)
                    if stop: break
            val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)
        else: val_loss=None

        if cb_handler.on_epoch_end(val_loss): break
        
    cb_handler.on_train_end()

In [None]:
#export
class Callback():
    def on_train_begin(self, **kwargs): pass         
        #To initiliaze constants in the callback.
    def on_epoch_begin(self, **kwargs): pass
        #At the beginning of each epoch
    def on_batch_begin(self, **kwargs): pass 
        #To set HP before the step is done. A look at the input can be useful (set the lr depending on the seq_len in RNNs, 
        #or for reg_functions called in on_backward_begin)
        #Returns xb, yb (which can allow us to modify the input at that step if needed)
    def on_valid_begin(self, **kwars): pass
        #If a valid_dl has been specified, then called at the start of validation
    def on_loss_begin(self, **kwargs): pass
        #Called after the forward pass but before the loss has been computed.
        #Passes the output of the model.
        #Returns the output (which can allow us to modify it)
    def on_backward_begin(self, **kwargs): pass
        #Called after the forward pass and the loss has been computed, but before the back propagation.
        #Passes the loss of the model.
        #Returns the loss (which can allow us to modify it, for instance for reg functions)
    def on_backward_end(self, **kwargs): pass
        #Called after the back propagation had been done (and the gradients computed) but before the step of the optimizer.
        #Useful for true weight decay in AdamW
    def on_step_end(self, **kwargs): pass
        #Called after the step of the optimizer but before the gradients are zeroed (not sure this one is useful)
    def on_batch_end(self, **kwargs): pass
        #Called at the end of the batch
    def on_epoch_end(self, **kwargs): pass
        #Called at the end of an epoch
    def on_train_end(self, **kwargs): pass
        #Useful for cleaning up things and saving files/models

In [None]:
class MetricCallback(Callback):
    def name(): pass
        #provide a user friendly name for the metric for display purposes, i.e. Accuracy
    def metric(): pass
        #return the metric value, i.e. 0.9832

In [None]:
#export
class SmoothenValue():
    
    def __init__(self, beta):
        self.beta,self.n,self.mov_avg = beta,0,0
    
    def add_value(self, val):
        self.n += 1
        self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
        self.smooth = self.mov_avg / (1 - self.beta ** self.n)

In [None]:
#export
class CallbackHandler():
    beta = 0.98
    
    def __init__(self, callbacks):
        self.callbacks = callbacks if callbacks is not None else []
        self.smoothener = SmoothenValue(self.beta)
    
    def __call__(self, cb_name):
        return [getattr(cb, f'on_{cb_name}')(**self.state_dict) for cb in self.callbacks]
    
    def on_train_begin(self): 
        self.state_dict = {'epoch': 0, 'iteration': 0, 'num_batch': 0, 'batch_type': 'train'}
        self('train_begin')
        
    def on_epoch_begin(self): 
        self.state_dict['num_batch'] = 0
        self.state_dict['batch_type'] = 'train'
        self('epoch_begin')
        
    def on_batch_begin(self, xb, yb):
        self.state_dict['last_input'], self.state_dict['last_target'] = xb, yb
        for cb in self.callbacks:
            a = cb.on_batch_begin(**self.state_dict)
            if a is not None: xb,yb = a
        return xb,yb

    def on_valid_begin(self, **kwars):
        self.state_dict['batch_type'] = 'valid'
    
    def on_loss_begin(self, out):
        self.state_dict['last_output'] = out
        for cb in self.callbacks:
            a = cb.on_loss_begin(**self.state_dict)
            if a is not None: out = a
        return out
    
    def on_backward_begin(self, loss):
        self.smoothener.add_value(loss.item())
        self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth
        for cb in self.callbacks:
            a = cb.on_backward_begin(**self.state_dict)
            if a is not None: loss = a
        return loss
    
    def on_backward_end(self):        self('backward_end')
    def on_step_end(self):            self('step_end')
        
    def on_batch_end(self, loss):     
        stop = np.any(self('batch_end'))
        self.state_dict['iteration'] += 1
        self.state_dict['num_batch'] += 1
        return stop
    
    def on_epoch_end(self, val_metrics):
        self.state_dict['last_metrics'] = val_metrics
        stop = np.any(self('epoch_end'))
        self.state_dict['epoch'] += 1
        return stop
    
    def on_train_end(self): self('train_end')

In [None]:
class Accuracy(MetricCallback):
    
    def __init__(self):
        self.total,self.correct = 0,0
        
    def name():
        return "Accuracy"
    
    def metric():
        if self.total > 0: return self.correct / float(self.total)
        else: return 0
    
    def on_epoch_begin(self, **kwargs):
        self.total,self.correct = 0,0

    def on_loss_begin(self, batch_type, last_target, last_output):
        if batch_type == 'valid':
            preds = torch.max(last_output, dim=1)[1]
            self.correct += torch.nonzero(preds == last_target).size(0)
            self.total += len(last_target)

In [None]:
#export
class Recorder(Callback):
    
    def __init__(self, opt, train_dl=None):
        self.opt,self.train_dl = opt,train_dl
    
    def on_train_begin(self, **kwargs):
        self.losses,self.val_losses,self.lrs,self.moms,self.metrics = [],[],[],[],[]
    
    def on_batch_begin(self, **kwargs):
        self.lrs.append(self.opt.lr)
        self.moms.append(self.opt.mom)
    
    def on_backward_begin(self, smooth_loss, **kwargs):
        #We record the loss here before any other callback has a chance to modify it.
        self.losses.append(smooth_loss)
        if self.train_dl is not None and self.train_dl.progress_func is not None: 
            self.train_dl.gen.set_postfix_str(smooth_loss)
    
    def on_epoch_end(self, epoch, last_metrics, **kwargs):
        if last_metrics is not None:
            self.val_losses.append(last_metrics[0])
            if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])
            print(epoch, *last_metrics)
    
    def plot_lr(self, show_moms=False):
        iterations = list(range(len(learn.recorder.lrs)))
        if show_moms:
            fig, axs = plt.subplots(1,2, figsize=(12,4))
            axs[0].plot(iterations, self.lrs)
            axs[1].plot(iterations, self.moms)
        else: plt.plot(iterations, self.lrs)
    
    def plot(self, skip_start=10, skip_end=5):
        lrs = self.lrs[skip_start:-skip_end] if skip_end > 0 else self.lrs[skip_start:]
        losses = self.losses[skip_start:-skip_end] if skip_end > 0 else self.losses[skip_start:]
        fig, ax = plt.subplots(1,1)
        ax.plot(lrs, losses)
        ax.set_xscale('log')

In [None]:
#export
class LRFinder(Callback):
    #TODO: add model.save in init or on_train_begin and model.load in on_train_end.
    
    def __init__(self, learn, start_lr=1e-5, end_lr=10, num_it=200):
        self.learn = learn
        self.sched = Stepper((start_lr, end_lr), num_it, annealing_exponential)
        #To avoid validating if the train_dl has less than num_it batches, we put aside the valid_dl and remove it
        #during the call to fit.
        self.valid_dl = learn.data.valid_dl
        learn.data.valid_dl = None
    
    def on_train_begin(self, **kwargs):
        self.opt = self.learn.opt
        self.opt.lr = self.sched.init_val()
        self.stop,self.best_loss = False,0.
    
    def on_batch_end(self, iteration, smooth_loss, **kwargs):
        if iteration==0 or smooth_loss < self.best_loss:
            self.best_loss = smooth_loss
        self.opt.lr = self.sched.step()
        if self.sched.is_done() or smooth_loss > 4*self.best_loss:
            #We use the smoothed loss to decide on the stopping since it's less shaky.
            self.stop=True
            return True
    
    def on_epoch_end(self, **kwargs): return self.stop
    
    def on_train_end(self, **kwargs):
        #Clean up and put back the valid_dl in its place.
        self.learn.data.valid_dl = self.valid_dl

In [None]:
#export
class OneCycleScheduler(Callback):
    
    def __init__(self, learn, lr_max, epochs, moms=(0.95,0.85), div_factor=10, pct_end=0.1):
        self.learn = learn
        a = int(len(learn.data.train_dl) * epochs * (1 - pct_end) / 2)
        b = int(len(learn.data.train_dl) * epochs * pct_end)
        self.lr_scheds = [Stepper((lr_max/div_factor, lr_max), a),
                          Stepper((lr_max, lr_max/div_factor), a),
                          Stepper((lr_max/div_factor, lr_max/(div_factor*100)), b)]
        self.mom_scheds = [Stepper(moms, a), Stepper((moms[1], moms[0]), a), Stepper(moms[0], b)]
    
    def on_train_begin(self, **kwargs):
        self.opt = self.learn.opt
        self.opt.lr, self.opt.mom = self.lr_scheds[0].init_val(), self.mom_scheds[0].init_val()
        self.idx_s = 0
    
    def on_batch_end(self, **kwargs):
        self.opt.lr = self.lr_scheds[self.idx_s].step()
        self.opt.mom = self.mom_scheds[self.idx_s].step()
        if self.lr_scheds[self.idx_s].is_done():
            self.idx_s += 1
            if self.idx_s >= len(self.lr_scheds): return True

In [None]:
#export
def accuracy(out, yb):
    preds = torch.max(out, dim=1)[1]
    return (preds==yb).float().mean()

In [None]:
import pdb

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)
learn.metrics = [accuracy]
sched = OneCycleScheduler(learn, 0.1, 5)

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
learn.fit(5, 1e-2, callbacks=[sched])