From b733da11aac8a241fa922aca3a73d0ea72e6768f Mon Sep 17 00:00:00 2001 From: James Reed Date: Tue, 19 Jan 2021 17:21:24 -0800 Subject: [PATCH] [FX] Fix NoneType annotation in generated code [ghstack-poisoned] --- test/test_fx.py | 3 +++ torch/fx/graph.py | 2 +- torch/fx/graph_module.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 4f01e876854e..089c793ff239 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -815,6 +815,9 @@ def test_remove_uses(self): self.assertTrue(neg not in relu.users) + def test_nonetype_annotation(self): + eb = torch.nn.EmbeddingBag(3, 4) + symbolic_trace(eb) def test_construct_root_dict(self): graph : torch.fx.Graph = torch.fx.Graph() diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 34bbc98cf9e0..09e7f2bfc8d9 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -8,7 +8,7 @@ import re def _shadows_builtin_name(name: str) -> bool: - return name in builtins.__dict__ or name in keyword.kwlist or name in {'inf', 'nan'} + return name in builtins.__dict__ or name in keyword.kwlist or name in {'inf', 'nan', 'NoneType'} def _is_magic(x: str) -> bool: return x.startswith('__') and x.endswith('__') diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index fc68cdab5677..80d16a12036f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -36,7 +36,8 @@ def patched_getline(*args, **kwargs): linecache.getlines = patched_getline def _forward_from_src(src : str): - gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan} + # If you add more globals here, remember to add their names to fx.graph._shadows_builtin_name! + gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan, 'NoneType' : type(None)} exec_with_source(src, gbls) return gbls['forward']