Skip to content

Commit

Permalink
[fx] use a linked list for nodes (#45708)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45708

This makes it possible to define reasonable semantics for what happens
when a node in the list is deleted. In particular the iteration over nodes
will continue at the node that was after the deleted node _when it was deleted_.
If the new node is also deleted, we skip it and, continue to the node after it.
Eventually we either reach a node still in the list or we reach the end of the list.

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D24089516

Pulled By: zdevito

fbshipit-source-id: d01312d11fe381c8d910a83a08582a2219f47dda
  • Loading branch information
zdevito authored and facebook-github-bot committed Oct 13, 2020
1 parent 31ee5d8 commit 88dcb95
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 45 deletions.
16 changes: 8 additions & 8 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,8 @@ def test_wrong_topo(self):
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
graph.output(d)
nodes = graph._nodes
nodes[2], nodes[3] = nodes[3], nodes[2]
nodes = list(graph.nodes)
nodes[3].append(nodes[2])
with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
graph.lint()

Expand Down Expand Up @@ -802,7 +802,7 @@ def forward(self, a, b):
dag = partitioner.partition_graph(traced, devices)
for node in traced.graph.nodes:
assert node.op == 'output' or node.partition_ids == [1]
nodes = traced.graph.nodes
nodes = list(traced.graph.nodes)
res_dag = DAG()
res_dag.create_node(0, [], [1], [], [])
res_dag.create_node(1, [0], [], [nodes[0], nodes[1]], [nodes[2]])
Expand Down Expand Up @@ -868,7 +868,7 @@ def is_leaf_module(self, m : torch.nn.Module, qualname : str):
kwargs = node.kwargs
# Neg doesn't have in-place
kwargs.pop('inplace')
with torch.fx.graph.insert_before(node):
with rn18_traced.graph.inserting_before(node):
new_node = rn18_traced.graph.call_function(
the_function=torch.neg, args=node.args, kwargs=node.kwargs)
node.replace_all_uses_with(replace_with=new_node)
Expand All @@ -883,7 +883,7 @@ def test_insertion_point(self):
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
output : torch.fx.Node = graph.output(b)

with torch.fx.graph.insert_before(b):
with graph.inserting_before(b):
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
_, *relu_args = b.args
b.args = (neg, *relu_args)
Expand All @@ -903,7 +903,7 @@ def test_move_before(self):
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
_, *relu_args = b.args
b.args = (neg, *relu_args)
graph.move_node_before(to_move=neg, before=b)
b.prepend(neg)

gm = torch.fx.GraphModule(torch.nn.Module(), graph)

Expand Down Expand Up @@ -951,7 +951,7 @@ def forward(self, x):
combined_graph = torch.fx.Graph()
output_node = combined_graph.graph_copy(inline_into.graph, {})

input_node = to_inline.graph.nodes[0]
input_node = list(to_inline.graph.nodes)[0]
assert input_node and input_node.op == 'placeholder'

val_map = {input_node : output_node}
Expand All @@ -968,7 +968,7 @@ def test_multi_insert_point(self):
x = torch.fx.Proxy(graph.placeholder('x'))
relu = torch.relu(x)

with torch.fx.graph.insert_before(relu.node):
with graph.inserting_before(relu.node):
y = torch.neg(x)
z = torch.tanh(y)

Expand Down
121 changes: 85 additions & 36 deletions torch/fx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,50 @@ def _type_repr(obj):
return obj.__name__
return repr(obj)

class insert_before:
def __init__(self, n : Node):
self.n = n
class _InsertPoint:
def __init__(self, graph, new_insert):
self.graph = graph
self.orig_insert, graph._insert = graph._insert, new_insert

def __enter__(self):
self.orig_insert_point = self.n.graph._insert_point
self.n.graph._insert_point = self.n
pass

def __exit__(self, type, value, tb):
self.n.graph._insert_point = self.orig_insert_point
self.graph._insert = self.orig_insert

class _node_list:
def __init__(self, graph: 'Graph', direction: str = '_next'):
assert direction in ['_next', '_prev']
self.graph = graph
self.direction = direction

def __len__(self):
return self.graph._len

def __iter__(self):
root, direction = self.graph._root, self.direction
cur = getattr(root, direction)
while cur is not root:
if not cur._erased:
yield cur
cur = getattr(cur, direction)

def __reversed__(self):
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')

class Graph:
def __init__(self):
"""
Construct an empty Graph.
"""
self._nodes : List[Node] = []
self._root : Node = Node(self, '', 'root', '', (), {})
self._used_names : Dict[str, int] = {} # base name -> number
self._insert_point : Optional[Node] = None
self._insert = self._root.prepend
self._len = 0

@property
def nodes(self):
return tuple(self._nodes)
return _node_list(self)

def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argument]:
"""
Expand All @@ -107,7 +128,7 @@ def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argume
items by this function. Returns the equivalent output value of `g` with
Nodes switched to refer to nodes in `self`.
"""
for node in g._nodes:
for node in g.nodes:
if node in val_map:
continue
if node.op == 'output':
Expand All @@ -126,26 +147,10 @@ def create_node(self, op: str, target: Target,
kwargs = {} if kwargs is None else kwargs
sanitized_name = self._register_name_used(name) if name is not None else self._name(target)
n = Node(self, sanitized_name, op, target, args, kwargs, type_expr)
if self._insert_point is not None:
before_idx = self._nodes.index(self._insert_point)
self._nodes.insert(before_idx, n)
else:
self._nodes.append(n)
self._insert(n)
self._len += 1
return n

def move_node_before(self, to_move : Node, before : Node):
"""
Move node `to_move` before `before` in the Graph. Both `Node` arguments
must be present in this graph.
"""
# TODO: Computationally inefficient
if to_move.graph != self or before.graph != self:
raise RuntimeError('Node arguments must belong to this Graph!')
node_idx = self._nodes.index(to_move)
before_idx = self._nodes.index(before)
self._nodes.insert(before_idx, self._nodes.pop(node_idx))


def erase_node(self, to_erase : Node):
"""
Erases the node `to_erase` from the `Graph`. Throws an exception if
Expand All @@ -155,11 +160,55 @@ def erase_node(self, to_erase : Node):
raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
f'users in the graph: {to_erase.users}!')

node_indices = [i for i, n in enumerate(self._nodes) if n == to_erase]
for idx in reversed(node_indices):
self._nodes.pop(idx)
to_erase._remove_from_list()
to_erase._erased = True # iterators may retain handles to erased nodes
self._len -= 1

def inserting_before(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and
then restore it when the with statement exits:
with g.inserting_before(n):
... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) # set the insert point permanently
Args:
n (Optional[Node]): The node before which to insert. If None this will insert before
the beginning of the entire graph.
Returns:
A resource manager that will restore the insert point on `__exit__`.
"""
if n is None:
return self.inserting_after(self._root)
assert n.graph == self, "Node to insert before is not in graph."
return _InsertPoint(self, n.prepend)

def inserting_after(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and
then restore it when the with statement exits:
with g.inserting_after(n):
... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) # set the insert point permanently
Args:
n (Optional[Node]): The node before which to insert. If None this will insert after
the beginning of the entire graph.
Returns:
A resource manager that will restore the insert point on `__exit__`.
"""
if n is None:
return self.inserting_before(self._root)
assert n.graph == self, "Node to insert after is not in graph."
return _InsertPoint(self, n.append)

# sugar for above when you know the op
# sugar for create_node when you know the op
def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
return self.create_node('placeholder', name, type_expr=type_expr)

Expand Down Expand Up @@ -269,7 +318,7 @@ def type_repr(o : Any):
register_modules_used(typename)
return typename

for node in self._nodes:
for node in self.nodes:
if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
Expand Down Expand Up @@ -372,7 +421,7 @@ def format_node(n : Node) -> Optional[str]:
f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})'


node_strs = [format_node(node) for node in self._nodes]
node_strs = [format_node(node) for node in self.nodes]
param_str = ', '.join(placeholder_names)
s = f'graph({param_str}){maybe_return_typename[0]}:'
for node_str in node_strs:
Expand Down Expand Up @@ -402,7 +451,7 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None:

seen_names : Set[str] = set()
seen_values : Set[Node] = set()
for node in self._nodes:
for node in self.nodes:
if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
if node.graph is not self:
Expand All @@ -417,7 +466,7 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None:

# Check targets are legit
if root:
for node in self._nodes:
for node in self.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
target_atoms = node.target.split('.')
Expand Down
42 changes: 41 additions & 1 deletion torch/fx/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
type : Optional[Any] = None) -> None:
self.graph = graph
self.name = name # unique name of value being created
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
if op in ['call_method', 'call_module']:
assert isinstance(target, str)
Expand All @@ -51,6 +51,46 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
# does not produce a value, it's more of a notation. Thus, this value
# describes the type of args[0] in the `return` node.
self.type : Optional[Any] = type
self._prev = self
self._next = self
self._erased = False

@property
def next(self) -> 'Node':
return self._next

@property
def prev(self) -> 'Node':
return self._prev

def prepend(self, x: 'Node'):
"""Insert x before this node in the list of nodes in the graph.
Before: p -> self
bx -> x -> ax
After: p -> x -> self
bx -> ax
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x

def append(self, x: 'Node'):
"""Insert x after this node in the list of nodes in the graph.
Equvalent to `self.next.prepend(x)`
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next.prepend(x)

def _remove_from_list(self):
p, n = self._prev, self._next
p._next, n._prev = n, p

@property
def args(self) -> Tuple[Argument, ...]:
Expand Down

0 comments on commit 88dcb95

Please sign in to comment.