## https://github.com/fastai/fastai2/pull/82

The goal of this notebook is to show how a change to `Hook` could help us write hook functions that accumulate state. 

In this example, we will;
- initialize our model using a simplified LSUV (https://arxiv.org/abs/1511.06422)
    - the simplification is that we don't pre-init with orthonormal matrices
- save stats at each step in the LSUV loop
    - to keep it simple, we just print the stats

_The ask is_ ... can anyone help me understand how we can make this change to `Hook` using the `funcs_kwargs` decorator (instead of `can_call_with_n_positional_args`)?

- `funcs_kwargs` is in fastcore/nbs/01_foundation.ipynb and
    - `DataLoader` in fastai2/nbs/02_data.load.ipynb is a good `funcs_kwargs` example
- `Hook` is in fastai2/nbs/15_callback.hook.ipynb

In [1]:
import PIL # hack to re-instate PILLOW_VERSION
PIL.PILLOW_VERSION = PIL.__version__

from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *

Change `Hook#hook_fn` to allow hook functions with either:
- 3 args (model, input, output) or
- 4 args (model, input, output, stored)

In [2]:
def can_call_with_n_positional_args(fn, n):
    "return `True` if fn can be called with n positional arguments, `False` otherwise"
    def _len(o): return 0 if o is None else len(o)
    fas = inspect.getfullargspec(fn)
    def _min(): return _len(fas.args) - _len(fas.defaults)
    def _max(): return 99999 if fas.varargs else len(fas.args)
    if inspect.ismethod(fn): n += 1 # add one for self
    return n >= _min() and n <= _max()

def hook_fn(self, module, input, output):
    "Applies `hook_func` to `module`, `input`, `output` and optionally `self`."
    if self.detach:
        input,output = to_detach(input, cpu=self.cpu, gather=self.gather),to_detach(output, cpu=self.cpu, gather=self.gather)
    args = [module, input, output]
    if can_call_with_n_positional_args(self.hook_func, 4): args.append(self)
    self.stored = self.hook_func(*args)

Hook.hook_fn = hook_fn

In [3]:
bs = 256
source = untar_data(URLs.MNIST)
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
                   get_items=get_image_files, 
                   get_y=parent_label)
dataloaders = dblock.dataloaders(source, path=source, bs=bs)

In [4]:
def conv_layer(in_channels, out_channels, kernel_size, stride):
    return [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1), Mish()]

Create a model with architecture similar to FitNet-MNIST (https://arxiv.org/abs/1412.6550)

In [5]:
net = nn.Sequential(
    *conv_layer(3, 16, 3, 2),
    *conv_layer(16, 16, 3, 1),
    nn.MaxPool2d(4, 2),
    *conv_layer(16, 16, 3, 1),
    *conv_layer(16, 16, 3, 1),
    nn.MaxPool2d(4, 2),
    *conv_layer(16, 12, 3, 1),
    *conv_layer(12, 12, 3, 1),
    nn.AdaptiveAvgPool2d(output_size=1),
    nn.Flatten(),
    nn.Linear(12, 10)).cuda()

We can subclass `Hook` and easily access stored state

In [6]:
class StoreStatsHook(Hook):
    def __init__(self, m):
        super().__init__(m, self.store_stats)
        self.stored = L()

    def store_stats(self, m, i, o):
        "save history of stats in list `hook.stored` and latest stats in `hook.stored.mean` and `hook.stored.std`"
        stored = self.stored
        stored.mean, stored.std = o.data.mean().item(), o.data.std().item()
        stored.append([stored.mean, stored.std])
        return stored

or we could do the same thing with a function and 3 lines less code.

Note: If the function can't have a `hook` argument, we can't accumulate state with a function

In [7]:
def store_stats(m, i, o, hook):
    "save history of stats in list `hook.stored` and latest stats in `hook.stored.mean` and `hook.stored.std`"
    if hook.stored is None: hook.stored = L()
    stored = hook.stored
    stored.mean, stored.std = o.data.mean().item(), o.data.std().item()
    stored.append([stored.mean, stored.std])
    return stored

In [8]:
def do_lsuv(tolerance=1e-3):
    "re-initialize the model using simplified LSUV and return history of stats for each layer"
    stats = L()
    xb, yb = dataloaders.one_batch()
    net.eval()
    with torch.no_grad():
        for module in net:
            if not isinstance(module, nn.Conv2d): continue
            # these hooks both do the same thing
            hook = Hook(module, store_stats)
#             hook = StoreStatsHook(module)
            while net(xb) is not None and abs(hook.stored.std-1) > tolerance: 
                module.weight.data /= hook.stored.std
            hook.remove()
            stats.append(hook.stored)
    return stats

In [9]:
for stats in do_lsuv(): print(stats)

(#3) [[-0.022715555503964424, 0.23131683468818665],[-0.039343997836112976, 0.9175119996070862],[-0.0412888340651989, 0.9990230798721313]]
(#3) [[-0.02419188804924488, 0.34461209177970886],[-0.04715002700686455, 1.0021822452545166],[-0.04707375168800354, 0.9999953508377075]]
(#3) [[-0.23670406639575958, 0.8159131407737732],[-0.2889198958873749, 0.9958274960517883],[-0.29010841250419617, 0.999923586845398]]
(#3) [[-0.00022016761067789048, 0.306287556886673],[0.050328124314546585, 0.9886533617973328],[0.05116439610719681, 0.9999657869338989]]
(#3) [[-0.10975136607885361, 0.5082138776779175],[-0.25173482298851013, 1.0027832984924316],[-0.2509334981441498, 0.9999889135360718]]
(#3) [[-0.02601805329322815, 0.1902713179588318],[-0.1723552942276001, 0.9768433570861816],[-0.17663948237895966, 1.0000594854354858]]


In [10]:
learn = Learner(dataloaders, net, opt_func=ranger, metrics=accuracy, loss_func=LabelSmoothingCrossEntropy())

In [11]:
learn.fit_flat_cos(3, lr=1e-2, wd=1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,0.718474,0.63654,0.9678,00:07
1,0.604962,0.569373,0.9817,00:07
2,0.55926,0.549785,0.9863,00:07
