In [None]:
# default_exp models.utils

In [None]:
# hide
%load_ext autoreload
%autoreload 2

# Module utilities
> Utility function for modules.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from grade_classif.core import ifnone
from grade_classif.models.hooks import Hooks
from grade_classif.imports import *

In [None]:
# export
def named_leaf_modules(model, name=''):
    named_children = list(model.named_children())
    if named_children==[]:
        model.name = name
        return [model]
    else:
        res = []
        for n, m in named_children:
            pref = name+'.' if name != '' else ''
            res += named_leaf_modules(m, pref+n)
        return res

Recursive function that gets all leaf modules of `model` with their name (as _parent\_n.parent\_n-1.(...).parent\_0.module\_name_) added as an attribute. `name` is a convenience argument for recursion and should always be an empty string when manually called.

In [None]:
# export
def get_sizes(model, input_shape=(3, 224, 224), leaf_modules=None):     
    leaf_modules = ifnone(leaf_modules, named_leaf_modules('', model))
    
    class Count:
        def __init__(self):
            self.k = 0
    count = Count()        
    def _hook(model, input, output):
        model.k = count.k
        count.k += 1
        return model, output
    
    with Hooks(leaf_modules, _hook) as hooks:
        x = torch.rand(2, *input_shape)
        model.eval()(x)
        sizes = [hook.stored[1].shape for hook in hooks]
        mods = [hook.stored[0] for hook in hooks]
    idxs = np.argsort([mod.k for mod in mods])
    return np.array(sizes)[idxs], np.array(mods)[idxs]

Get a tuple `(sizes, modules)` where `sizes` contains the output shapes of all `leaf_modules` from `model`. Both are ordered depending on the call order.

In [None]:
# export
def gaussian_mask(m, s, d, R, C):
    # indices to create centres
    R = torch.arange(R, dtype=torch.float32).reshape((R, 1))
    C = torch.arange(C, dtype=torch.float32).reshape((1, C))
    centres = m + R * d
    column_centres = C - centres
    mask = torch.exp(-.5 * torch.square(column_centres / s))
    # we add eps for numerical stability
    normalised_mask = mask / (mask.sum(1, keepdims=True) + 1e-8)
    return normalised_mask

Create a gaussian attention mask with mean `m`, standard deviation `s`, distance between centers `d`, `R` rows ans `C` colums. Explanations for gaussian attention can be found in [this blog post](http://akoriorek.github.io/ml/2017/10/14/visual-attention.html). 

In [None]:
# export
def get_num_features(model):
    sizes, _ = get_sizes(model)
    return sizes[-1, 1]

Get the number of features from the last layer of `model`.

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_train.ipynb.
Converted 02_predict.ipynb.
Converted 10_data.read.ipynb.
Converted 11_data.loaders.ipynb.
Converted 12_data.dataset.ipynb.
Converted 13_data.utils.ipynb.
Converted 14_data.transforms.ipynb.
Converted 20_models.plmodules.ipynb.
Converted 21_models.modules.ipynb.
Converted 22_models.utils.ipynb.
Converted 23_models.hooks.ipynb.
Converted 24_models.metrics.ipynb.
Converted 25_models.losses.ipynb.
Converted 80_params.defaults.ipynb.
Converted 81_params.parser.ipynb.
Converted 99_index.ipynb.
