Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Added partial concrete values for symbolic tracing #51609

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_fx.py
Expand Up @@ -1597,6 +1597,18 @@ def test_no_mutation(self):
with self.assertRaisesRegex(NotImplementedError, "new_args"):
x[0] = 4

def test_partial_trace(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
if y:
return 2 * x
else:
return x
mod = Foo()
mod_true = symbolic_trace(mod, concrete_args={'y': True})
mod_false = symbolic_trace(mod, concrete_args={'y': False})
self.assertEqual(mod_true(3), 6)
self.assertEqual(mod_false(3), 3)

def run_getitem_target():
from torch.fx.symbolic_trace import _wrapped_methods_to_patch
Expand Down
13 changes: 8 additions & 5 deletions torch/fx/symbolic_trace.py
Expand Up @@ -223,7 +223,7 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu
return forward(*args, **kwargs)
return self.create_proxy('call_module', module_qualified_name, args, kwargs)

def create_args_for_root(self, root_fn, is_module):
def create_args_for_root(self, root_fn, is_module, concrete_args=None):
"""
Create ``placeholder`` nodes corresponding to the signature of the ``root``
Module. This method introspects root's signature and emits those
Expand All @@ -249,6 +249,8 @@ def create_args_for_root(self, root_fn, is_module):
sig = inspect.signature(fn_for_analysis)

def proxy_placeholder(name: str):
if concrete_args is not None and name in concrete_args:
return concrete_args[name]
if name[0] == '*':
default = () # type: ignore
else:
Expand All @@ -269,7 +271,7 @@ def proxy_placeholder(name: str):

return root_fn, args

def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, any]] = None) -> Graph:
Chillee marked this conversation as resolved.
Show resolved Hide resolved
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.
Expand Down Expand Up @@ -315,7 +317,7 @@ def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]):
assert isinstance(fn, FunctionType)

fn_globals = fn.__globals__ # run before it gets patched
fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module))
fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concrete_args)

parameter_proxy_cache : Dict[str, Proxy] = {} # Reduce number of get_attr calls

Expand Down Expand Up @@ -585,7 +587,7 @@ def my_custom_function(x, y):
_wrapped_fns_to_patch.append((f.f_globals, fn_name))
return fn_or_name

def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
def symbolic_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, any]] = None) -> GraphModule:
"""Symbolic tracing API

Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
Expand All @@ -594,12 +596,13 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
Args:
root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should not be treated as Proxies.

Returns:
GraphModule: a Module created from the recorded operations from ``root``.

"""
tracer = Tracer()
graph = tracer.trace(root)
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)