diff --git a/test/test_fx.py b/test/test_fx.py index f22105ac2562..d8f83406fb9d 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1597,6 +1597,18 @@ def test_no_mutation(self): with self.assertRaisesRegex(NotImplementedError, "new_args"): x[0] = 4 + def test_partial_trace(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + if y: + return 2 * x + else: + return x + mod = Foo() + mod_true = symbolic_trace(mod, concrete_args={'y': True}) + mod_false = symbolic_trace(mod, concrete_args={'y': False}) + self.assertEqual(mod_true(3), 6) + self.assertEqual(mod_false(3), 3) def run_getitem_target(): from torch.fx.symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index ec36735c8bcc..edfb90a2b003 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -223,7 +223,7 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu return forward(*args, **kwargs) return self.create_proxy('call_module', module_qualified_name, args, kwargs) - def create_args_for_root(self, root_fn, is_module): + def create_args_for_root(self, root_fn, is_module, concrete_args=None): """ Create ``placeholder`` nodes corresponding to the signature of the ``root`` Module. This method introspects root's signature and emits those @@ -249,6 +249,8 @@ def create_args_for_root(self, root_fn, is_module): sig = inspect.signature(fn_for_analysis) def proxy_placeholder(name: str): + if concrete_args is not None and name in concrete_args: + return concrete_args[name] if name[0] == '*': default = () # type: ignore else: @@ -269,7 +271,7 @@ def proxy_placeholder(name: str): return root_fn, args - def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: + def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: """ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` can either be an ``nn.Module`` instance or a Python callable. @@ -315,7 +317,7 @@ def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]): assert isinstance(fn, FunctionType) fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module)) + fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concrete_args) parameter_proxy_cache : Dict[str, Proxy] = {} # Reduce number of get_attr calls @@ -585,7 +587,7 @@ def my_custom_function(x, y): _wrapped_fns_to_patch.append((f.f_globals, fn_name)) return fn_or_name -def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule: +def symbolic_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: """Symbolic tracing API Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` @@ -594,12 +596,13 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule: Args: root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should not be treated as Proxies. Returns: GraphModule: a Module created from the recorded operations from ``root``. """ tracer = Tracer() - graph = tracer.trace(root) + graph = tracer.trace(root, concrete_args) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return GraphModule(tracer.root, graph, name)