Skip to content

Commit

Permalink
[FX] Fix NoneType annotation in generated code
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
James Reed committed Jan 20, 2021
1 parent 526659d commit b733da1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions test/test_fx.py
Expand Up @@ -815,6 +815,9 @@ def test_remove_uses(self):

self.assertTrue(neg not in relu.users)

def test_nonetype_annotation(self):
eb = torch.nn.EmbeddingBag(3, 4)
symbolic_trace(eb)

def test_construct_root_dict(self):
graph : torch.fx.Graph = torch.fx.Graph()
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/graph.py
Expand Up @@ -8,7 +8,7 @@
import re

def _shadows_builtin_name(name: str) -> bool:
return name in builtins.__dict__ or name in keyword.kwlist or name in {'inf', 'nan'}
return name in builtins.__dict__ or name in keyword.kwlist or name in {'inf', 'nan', 'NoneType'}

def _is_magic(x: str) -> bool:
return x.startswith('__') and x.endswith('__')
Expand Down
3 changes: 2 additions & 1 deletion torch/fx/graph_module.py
Expand Up @@ -36,7 +36,8 @@ def patched_getline(*args, **kwargs):
linecache.getlines = patched_getline

def _forward_from_src(src : str):
gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan}
# If you add more globals here, remember to add their names to fx.graph._shadows_builtin_name!
gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan, 'NoneType' : type(None)}
exec_with_source(src, gbls)
return gbls['forward']

Expand Down

0 comments on commit b733da1

Please sign in to comment.