Skip to content

Commit

Permalink
Trace calls with Python Enum values.
Browse files Browse the repository at this point in the history
Fix: #82135

ghstack-source-id: 7023390a339bf3ad4a31ef5b01fa6a3599c3e544
Pull Request resolved: #109507
  • Loading branch information
ysiraichi committed Sep 18, 2023
1 parent 8ff0036 commit 64932eb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
20 changes: 20 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3774,6 +3774,26 @@ def test_deepcopy_no_recursion(self):
copy_m = copy.deepcopy(m) # finishes
self.assertEqual(id(copy_m), id(copy_m.meta['hello']))

def test_enum(self):
from enum import Enum

class Foo(Enum):
A = 1
B = 2

def leaf_fn(arr, enum_val):
# Use the raw enum.
arr.append(enum_val)
return arr[-1].value

def foo(x):
# Pass the enum as argument.
return leaf_fn(x, Foo.A)

traced = torch.fx.symbolic_trace(foo)
self.assertEqual(foo([]), traced([]))



def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
Expand Down
40 changes: 23 additions & 17 deletions torch/fx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from contextlib import contextmanager
import copy
import enum
import torch
import keyword
import re
Expand Down Expand Up @@ -389,18 +390,23 @@ def type_repr(o : Any):
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)

def _get_repr(arg: Any) -> str:
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
elif isinstance(arg, torch._ops.OpOverload):
qualified_name = _get_qualified_name(arg)
global_name = add_global(qualified_name, arg)
return f"{global_name}"
elif isinstance(arg, enum.Enum):
cls = arg.__class__
clsname = add_global(cls.__name__, cls)
return f"{clsname}.{arg.name}"
return repr(arg)

def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
elif isinstance(arg, torch._ops.OpOverload):
qualified_name = _get_qualified_name(arg)
global_name = add_global(qualified_name, arg)
return f"{global_name}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
Expand Down Expand Up @@ -499,7 +505,7 @@ def emit_node(node : Node):

if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
if raw_name != repr(node):
Expand All @@ -508,7 +514,7 @@ def emit_node(node : Node):
elif node.op == 'call_method':
assert isinstance(node.target, str)
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
return
elif node.op == 'call_function':
Expand All @@ -517,14 +523,14 @@ def emit_node(node : Node):
if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}')
return

# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}')
return

qualified_name = _get_qualified_name(node.target)
Expand All @@ -536,7 +542,7 @@ def emit_node(node : Node):
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}')
return
body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
Expand Down
3 changes: 2 additions & 1 deletion torch/fx/proxy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import dis
import copy
import sys
Expand Down Expand Up @@ -286,7 +287,7 @@ def no_node(arg):
kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
return self.create_node("call_function", a.__class__, (), kwargs)

elif isinstance(a, base_types) or a is None or a is ...:
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
return a
raise NotImplementedError(f"argument of type: {type(a)}")

Expand Down

0 comments on commit 64932eb

Please sign in to comment.