# Interactive Visualization of PyTorch JIT Modules

Copyright 2022 by Thomas Viehmann

Recently, PyTorch turned 5 years old Today is the 5th anniversary of my first post to the PyTorch forums.
I have had the honor of contriuting to PyTorch quite a while and have met many great people from the community.

To celebrate, I'm dusting off a two-year old visualization notebook and make the visualizations clickable to expand modules.

One caveat: Aside from being very hacky, I would expect names to not necessarily be completely stable across PyTorch releases, so you might not get the same pre-expanded diagrams.

So without further ado, here is the visualization. (I left the old text in here.)
 
I license this code with the CC-BY-SA 4.0 license. Please link to my blog post or the original github source (linked from the blog post) with the attribution notice.


## Introduction

Did you ever wish to get a concise picture of your PyTorch model's structure and found that too hard to get?


Recently, I did some work that involved looking at model structure in some detail. For my write-up, I wanted to get a diagram of some model structures. Even though it is a relatively common model, searching for a diagram didn't turn up something in the shape what I was looking for.

So how do can we get model structure for PyTorch models? The first stop probably is the neat string representation that PyTorch provides for `nn.Modules` - even without doing anything, it'll also cover our custom models pretty well. It is, however not without shortcomings.

Let's look at TorchVision's ResNet18 basic block as an example.

In [1]:
import torchvision
m = torchvision.models.resnet18()
m.layer1[0]

BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

So we have two convs and two batch norms. But how are things connected? Is there one ReLU?

Looking at the forward method (you can get this using Python's `inspect` module or `??` in IPython), we see some important details not in the summary:

In [6]:
import inspect
print(inspect.getsource(m.layer1[0].forward))

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out



So we missed the entire residual bit. Also, there are two ReLUs. Arguably, it is wrong to re-use stateless modules like this. It'll haunt you when you do things like quantization (because it becomes stateful then due to the quantization parameters) and it's mixing things too much. If you want stateless, use the functional interface.

But so we can build a visualization based on JITed modules.

We recurse into calls to make subgraphs and we have to take some care that the edges connecting the subgraph to the outer graph need to be part of the outer graph, but other than that, it is very straightforward, even though the details are messy.

In [7]:
import graphviz

def make_graph(mod, classes_to_visit=None, classes_found=None, dot=None, prefix="",
               input_preds=None, 
               parent_dot=None):
    preds = {}
    
    def find_name(i, self_input, suffix=None):
        if i == self_input:
            return suffix
        cur = i.node().s("name")
        if suffix is not None:
            cur = cur + '.' + suffix
        of = next(i.node().inputs())
        return find_name(of, self_input, suffix=cur)

    gr = mod.graph
    toshow = []
    # list(traced_model.graph.nodes())[0]
    self_input = next(gr.inputs())
    self_type = self_input.type().str().split('.')[-1]
    preds[self_input] = (set(), set()) # inps, ops
    
    if dot is None:
        dot = graphviz.Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'})
        #dot.attr('node', shape='box')

    seen_inpnames = set()
    seen_edges = set()
    
    def add_edge(dot, n1, n2):
        if (n1, n2) not in seen_edges:
            seen_edges.add((n1, n2))
            dot.edge(n1, n2)

    def make_edges(pr, inpname, name, op, edge_dot=dot):
        if op:
            if inpname not in seen_inpnames:
                seen_inpnames.add(inpname)
                label_lines = [[]]
                line_len = 0
                for w in op:
                    if line_len >= 20:
                        label_lines.append([])
                        line_len = 0
                    label_lines[-1].append(w)
                    line_len += len(w) + 1
                edge_dot.node(inpname, label='\n'.join([' '.join(w) for w in label_lines]), shape='box', style='rounded')
                for p in pr:
                    add_edge(edge_dot, p, inpname)
            add_edge(edge_dot, inpname, name)
        else:
            for p in pr:
                add_edge(edge_dot, p, name)

    for nr, i in enumerate(list(gr.inputs())[1:]):
        name = prefix+'inp_'+i.debugName()
        preds[i] = {name}, set()
        dot.node(name, shape='ellipse')
        if input_preds is not None:
            pr, op = input_preds[nr]
            make_edges(pr, 'inp_'+name, name, op, edge_dot=parent_dot)
        
    def is_relevant_type(t):
        kind = t.kind()
        if kind == 'TensorType':
            return True
        if kind in ('ListType', 'OptionalType'):
            return is_relevant_type(t.getElementType())
        if kind == 'TupleType':
            return any([is_relevant_type(tt) for tt in t.elements()])
        return False

    for n in gr.nodes():
        only_first_ops = {'aten::expand_as'}
        rel_inp_end = 1 if n.kind() in only_first_ops else None
            
        relevant_inputs = [i for i in list(n.inputs())[:rel_inp_end] if is_relevant_type(i.type())]
        relevant_outputs = [o for o in n.outputs() if is_relevant_type(o.type())]
        if n.kind() == 'prim::CallMethod':
            fq_submodule_name = '.'.join([nc for nc in list(n.inputs())[0].type().str().split('.') if not nc.startswith('__')])
            submodule_type = list(n.inputs())[0].type().str().split('.')[-1]
            submodule_name = find_name(list(n.inputs())[0], self_input)
            name = prefix+'.'+n.output().debugName()
            label = prefix+submodule_name+' (' + submodule_type + ')'
            if classes_found is not None:
                classes_found.add(fq_submodule_name)
            if ((classes_to_visit is None and
                 (not fq_submodule_name.startswith('torch.nn') or 
                  fq_submodule_name.startswith('torch.nn.modules.container')))
                or (classes_to_visit is not None and 
                    (submodule_type in classes_to_visit
                    or fq_submodule_name in classes_to_visit))):
                # go into subgraph
                sub_prefix = prefix+submodule_name+'.'
                with dot.subgraph(name="cluster_"+name) as sub_dot:
                    sub_dot.attr(label=label)
                    submod = mod
                    for k in  submodule_name.split('.'):
                        submod = getattr(submod, k)
                    make_graph(submod, dot=sub_dot, prefix=sub_prefix,
                              input_preds = [preds[i] for i in list(n.inputs())[1:]],
                              parent_dot=dot, classes_to_visit=classes_to_visit,
                              classes_found=classes_found)
                for i, o in enumerate(n.outputs()):
                    preds[o] = {sub_prefix+f'out_{i}'}, set()
            else:
                dot.node(name, label=label, shape='box')
                for i in relevant_inputs:
                    pr, op = preds[i]
                    make_edges(pr, prefix+i.debugName(), name, op)
                for o in n.outputs():
                    preds[o] = {name}, set()
        elif n.kind() == 'prim::CallFunction':
            funcname = list(n.inputs())[0].type().__repr__().split('.')[-1]
            name = prefix+'.'+n.output().debugName()
            label = funcname
            dot.node(name, label=label, shape='box')
            for i in relevant_inputs:
                pr, op = preds[i]
                make_edges(pr, prefix+i.debugName(), name, op)
            for o in n.outputs():
                preds[o] = {name}, set()
        else:
            unseen_ops = {'prim::ListConstruct', 'prim::TupleConstruct', 'aten::index', 
                          'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',
                          'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',
                          'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',
                          'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',
                          'aten::cat', 'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',
                          }
        
            absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') # probably also partially absorbing ops. :/
            if False:
                print(n.kind())
                #DEBUG['kinds'].add(n.kind())
                #DEBUG[n.kind()] = n
                label = n.kind().split('::')[-1].rstrip('_')
                name = prefix+'.'+relevant_outputs[0].debugName()
                dot.node(name, label=label, shape='box', style='rounded')
                for i in relevant_inputs:
                    pr, op = preds[i]
                    make_edges(pr, prefix+i.debugName(), name, op)
                for o in n.outputs():
                    preds[o] = {name}, set()
            if True:
                label = n.kind().split('::')[-1].rstrip('_')
                pr, op = set(), set()
                for i in relevant_inputs:
                    apr, aop = preds[i]
                    pr |= apr
                    op |= aop
                if pr and n.kind() not in unseen_ops:
                    print(n.kind(), n)
                if n.kind() in absorbing_ops:
                    pr, op = set(), set()
                elif len(relevant_inputs) > 0 and len(relevant_outputs) > 0 and n.kind() not in unseen_ops:
                    op.add(label)
                for o in n.outputs():
                    preds[o] = pr, op

    for i, o in enumerate(gr.outputs()):
        name = prefix+f'out_{i}'
        dot.node(name, shape='ellipse')
        pr, op = preds[o]
        make_edges(pr, 'inp_'+name, name, op)
    return dot

In [8]:
import ipywidgets
import IPython.display
import IPython
import tempfile
import os

# global callback and register
registered_jit_visualizations = {}
def on_click_in_jit_visualization(objid, t):
    # careful, if this throws an error, you won't notice
    registered_jit_visualizations[objid](t)
    return

class JITVisualizer:
    def __init__(self, mod, *, classes_to_visit=None, submodules_to_visit=None, expanded_names=None):
        self.mod = mod
        self.classes_to_visit = classes_to_visit or []  # default to no subclasses
        self.classes_found = set()
        self.submodules_to_visit = submodules_to_visit or set()
        self.expanded_names = expanded_names or set()
        self.js_display_id = None
        self.clickseq = []

    def make_node(self, dot, name, *, clickable=False, **kwargs):
        assert name not in self.node_ids
        nid = str(len(self.node_names))
        self.node_names.append(name)
        self.node_ids[name] = nid
        if 'label' not in kwargs:
            dot.node(self.node_ids[name], label=name, **kwargs)
        else:
            dot.node(self.node_ids[name], **kwargs)

    def make_edge(self, dot, n1, n2):
        dot.edge(self.node_ids[n1], self.node_ids[n2])

    def make_graph(self):
        self.node_ids = {}
        self.node_names = [None]
        self.cluster_ids = {}
        self.cluster_names = [None]
        return self._make_graph(mod=self.mod)

    def cluster_id(self, name):
        cid = 'cluster_' + str(len(self.cluster_names))
        self.cluster_names.append(name)
        self.cluster_ids[name] = cid
        return cid

    def is_relevant_type(self, t):
        kind = t.kind()
        if kind == 'TensorType':
            return True
        if kind in ('ListType', 'OptionalType'):
            return self.is_relevant_type(t.getElementType())
        if kind == 'TupleType':
            return any([self.is_relevant_type(tt) for tt in t.elements()])
        return False

    def make_svg(self):
        self.DEBUG_could_expand = []
        graph = self.make_graph()
        with tempfile.TemporaryDirectory() as d:
            res = graph.render(os.path.join(d, 'gr'))
            svg = open(res).read()
        return svg

    def _make_graph(self, *, mod, dot=None, prefix="", input_preds=None, parent_dot=None):
        gr = mod.graph
        preds = {}

        def find_name(i, self_input, suffix=None):
            if i == self_input:
                return suffix
            cur = i.node().s("name")
            if suffix is not None:
                cur = cur + '.' + suffix
            of = next(i.node().inputs())
            return find_name(of, self_input, suffix=cur)

        toshow = []
        # list(traced_model.graph.nodes())[0]
        self_input = next(gr.inputs())
        self_type = self_input.type().str().split('.')[-1]
        preds[self_input] = (set(), set())  # inps, ops

        if dot is None:
            dot = graphviz.Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'})

        seen_inpnames = set()
        seen_edges = set()

        def add_edge(dot, n1, n2):
            if (n1, n2) not in seen_edges:
                seen_edges.add((n1, n2))
                self.make_edge(dot, n1, n2)
                #dot.edge(n1, n2)

        def make_edges(pr, inpname, name, op, edge_dot=dot):
            if op:
                if inpname not in seen_inpnames:
                    seen_inpnames.add(inpname)
                    label_lines = [[]]
                    line_len = 0
                    for w in op:
                        if line_len >= 20:
                            label_lines.append([])
                            line_len = 0
                        label_lines[-1].append(w)
                        line_len += len(w) + 1
                    self.make_node(edge_dot, inpname, label='\n'.join([' '.join(w) for w in label_lines]), shape='box',
                                  style='rounded')
                    for p in pr:
                        add_edge(edge_dot, p, inpname)
                add_edge(edge_dot, inpname, name)
            else:
                for p in pr:
                    add_edge(edge_dot, p, name)

        for nr, i in enumerate(list(gr.inputs())[1:]):
            name = prefix + 'inp_' + i.debugName()
            preds[i] = {name}, set()
            self.make_node(dot, name, shape='ellipse')
            if input_preds is not None:
                pr, op = input_preds[nr]
                make_edges(pr, 'inp_' + name, name, op, edge_dot=parent_dot)

        for n in gr.nodes():
            only_first_ops = {'aten::expand_as'}
            rel_inp_end = 1 if n.kind() in only_first_ops else None

            relevant_inputs = [i for i in list(n.inputs())[:rel_inp_end] if self.is_relevant_type(i.type())]
            relevant_outputs = [o for o in n.outputs() if self.is_relevant_type(o.type())]
            if n.kind() == 'prim::CallMethod':
                fq_submodule_name = '.'.join(
                    [nc for nc in list(n.inputs())[0].type().str().split('.') if not nc.startswith('__')])
                submodule_type = list(n.inputs())[0].type().str().split('.')[-1]
                submodule_name = find_name(list(n.inputs())[0], self_input)
                name = prefix + '.' + n.output().debugName()
                label = prefix + submodule_name + ' (' + submodule_type + ')'
                self.classes_found.add(fq_submodule_name)  # debugging
                if ((self.classes_to_visit is None and
                     (not fq_submodule_name.startswith('torch.nn') or
                      fq_submodule_name.startswith('torch.nn.modules.container')))
                    or (self.classes_to_visit is not None and
                        (submodule_type in self.classes_to_visit
                         or fq_submodule_name in self.classes_to_visit))
                    or (name in self.expanded_names)):
                    # go into subgraph
                    sub_prefix = prefix + submodule_name + '.'
                    # print("name=cluster_" + name, f"{sub_prefix=}")
                    
                    with dot.subgraph(name=self.cluster_id(name)) as sub_dot:
                        sub_dot.attr(label=label)
                        submod = mod
                        for k in submodule_name.split('.'):
                            submod = getattr(submod, k)
                        self._make_graph(mod=submod, dot=sub_dot, prefix=sub_prefix,
                                   input_preds=[preds[i] for i in list(n.inputs())[1:]],
                                   parent_dot=dot)
                    for i, o in enumerate(n.outputs()):
                        preds[o] = {sub_prefix + f'out_{i}'}, set()
                else:
                    self.DEBUG_could_expand.append(name)
                    # print("could expand", name)
                    self.make_node(dot, name, label=label, shape='box', clickable=True)
                    for i in relevant_inputs:
                        pr, op = preds[i]
                        make_edges(pr, prefix + i.debugName(), name, op)
                    for o in n.outputs():
                        preds[o] = {name}, set()
            elif n.kind() == 'prim::CallFunction':
                funcname = list(n.inputs())[0].type().__repr__().split('.')[-1]
                name = prefix + '.' + n.output().debugName()
                label = funcname
                self.make_node(dot, name, label=label, shape='box')
                for i in relevant_inputs:
                    pr, op = preds[i]
                    make_edges(pr, prefix + i.debugName(), name, op)
                for o in n.outputs():
                    preds[o] = {name}, set()
            else:
                unseen_ops = {'prim::ListConstruct', 'prim::TupleConstruct', 'aten::index',
                              'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',
                              'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',
                              'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',
                              'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',
                              #'aten::cat', 
                              'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',
                              }

                absorbing_ops = ('aten::size', 'aten::_shape_as_tensor')  # probably also partially absorbing ops. :/

                label = n.kind().split('::')[-1].rstrip('_')
                pr, op = set(), set()
                for i in relevant_inputs:
                    apr, aop = preds[i]
                    pr |= apr
                    op |= aop
                #if pr and n.kind() not in unseen_ops:
                #    print(n.kind(), n)
                if n.kind() in absorbing_ops:
                    pr, op = set(), set()
                elif len(relevant_inputs) > 0 and len(relevant_outputs) > 0 and n.kind() not in unseen_ops:
                    op.add(label)
                for o in n.outputs():
                    preds[o] = pr, op

        for i, o in enumerate(gr.outputs()):
            name = prefix + f'out_{i}'
            self.make_node(dot, name, shape='ellipse')
            pr, op = preds[o]
            make_edges(pr, 'inp_' + name, name, op)
        return dot

    def update_javascript(self):
        registered_jit_visualizations[id(self)] = self.show_with_lastclick
        js = '''  var my_click_event_func = function(e){
             var kernel = IPython.notebook.kernel;
             var label = $('title', e.currentTarget)[0].textContent;
             //window.alert('hello!' + label);
             if (label.startsWith('cluster_')) {
               label = label.slice(8) + "1";
             } else {
               label = label + "2";
             }
             //window.alert('hello!' + label);
             kernel.execute("on_click_in_jit_visualization(''' +str(id(self))+''', "+ label+ ")");
             $(".tv-graph .node").on("click", my_click_event_func);
             $(".tv-graph .cluster").on("click", my_click_event_func);
          };
          $(".tv-graph .node").on("click", my_click_event_func);
          $(".tv-graph .cluster").on("click", my_click_event_func);          
          '''

        if self.js_display_id is None:
            self.js_display_id = IPython.display.display(IPython.display.Javascript(js), display_id=True)
        else:
            self.js_display_id.update(IPython.display.Javascript(js))

    def show_with_lastclick(self, i=None):
        if i is not None:
            typ = i % 10
            i = i // 10
            if typ == 1:
                # cluster
                self.expanded_names.remove(self.cluster_names[i])
            else:
                self.expanded_names.add(self.node_names[i])
            self.last_clicked_typ = typ
            self.last_clicked = i
            self.clickseq.append((typ, i))
        svg = self.make_svg()
        self.html.value = svg
        self.update_javascript()


    def show_interactive_graph(self):
        self.html = ipywidgets.HTML("")
        self.html.add_class('tv-graph')
        IPython.display.display(self.html)
        self.show_with_lastclick()


## Applications


Let's apply it! These are the pictures from my blog post along with the code that generated them.

The following code is from the [transformers library](https://github.com/huggingface/transformers/) (Copyright 2018- The Hugging Face team. Apache Licensed.).

In [29]:

import transformers

from transformers import BertModel, BertTokenizer, BertConfig
import numpy

import torch

enc = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

model.eval()
for p in model.parameters():
    p.requires_grad_(False)

transformers.__version__

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


'4.12.0.dev0'

In [30]:
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
traced_model.eval()
for p in traced_model.parameters():
    p.requires_grad_(False)

In [31]:
if 0:
    # resolvign functions?
    t = fn.type()
    def lookup(fn):
        n = str(fn.type()).split('.')[1:]
        res = globals()[n[0]]
        for nc in n[1:]:
            res = getattr(res, nc)
        return res
    lookup(fn).graph

In [39]:
viz = JITVisualizer(traced_model, expanded_names={'.3379'})
viz.show_interactive_graph()


HTML(value='', _dom_classes=('tv-graph',))

<IPython.core.display.Javascript object>

In [43]:
mod = getattr(traced_model.encoder.layer, "0") # traced_model.encoder.layer[0]


viz = JITVisualizer(getattr(traced_model.encoder.layer, "0"), expanded_names={'.9', 'attention..7'}) # classes_to_visit={'BertAttention', 'BertSelfAttention'}
viz.show_interactive_graph()

HTML(value='', _dom_classes=('tv-graph',))

<IPython.core.display.Javascript object>

In [44]:
import torchvision

In [45]:
m = torchvision.models.resnet18()
tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)])

In [46]:
m = torchvision.models.resnet18()
m.layer1[0]

BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [47]:
print(inspect.getsource(m.layer1[0].forward))

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out



In [53]:
viz = JITVisualizer(tm, expanded_names={'.1617', 'layer1..6', 'layer1..7'})
viz.show_interactive_graph()


HTML(value='', _dom_classes=('tv-graph',))

<IPython.core.display.Javascript object>

In [54]:
viz = JITVisualizer(getattr(tm.layer1, "0"))
viz.show_interactive_graph()


HTML(value='', _dom_classes=('tv-graph',))

<IPython.core.display.Javascript object>

In [60]:
m = torchvision.models.segmentation.fcn_resnet50()
tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)], strict=False)
#d = make_graph(tm, classes_to_visit={'IntermediateLayerGetter', 'FCNHead'})
viz = JITVisualizer(tm, expanded_names={'.4164', '.4165'})
viz.show_interactive_graph()

HTML(value='', _dom_classes=('tv-graph',))

<IPython.core.display.Javascript object>

In [61]:
class Detection(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.m = torchvision.models.detection.fasterrcnn_resnet50_fpn().eval()
    def forward(self, inp):
        assert inp.shape[0] == 1
        res, = self.m(inp)
        return res['boxes'], res['labels'], res['scores']

tm = torch.jit.trace(Detection(), [torch.randn(1, 3, 224, 224)], check_trace=False)


  assert inp.shape[0] == 1
  (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
  torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
  torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
  A = Ax4 // 4
  C = AxC // A
  boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
  boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
  torch.tensor(s, dtype=torch.float32, device=boxes.device)
  / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)


In [67]:
viz = JITVisualizer(tm)
viz.show_interactive_graph()


HTML(value='', _dom_classes=('tv-graph',))

<IPython.core.display.Javascript object>

In [66]:
viz.expanded_names = {'.8220', 'm..92', 'm..93'}
viz.show_interactive_graph()

HTML(value='', _dom_classes=('tv-graph',))