In [None]:
#default_exp representation

In [None]:
#export
from fastai2.vision.all import *

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(path)
learn = cnn_learner(dls, resnet18, pretrained=False)
m = learn.model

# Representation

> Functions and utilities to get representations of pytorch/fastai objects.

## TODO
* recursively obtain representations

## Model representation

First we want to represent any model as a graph. A graph consists of nodes and links, the nodes will be a collection of nn.Module's and nn.Parameter's, and the links will describe the connections between them.

In [None]:
#export
class Representation:
    'Representation of a Learn object.'
    def __init__(self, data): self.data = data
    def __repr__(self): return f'{self.__class__.__name__} ()'

In [None]:
#hide
# @patch
# def to_representation(self:nn.Module, *xb):
#     "Gets a summary of `self` using `xb`"
#     sample_inputs,infos = layer_info(self, *xb)
#     n,bs = 64,find_bs(xb)
#     name = self.__class__.__name__
#     modules,params,xtra = L(),L(),{}
#     inp_shape = list(apply(lambda x:x.shape, xb)[0])
#     out_shape = list(apply(lambda x:x.shape, xb)[0])
#     infos = L([o for o in infos if o is not None]) #see comment in previous cell
#     for i,(typ,np,trn,sz) in infos.enumerate():
#         modules.append({'idx':i, 'name':typ})
    
#     if isinstance(self, nn.Sequential):
#         idxs = modules.map(lambda x: x['idx'])
#         links = idxs[:-1].map_zipwith(lambda a,b: {'source':a, 'target':b}, idxs[1:])
#     else: raise NotImplementedError()
    
#     return Representation(name, inp_shape, out_shape, modules, params, links, xtra)
    
#     inp_sz = _print_shapes(apply(lambda x:x.shape, xb), bs)
#     res = f"{self.__class__.__name__} (Input shape: {inp_sz})\n"
#     res += "=" * n + "\n"
#     res += f"{'Layer (type)':<20} {'Output Shape':<20} {'Param #':<10} {'Trainable':<10}\n"
#     res += "=" * n + "\n"
#     ps,trn_ps = 0,0
#     infos = [o for o in infos if o is not None] #see comment in previous cell
#     for typ,np,trn,sz in infos:
#         if sz is None: continue
#         ps += np
#         if trn: trn_ps += np
#         res += f"{typ:<20} {_print_shapes(sz, bs)[:19]:<20} {np:<10,} {str(trn):<10}\n"
#         res += "_" * n + "\n"
#     res += f"\nTotal params: {ps:,}\n"
#     res += f"Total trainable params: {trn_ps:,}\n"
#     res += f"Total non-trainable params: {ps - trn_ps:,}\n\n"
#     return Representation(name, inp_shape, out_shape, modules, params, links, xtra), PrettyString(res)

In [None]:
#hide
# @patch
# def to_representation(self:Learner):
#     "Gets a summary of the model, optimizer and loss function."
#     xb = self.dls.train.one_batch()[:self.dls.train.n_inp]
#     return self.model.to_representation(*xb)
#     res = self.model.summary(*xb)
#     res += f"Optimizer used: {self.opt_func}\nLoss function: {self.loss_func}\n\n"
#     if self.opt is not None:
#         res += f"Model " + ("unfrozen\n\n" if self.opt.frozen_idx==0 else f"frozen up to parameter group number {self.opt.frozen_idx}\n\n")
#     res += "Callbacks:\n" + '\n'.join(f"  - {cb}" for cb in sort_by_run(self.cbs))
#     return PrettyString(res)

In [None]:
#export
@patch
def to_representation(self:Learner):
    "Gets a representation of the Learner to be passed to a web client."
    return Representation(self.model.to_representation('Model'))

In [None]:
#export
@patch
def to_representation(self:nn.Module, name=None, index=0):
    name = ifnone(name, self.__class__.__name__)
    res = {'name': name, 'index': index}
    res['type'] = 'Sequential' if isinstance(self, nn.Sequential) else 'Module'
    nodes,links = get_module_nodes(self)
    if len(nodes): res['nodes'] = nodes
    if len(links): res['links'] = links
    return res

In [None]:
#export
@typedispatch
def get_module_nodes(module:nn.Module):
    nodes,links = [],[]
    is_seq = isinstance(module, nn.Sequential)
    for i,(n,m) in enumerate(module.named_children()):
        if is_seq: n = f'{m.__class__.__name__}_{n}'
        nodes.append(m.to_representation(n,i))
        if i>0: links.append({'source':i-1, 'target':i})

    return nodes,links

In [None]:
r = learn.to_representation()
PrettyString(r.data)[:100]

"{'name': 'Model', 'index': 0, 'type': 'Sequential', 'nodes': [{'name': 'Sequential_0', 'index': 0, '"

In [None]:
#export
@patch
def to_json(self:Representation): return json.dumps(self.data)

In [None]:
r.to_json()[:100]

'{"name": "Model", "index": 0, "type": "Sequential", "nodes": [{"name": "Sequential_0", "index": 0, "'

In [None]:
#hide
learn.summary()

Sequential (Input shape: ['64 x 3 x 28 x 28'])
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               64 x 64 x 14 x 14    9,408      True      
________________________________________________________________
BatchNorm2d          64 x 64 x 14 x 14    128        True      
________________________________________________________________
ReLU                 64 x 64 x 14 x 14    0          False     
________________________________________________________________
MaxPool2d            64 x 64 x 7 x 7      0          False     
________________________________________________________________
Conv2d               64 x 64 x 7 x 7      36,864     True      
________________________________________________________________
BatchNorm2d          64 x 64 x 7 x 7      128        True      
________________________________________________________________
ReLU                 64 x 64 x 7 x 7      0          False     
___________________________________________________

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_representation.ipynb.
Converted 01_explorer.ipynb.
Converted 10_tutorial.ipynb.
Converted index.ipynb.
