In [None]:
#default_exp representation

In [None]:
#export
from fastai.vision.all import *

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(path)
learn = cnn_learner(dls, resnet18, pretrained=False)

# Representation

> Functions and utilities to get representations of pytorch/fastai objects.

## Model representation

First we want to represent any model as a graph. A graph consists of nodes and links, the nodes will be a collection of nn.Module's and nn.Parameter's, and the links will describe the connections between them.

In [None]:
#export
class Representation:
    'Representation of a Learn object.'
    def __init__(self, data): self.data = data
    def __repr__(self): return f'{self.__class__.__name__} (\n{self.data}\n)'

In [None]:
#export
class Node:
    "Represents a Module or Parameter."
    def __init__(self, name, idx, typ, obj=None, nodes=None, links=None, xtra=None):
        store_attr('name,idx,typ,obj', self)
        self.nodes = ifnone(nodes, [])
        self.links = ifnone(links, [])
        self.xtra  = ifnone(xtra , {})

    def __repr__(self):
        res = f'({self.__class__.__name__}): {self.name}, type: {self.typ}, idx={self.idx}'
        if len(self.xtra): res += f'\nxtra: {self.xtra}'
        if len(self.nodes):
            res += '\nnodes: ['
            for node in self.nodes: res += f'\n  {node}'
            res += '\n]'

        if len(self.links):
            res += '\nlinks: ['
            for link in self.links: res += f'\n  {link}'
            res += '\n]'

        return res

In [None]:
#export
@patch
def to_representation(self:nn.Module, name=None, idx=0, path=None, xtra=None):
    "Obtain information of the Module and stores it on a `Node`."
    name = ifnone(name, self.__class__.__name__)
    xtra = ifnone(xtra, {})
    if path is not None: xtra['path'] = path
    typ = 'Sequential' if isinstance(self, nn.Sequential) else 'Module'
    nodes,links = _get_module_nodes(self)
    if hasattr(self, '_xtra'): xtra.update(self._xtra)
    return Node(name, idx, typ, self, nodes, links, xtra)

def _get_module_nodes(module:nn.Module):
    "Obtain the `Node` representation for all the module childrens."
    nodes,links = [],[]
    is_seq = isinstance(module, nn.Sequential)
    for i,(n,m) in enumerate(module.named_children()):
        name = f'{n}_{m.__class__.__name__}' if is_seq else n
        nodes.append(m.to_representation(name, i, n))
        if i>0: links.append({'source':i-1, 'target':i})

    return nodes,links

In [None]:
show_doc(nn.Module.to_representation)

<h4 id="Module.to_representation" class="doc_header"><code>Module.to_representation</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Module.to_representation</code>(**`name`**=*`None`*, **`idx`**=*`0`*, **`path`**=*`None`*, **`xtra`**=*`None`*)

Obtain information of the Module and stores it on a [`Node`](/fastexplorer/representation.html#Node).

In [None]:
#export
@patch
def to_representation(self:Learner):
    "Gets a representation of the Learner to be passed to a web client."
    xb,yb = self.dls.train.one_batch()
    def _get_info(m, i, o):
        params,trainable = total_params(m)
        m._xtra = {'params': params, 'trainable': trainable, 'shape': o.shape}

    model = self.model.to(xb.device)
    layers = flatten_model(model)
    with Hooks(layers, _get_info) as h: model.eval()(xb)

    nodes = [Node('Input', 0, 'Input', xtra={'shape':list(xb.shape)}),
             self.model.to_representation(xtra={'open': True}),
             Node('Output', 0, 'Output')]
    links = [{'source':i, 'target':i+1} for i in range_of(nodes)]
    rep = Representation(Node('Learner', 0, 'Learner', nodes=nodes, links=links, xtra={'open': True}))
    _update_shapes(rep.data)
    for layer in layers: del(layer._xtra) # Clean
    nodes[-1].xtra['shape'] = nodes[-2].xtra.get('shape')
    return rep

def _update_shapes(node):
    shape = node.xtra.get('shape')
    childs = node.nodes
    if (shape is None) and len(childs):
        for n in node.nodes: _update_shapes(n)
        node.xtra['shape'] = childs[-1].xtra.get('shape')

In [None]:
show_doc(Learner.to_representation)

<h4 id="Learner.to_representation" class="doc_header"><code>Learner.to_representation</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.to_representation</code>()

Gets a representation of the Learner to be passed to a web client.

In [None]:
r = learn.to_representation()
PrettyString(r.data)[:300]

"(Node): Learner, type: Learner, idx=0\nxtra: {'open': True, 'shape': None}\nnodes: [\n  (Node): Input, type: Input, idx=0\nxtra: {'shape': [64, 3, 28, 28]}\n  (Node): Sequential, type: Sequential, idx=0\nxtra: {'open': True, 'shape': torch.Size([64, 2])}\nnodes: [\n  (Node): 0_Sequential, type: Sequential, "

In [None]:
#export
@patch
def get_dict(self:Node):
    "Gets the dictionary of the `Node`."
    res = {'name':self.name, 'type':self.typ, 'index':self.idx}
    if len(self.nodes): res['nodes'] = [o.get_dict() for o in self.nodes]
    if len(self.links): res['links'] = self.links
    if len(self.xtra) : res['xtra']  = self.xtra
    return res

In [None]:
show_doc(Node.get_dict)

<h4 id="Node.get_dict" class="doc_header"><code>Node.get_dict</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Node.get_dict</code>()

Gets the dictionary of the [`Node`](/fastexplorer/representation.html#Node).

In [None]:
#export
@patch
def to_json(self:Representation):
    "Gets the seriable json from the Leaner `Representation`."
    return json.dumps(self.data.get_dict())

In [None]:
show_doc(Representation.to_json)

<h4 id="Representation.to_json" class="doc_header"><code>Representation.to_json</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Representation.to_json</code>()

Gets the seriable json from the Leaner [`Representation`](/fastexplorer/representation.html#Representation).

In [None]:
r.to_json()[:300]

'{"name": "Learner", "type": "Learner", "index": 0, "nodes": [{"name": "Input", "type": "Input", "index": 0, "xtra": {"shape": [64, 3, 28, 28]}}, {"name": "Sequential", "type": "Sequential", "index": 0, "nodes": [{"name": "0_Sequential", "type": "Sequential", "index": 0, "nodes": [{"name": "0_Conv2d"'

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_representation.ipynb.
Converted 01_explorer.ipynb.
Converted 02_loss_landscape.ipynb.
Converted index.ipynb.
