Skip to content

Commit

Permalink
[FX] Add Node.all_input_nodes (#48270)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #48270

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D25100241

Pulled By: jamesr66a

fbshipit-source-id: f742f5a13debebb5be37f7c0045c121f6eaff1d5
  • Loading branch information
James Reed authored and facebook-github-bot committed Nov 20, 2020
1 parent aa8aa30 commit 998c4ca
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
15 changes: 15 additions & 0 deletions test/test_fx.py
Expand Up @@ -551,6 +551,21 @@ def forward(self, x):
x = torch.rand(3, 4)
self.assertEqual(loaded(x), traced(x))

def test_all_input_nodes(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.placeholder('x')
b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
c : torch.fx.Node = graph.get_attr('y_attr')
d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
graph.output(e)
graph.lint()

self.assertEqual(b.all_input_nodes, [a])
self.assertEqual(c.all_input_nodes, [])
self.assertEqual(d.all_input_nodes, [b, c])
self.assertEqual(e.all_input_nodes, [d])

def test_deepcopy_graphmodule_with_transform(self):
st = SimpleTest()
traced = symbolic_trace(st)
Expand Down
12 changes: 12 additions & 0 deletions torch/fx/node.py
Expand Up @@ -136,6 +136,18 @@ def kwargs(self, k : Dict[str, Argument]):
"""
self._update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore

@property
def all_input_nodes(self) -> List['Node']:
"""
Return all Nodes that are inputs to this Node. This is equivalent to
iterating over `args` and `kwargs` and only collecting the values that
are Nodes
"""
all_nodes : List['Node'] = []
map_arg(self.args, lambda n: all_nodes.append(n))
map_arg(self.kwargs, lambda n: all_nodes.append(n))
return all_nodes

def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
self._args = new_args
self._kwargs = new_kwargs
Expand Down

0 comments on commit 998c4ca

Please sign in to comment.