Skip to content

Commit

Permalink
[FX] Add Node.all_input_nodes
Browse files Browse the repository at this point in the history
ghstack-source-id: 4e39d5d60fabffd899cda0ceb52a3dd75c9136e2
Pull Request resolved: #48270
  • Loading branch information
James Reed committed Nov 19, 2020
1 parent 678fe9f commit 5575134
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
15 changes: 15 additions & 0 deletions test/test_fx.py
Expand Up @@ -551,6 +551,21 @@ def forward(self, x):
x = torch.rand(3, 4)
self.assertEqual(loaded(x), traced(x))

def test_all_input_nodes(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.placeholder('x')
b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
c : torch.fx.Node = graph.get_attr('y_attr')
d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
graph.output(e)
graph.lint()

self.assertEqual(b.all_input_nodes, [a])
self.assertEqual(c.all_input_nodes, [])
self.assertEqual(d.all_input_nodes, [b, c])
self.assertEqual(e.all_input_nodes, [d])

def test_deepcopy_graphmodule_with_transform(self):
st = SimpleTest()
traced = symbolic_trace(st)
Expand Down
12 changes: 12 additions & 0 deletions torch/fx/node.py
Expand Up @@ -136,6 +136,18 @@ def kwargs(self, k : Dict[str, Argument]):
"""
self._update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore

@property
def all_input_nodes(self) -> List['Node']:
"""
Return all Nodes that are inputs to this Node. This is equivalent to
iterating over `args` and `kwargs` and only collecting the values that
are Nodes
"""
all_nodes : List['Node'] = []
map_arg(self.args, lambda n: all_nodes.append(n))
map_arg(self.kwargs, lambda n: all_nodes.append(n))
return all_nodes

def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
self._args = new_args
self._kwargs = new_kwargs
Expand Down

0 comments on commit 5575134

Please sign in to comment.