Skip to content

Commit

Permalink
update visualizer
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko committed May 9, 2017
1 parent e8f9bd2 commit aad9b64
Show file tree
Hide file tree
Showing 7 changed files with 10,783 additions and 7,651 deletions.
1,425 changes: 789 additions & 636 deletions fast-neural-style.ipynb

Large diffs are not rendered by default.

1,178 changes: 710 additions & 468 deletions nin-export.ipynb

Large diffs are not rendered by default.

5,498 changes: 3,323 additions & 2,175 deletions resnet-18-at-export.ipynb

Large diffs are not rendered by default.

1,967 changes: 1,134 additions & 833 deletions resnet-18-export.ipynb

Large diffs are not rendered by default.

3,482 changes: 2,000 additions & 1,482 deletions resnet-34-export.ipynb

Large diffs are not rendered by default.

43 changes: 34 additions & 9 deletions visualize.py
@@ -1,7 +1,21 @@
from graphviz import Digraph
import torch
from torch.autograd import Variable

def make_dot(var):

def make_dot(var, params):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
param_map = {id(v): k for k, v in params.items()}

node_attr = dict(style='filled',
shape='box',
align='left',
Expand All @@ -10,18 +24,29 @@ def make_dot(var):
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()

def size_to_str(size):
return '('+(', ').join(['%d'% v for v in size])+')'

def add_nodes(var):
if var not in seen:
if isinstance(var, Variable):
value = '('+(', ').join(['%d'% v for v in var.size()])+')'
dot.node(str(id(var)), str(value), fillcolor='lightblue')
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
node_name = '%s\n %s' % (param_map[id(u)], size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'previous_functions'):
for u in var.previous_functions:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
add_nodes(var.creator)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
4,841 changes: 2,793 additions & 2,048 deletions wide-resnet-50-2-export.ipynb

Large diffs are not rendered by default.

0 comments on commit aad9b64

Please sign in to comment.