diff --git a/test/test_fx.py b/test/test_fx.py index 349941c72f86..1796ad2e87ef 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1104,6 +1104,43 @@ def forward(self, x): traced = torch.fx.symbolic_trace(Foo()) assert(all('constant' not in node.target for node in traced.graph.nodes)) + def test_single_default_arg(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, y=1): + return y + + m = M() + self.checkGraphModule(m, ()) + self.checkGraphModule(m, (3,)) + + def test_multiple_default_args(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, y=1, z=2): + return y + z + + m = M() + self.checkGraphModule(m, ()) + self.checkGraphModule(m, (3,)) + self.checkGraphModule(m, (3, 4)) + + def test_regular_and_default_args(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y=1): + return x + y + + m = M() + self.checkGraphModule(m, (2,)) + self.checkGraphModule(m, (2, 3)) + if __name__ == '__main__': run_tests() diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index c7fbd6fbf0ea..f3804c515612 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -44,7 +44,8 @@ def forward(self, x): The semantics are as follows: - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. - `target` is similarly the name of the argument. `args` and `kwargs` are don't-care. Placeholders correspond to + `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. - `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 0c98ebe89447..d737a1a65629 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -344,7 +344,8 @@ def type_repr(o : Any): if node.op == 'placeholder': assert isinstance(node.target, str) maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - free_vars.append(f'{node.target}{maybe_type_annotation}') + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') raw_name = node.target.replace('*', '') if raw_name != node.name: body.append(f'{node.name} = {raw_name}\n') diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index d1672d332f14..317e039223a0 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -27,6 +27,15 @@ def proxy(self, node: Node) -> 'Proxy': def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr : Optional[Any] = None): + ''' + Create a Node from the given arguments, then return the Node + wrapped in a Proxy object. + + If kind = 'placeholder', then we're creating a Node that + represents the parameter of a function. If we need to encode + a default parameter, we use the `args` tuple. `args` is + otherwise empty for `placeholder` Nodes. + ''' args_ = self.create_arg(args) kwargs_ = self.create_arg(kwargs) assert isinstance(args_, tuple) diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index 44f0ffba98e0..20566bb58e6e 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -125,8 +125,15 @@ def create_args_for_root(self, root_fn, is_module): next(names_iter) # skip self args.append(self.root) + sig = inspect.signature(fn_for_analysis) + def proxy_placeholder(name: str): - return self.create_proxy('placeholder', name, (), {}, + if name[0] == '*': + default = () # type: ignore + else: + param = sig.parameters[name] + default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore + return self.create_proxy('placeholder', name, default, {}, type_expr=fn_for_analysis.__annotations__.get(name, None)) args.extend(proxy_placeholder(next(names_iter)) for _ in range(skip_arg_idx, total_args))