Skip to content

Commit

Permalink
Support default args in symbolic tracing (#47615)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47615

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D24865060

Pulled By: ansley

fbshipit-source-id: 32ff105a1fa9c4a8f00adc20e8d40d1b6bd7157f
  • Loading branch information
Ansley Ussery authored and facebook-github-bot committed Nov 11, 2020
1 parent a5e9fa1 commit e914a1b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 3 deletions.
37 changes: 37 additions & 0 deletions test/test_fx.py
Expand Up @@ -1104,6 +1104,43 @@ def forward(self, x):
traced = torch.fx.symbolic_trace(Foo())
assert(all('constant' not in node.target for node in traced.graph.nodes))

def test_single_default_arg(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, y=1):
return y

m = M()
self.checkGraphModule(m, ())
self.checkGraphModule(m, (3,))

def test_multiple_default_args(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, y=1, z=2):
return y + z

m = M()
self.checkGraphModule(m, ())
self.checkGraphModule(m, (3,))
self.checkGraphModule(m, (3, 4))

def test_regular_and_default_args(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y=1):
return x + y

m = M()
self.checkGraphModule(m, (2,))
self.checkGraphModule(m, (2, 3))


if __name__ == '__main__':
run_tests()
3 changes: 2 additions & 1 deletion torch/fx/__init__.py
Expand Up @@ -44,7 +44,8 @@ def forward(self, x):
The semantics are as follows:
- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on.
`target` is similarly the name of the argument. `args` and `kwargs` are don't-care. Placeholders correspond to
`target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument
denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to
the function parameters (e.g. `x`) in the graph printout.
- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the
fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy.
Expand Down
3 changes: 2 additions & 1 deletion torch/fx/graph.py
Expand Up @@ -344,7 +344,8 @@ def type_repr(o : Any):
if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
free_vars.append(f'{node.target}{maybe_type_annotation}')
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
if raw_name != node.name:
body.append(f'{node.name} = {raw_name}\n')
Expand Down
9 changes: 9 additions & 0 deletions torch/fx/proxy.py
Expand Up @@ -27,6 +27,15 @@ def proxy(self, node: Node) -> 'Proxy':

def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None):
'''
Create a Node from the given arguments, then return the Node
wrapped in a Proxy object.
If kind = 'placeholder', then we're creating a Node that
represents the parameter of a function. If we need to encode
a default parameter, we use the `args` tuple. `args` is
otherwise empty for `placeholder` Nodes.
'''
args_ = self.create_arg(args)
kwargs_ = self.create_arg(kwargs)
assert isinstance(args_, tuple)
Expand Down
9 changes: 8 additions & 1 deletion torch/fx/symbolic_trace.py
Expand Up @@ -125,8 +125,15 @@ def create_args_for_root(self, root_fn, is_module):
next(names_iter) # skip self
args.append(self.root)

sig = inspect.signature(fn_for_analysis)

def proxy_placeholder(name: str):
return self.create_proxy('placeholder', name, (), {},
if name[0] == '*':
default = () # type: ignore
else:
param = sig.parameters[name]
default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore
return self.create_proxy('placeholder', name, default, {},
type_expr=fn_for_analysis.__annotations__.get(name, None))

args.extend(proxy_placeholder(next(names_iter)) for _ in range(skip_arg_idx, total_args))
Expand Down

0 comments on commit e914a1b

Please sign in to comment.