Skip to content

Commit

Permalink
[FX] Added partial concrete values for symbolic tracing (#51609)
Browse files Browse the repository at this point in the history
Summary:
Currently it's passed in a dict but might be worth considering whether we want to support other methods of passing it in (like a list corresponding to the positional args).

Pull Request resolved: #51609

Reviewed By: zou3519

Differential Revision: D26224464

Pulled By: Chillee

fbshipit-source-id: 305769db1a6e5fdcfb9e7dcacfdf153acd057a5a
  • Loading branch information
Chillee authored and facebook-github-bot committed Feb 4, 2021
1 parent 2e8e560 commit 2d305b9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
12 changes: 12 additions & 0 deletions test/test_fx.py
Expand Up @@ -1597,6 +1597,18 @@ def test_no_mutation(self):
with self.assertRaisesRegex(NotImplementedError, "new_args"):
x[0] = 4

def test_partial_trace(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
if y:
return 2 * x
else:
return x
mod = Foo()
mod_true = symbolic_trace(mod, concrete_args={'y': True})
mod_false = symbolic_trace(mod, concrete_args={'y': False})
self.assertEqual(mod_true(3), 6)
self.assertEqual(mod_false(3), 3)

def run_getitem_target():
from torch.fx.symbolic_trace import _wrapped_methods_to_patch
Expand Down
13 changes: 8 additions & 5 deletions torch/fx/symbolic_trace.py
Expand Up @@ -223,7 +223,7 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu
return forward(*args, **kwargs)
return self.create_proxy('call_module', module_qualified_name, args, kwargs)

def create_args_for_root(self, root_fn, is_module):
def create_args_for_root(self, root_fn, is_module, concrete_args=None):
"""
Create ``placeholder`` nodes corresponding to the signature of the ``root``
Module. This method introspects root's signature and emits those
Expand All @@ -249,6 +249,8 @@ def create_args_for_root(self, root_fn, is_module):
sig = inspect.signature(fn_for_analysis)

def proxy_placeholder(name: str):
if concrete_args is not None and name in concrete_args:
return concrete_args[name]
if name[0] == '*':
default = () # type: ignore
else:
Expand All @@ -269,7 +271,7 @@ def proxy_placeholder(name: str):

return root_fn, args

def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.
Expand Down Expand Up @@ -315,7 +317,7 @@ def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]):
assert isinstance(fn, FunctionType)

fn_globals = fn.__globals__ # run before it gets patched
fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module))
fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concrete_args)

parameter_proxy_cache : Dict[str, Proxy] = {} # Reduce number of get_attr calls

Expand Down Expand Up @@ -585,7 +587,7 @@ def my_custom_function(x, y):
_wrapped_fns_to_patch.append((f.f_globals, fn_name))
return fn_or_name

def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
def symbolic_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
"""Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
Expand All @@ -594,12 +596,13 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule:
Args:
root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should not be treated as Proxies.
Returns:
GraphModule: a Module created from the recorded operations from ``root``.
"""
tracer = Tracer()
graph = tracer.trace(root)
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)

0 comments on commit 2d305b9

Please sign in to comment.