In [None]:
#default_exp callback.training

In [None]:
#export
from fastai2.basics import *
from fastai2.callback.progress import *
from fastai2.callback.fp16 import *

In [None]:
#hide
from nbdev.showdoc import *
from fastai2.test_utils import *

# Tracking callbacks

> Callbacks that make decisions depending how a monitored metric/loss behaves

## ShortEpochCallback -

In [None]:
#export
@log_args
class ShortEpochCallback(Callback):
    "Fit just `pct` of an epoch, then stop"
    def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
    def after_batch(self):
        if self.iter/self.n_iter < self.pct: return
        if self.training:    raise CancelTrainException
        if self.short_valid: raise CancelValidException

In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())

In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))

## GradientAccumulation -

In [None]:
# export
@log_args
class GradientAccumulation(Callback):
    "Accumulate gradients before updating weights"
    toward_end,run_before=True,MixedPrecision

    def __init__(self, n_acc=32): store_attr(self, 'n_acc')
    def begin_fit(self): self.count=0

    def after_backward(self):
        self.count += find_bs(self.learn.yb)
        if self.count < self.n_acc: raise CancelBatchException() #skip weight update
        else: self.count=0

    _docs = dict(begin_fit="Set counter to 0",
                 after_backward="Skip weight update if we have not seen enough items")

In [None]:
learn = synth_learner()

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs))
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]

## BnFreeze

In [None]:
#export
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

def set_bn_eval(m:nn.Module)->None:
    "Set bn layers in eval mode for all recursive children of `m`."
    for l in m.children():
        if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
            l.eval()
        set_bn_eval(l)

class BnFreeze(Callback):
    "Freeze moving average statistics in all non-trainable batchnorm layers."
    def begin_batch(self):
        set_bn_eval(self.model)

`BnFreeze` is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning. <br>

`Learner.freeze()` doesn't suffice here as the `BatchNorm` layers are trainable by default, and running mean and sdev of batches are tracked. For feature extractors to fully match, you need to set `train_bn=False` and these stats need to be frozen as well, which is precisely the function of `BnFreeze`.

In [None]:
#slow
from fastai2.vision.all import *

path = untar_data(URLs.MNIST_TINY)
dls  = ImageDataLoaders.from_folder(path, valid_pct=0.2)

We first demonstrate the mismatch of the running stats when using only `train_bn=False`

In [None]:
#slow
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)
learn2 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)

learn1.fit(1, lr=0.02)
learn2.fit(1, lr=0.02)

epoch,train_loss,valid_loss,time
0,1.086924,0.464564,00:06


epoch,train_loss,valid_loss,time
0,1.23474,0.445445,00:05


In [None]:
#slow
# SOURCE: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
def models_equal(model_1, model_2, verbose=False):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                if verbose: print(f'Mismtach found at {key_item_1[0]}')
            else:
                raise Exception
                if verbose: print('Models being compared have different architectures')
    if models_differ == 0:
        if verbose: print('Models match perfectly')
        return True
    return False

In [None]:
#slow
models_equal(learn1.model, learn2.model)

False

In [None]:
#slow
dls1 = ImageDataLoaders.from_folder(path, valid_pct=0.2)
dls2 = ImageDataLoaders.from_folder(path, valid_pct=0.2)

learn1 = cnn_learner(dls1, resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
learn2 = cnn_learner(dls2, resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)

learn1.fit(1, lr=0.02)
learn2.fit(1, lr=0.02)

assert models_equal(learn1.model[0], learn2.model[0])

epoch,train_loss,valid_loss,time
0,0.402903,0.437258,00:04


epoch,train_loss,valid_loss,time
0,0.618615,0.176437,00:04


## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()