From 55751347de79495495197949784e2960281eda58 Mon Sep 17 00:00:00 2001 From: James Reed Date: Thu, 19 Nov 2020 12:12:15 -0800 Subject: [PATCH] [FX] Add Node.all_input_nodes ghstack-source-id: 4e39d5d60fabffd899cda0ceb52a3dd75c9136e2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48270 --- test/test_fx.py | 15 +++++++++++++++ torch/fx/node.py | 12 ++++++++++++ 2 files changed, 27 insertions(+) diff --git a/test/test_fx.py b/test/test_fx.py index c2838ff56dd4..5a47c729f7eb 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -551,6 +551,21 @@ def forward(self, x): x = torch.rand(3, 4) self.assertEqual(loaded(x), traced(x)) + def test_all_input_nodes(self): + graph : torch.fx.Graph = torch.fx.Graph() + a : torch.fx.Node = graph.placeholder('x') + b : torch.fx.Node = graph.call_module('linear_mod', args=(a,)) + c : torch.fx.Node = graph.get_attr('y_attr') + d : torch.fx.Node = graph.call_function(operator.add, args=(b, c)) + e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) + graph.output(e) + graph.lint() + + self.assertEqual(b.all_input_nodes, [a]) + self.assertEqual(c.all_input_nodes, []) + self.assertEqual(d.all_input_nodes, [b, c]) + self.assertEqual(e.all_input_nodes, [d]) + def test_deepcopy_graphmodule_with_transform(self): st = SimpleTest() traced = symbolic_trace(st) diff --git a/torch/fx/node.py b/torch/fx/node.py index dd304a801155..8c484e0ab421 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -136,6 +136,18 @@ def kwargs(self, k : Dict[str, Argument]): """ self._update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore + @property + def all_input_nodes(self) -> List['Node']: + """ + Return all Nodes that are inputs to this Node. This is equivalent to + iterating over `args` and `kwargs` and only collecting the values that + are Nodes + """ + all_nodes : List['Node'] = [] + map_arg(self.args, lambda n: all_nodes.append(n)) + map_arg(self.kwargs, lambda n: all_nodes.append(n)) + return all_nodes + def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): self._args = new_args self._kwargs = new_kwargs