-
Notifications
You must be signed in to change notification settings - Fork 278
/
dot.py
190 lines (162 loc) · 7.2 KB
/
dot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from collections import namedtuple
from distutils.version import LooseVersion
from graphviz import Digraph
import torch
from torch.autograd import Variable
import warnings
Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))
# Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*`
SAVED_PREFIX = "_saved_"
def get_fn_name(fn, show_attrs, max_attr_chars):
name = str(type(fn).__name__)
if not show_attrs:
return name
attrs = dict()
for attr in dir(fn):
if not attr.startswith(SAVED_PREFIX):
continue
val = getattr(fn, attr)
attr = attr[len(SAVED_PREFIX):]
if torch.is_tensor(val):
attrs[attr] = "[saved tensor]"
elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val):
attrs[attr] = "[saved tensors]"
else:
attrs[attr] = str(val)
if not attrs:
return name
max_attr_chars = max(max_attr_chars, 3)
col1width = max(len(k) for k in attrs.keys())
col2width = min(max(len(str(v)) for v in attrs.values()), max_attr_chars)
sep = "-" * max(col1width + col2width + 2, len(name))
attrstr = '%-' + str(col1width) + 's: %' + str(col2width)+ 's'
truncate = lambda s: s[:col2width - 3] + "..." if len(s) > col2width else s
params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items())
return name + '\n' + sep + '\n' + params
def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50):
""" Produces Graphviz representation of PyTorch autograd graph.
If a node represents a backward function, it is gray. Otherwise, the node
represents a tensor and is either blue, orange, or green:
- Blue: reachable leaf tensors that requires grad (tensors whose `.grad`
fields will be populated during `.backward()`)
- Orange: saved tensors of custom autograd functions as well as those
saved by built-in backward nodes
- Green: tensor passed in as outputs
- Dark green: if any output is a view, we represent its base tensor with
a dark green node.
Args:
var: output tensor
params: dict of (name, tensor) to add names to node that requires grad
show_attrs: whether to display non-tensor attributes of backward nodes
(Requires PyTorch version >= 1.9)
show_saved: whether to display saved tensor nodes that are not by custom
autograd functions. Saved tensor nodes for custom functions, if
present, are always displayed. (Requires PyTorch version >= 1.9)
max_attr_chars: if show_attrs is `True`, sets max number of characters
to display for any given attribute.
"""
if LooseVersion(torch.__version__) < LooseVersion("1.9") and \
(show_attrs or show_saved):
warnings.warn(
"make_dot: showing grad_fn attributes and saved variables"
" requires PyTorch version >= 1.9. (This does NOT apply to"
" saved tensors saved by custom autograd functions.)")
if params is not None:
assert all(isinstance(p, Variable) for p in params.values())
param_map = {id(v): k for k, v in params.items()}
else:
param_map = {}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='10',
ranksep='0.1',
height='0.2',
fontname='monospace')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '(' + (', ').join(['%d' % v for v in size]) + ')'
def get_var_name(var, name=None):
if not name:
name = param_map[id(var)] if id(var) in param_map else ''
return '%s\n %s' % (name, size_to_str(var.size()))
def add_nodes(fn):
assert not torch.is_tensor(fn)
if fn in seen:
return
seen.add(fn)
if show_saved:
for attr in dir(fn):
if not attr.startswith(SAVED_PREFIX):
continue
val = getattr(fn, attr)
seen.add(val)
attr = attr[len(SAVED_PREFIX):]
if torch.is_tensor(val):
dot.edge(str(id(fn)), str(id(val)), dir="none")
dot.node(str(id(val)), get_var_name(val, attr), fillcolor='orange')
if isinstance(val, tuple):
for i, t in enumerate(val):
if torch.is_tensor(t):
name = attr + '[%s]' % str(i)
dot.edge(str(id(fn)), str(id(t)), dir="none")
dot.node(str(id(t)), get_var_name(t, name), fillcolor='orange')
if hasattr(fn, 'variable'):
# if grad_accumulator, add the node for `.variable`
var = fn.variable
seen.add(var)
dot.node(str(id(var)), get_var_name(var), fillcolor='lightblue')
dot.edge(str(id(var)), str(id(fn)))
# add the node for this grad_fn
dot.node(str(id(fn)), get_fn_name(fn, show_attrs, max_attr_chars))
# recurse
if hasattr(fn, 'next_functions'):
for u in fn.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(fn)))
add_nodes(u[0])
# note: this used to show .saved_tensors in pytorch0.2, but stopped
# working* as it was moved to ATen and Variable-Tensor merged
# also note that this still works for custom autograd functions
if hasattr(fn, 'saved_tensors'):
for t in fn.saved_tensors:
seen.add(t)
dot.edge(str(id(t)), str(id(fn)), dir="none")
dot.node(str(id(t)), get_var_name(t), fillcolor='orange')
def add_base_tensor(var, color='darkolivegreen1'):
if var in seen:
return
seen.add(var)
dot.node(str(id(var)), get_var_name(var), fillcolor=color)
if (var.grad_fn):
add_nodes(var.grad_fn)
dot.edge(str(id(var.grad_fn)), str(id(var)))
if var._is_view():
add_base_tensor(var._base, color='darkolivegreen3')
dot.edge(str(id(var._base)), str(id(var)), style="dotted")
# handle multiple outputs
if isinstance(var, tuple):
for v in var:
add_base_tensor(v)
else:
add_base_tensor(var)
resize_graph(dot)
return dot
def make_dot_from_trace(trace):
""" This functionality is not available in pytorch core at
https://pytorch.org/docs/stable/tensorboard.html
"""
# from tensorboardX
raise NotImplementedError("This function has been moved to pytorch core and "
"can be found here: https://pytorch.org/docs/stable/tensorboard.html")
def resize_graph(dot, size_per_element=0.15, min_size=12):
"""Resize the graph according to how much content it contains.
Modify the graph in place.
"""
# Get the approximate number of nodes and edges
num_rows = len(dot.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
dot.graph_attr.update(size=size_str)