Skip to content

Commit

Permalink
Revert D24269034: [fx] Refactor Tracer so that find_module and root a…
Browse files Browse the repository at this point in the history
…rgs creation could be overridden by implementations

Test Plan: revert-hammer

Differential Revision:
D24269034 (7b2e8be)

Original commit changeset: d7b67f2349dd

fbshipit-source-id: 7dd709b585f82d52d9b9973508137e36d5b5871e
  • Loading branch information
malfet authored and facebook-github-bot committed Oct 20, 2020
1 parent cda88e8 commit 8f12c0e
Showing 1 changed file with 25 additions and 31 deletions.
56 changes: 25 additions & 31 deletions torch/fx/symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS

def _find_module(root: torch.nn.Module, m: torch.nn.Module):
for n, p in root.named_modules():
if m is p:
return n
raise NameError('module is not installed as a submodule')

def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
co = fn.__code__
co_flags = co.co_flags & ~HAS_VARSTUFF
Expand Down Expand Up @@ -113,32 +119,34 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo
"""
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)

def path_of_module(self, mod):
for n, p in self.root.named_modules():
if mod is p:
return n
raise NameError('module is not installed as a submodule')

def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
module_qualified_name = self.path_of_module(m)
def call_module(self, m: torch.nn.Module, module_qualified_name: str, forward: Callable[..., Any], args, kwargs):
if not self.is_leaf_module(m, module_qualified_name):
return forward(*args, **kwargs)
return self.create_proxy('call_module', module_qualified_name, args, kwargs)

def create_args_for_root(self, root_fn, is_module):
co = root_fn.__code__
def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
if isinstance(root, torch.nn.Module):
self.root = root
fn = type(root).forward
else:
self.root = torch.nn.Module()
fn = root
self.graph = Graph()

assert isinstance(fn, FunctionType)
co = fn.__code__
total_args = co.co_argcount + co.co_kwonlyargcount
names_iter = iter(co.co_varnames)
args : List[Any] = []
skip_arg_idx = 0
if is_module:
if isinstance(root, torch.nn.Module):
skip_arg_idx = 1
next(names_iter) # skip self
args.append(self.root)
args.append(root)

def proxy_placeholder(name: str):
return self.create_proxy('placeholder', name, (), {},
type_expr=root_fn.__annotations__.get(name, None))
type_expr=fn.__annotations__.get(name, None))

args.extend(proxy_placeholder(next(names_iter)) for _ in range(skip_arg_idx, total_args))

Expand All @@ -148,31 +156,17 @@ def proxy_placeholder(name: str):
args.append(proxy_placeholder('*' + next(names_iter)))
if co.co_flags & inspect.CO_VARKEYWORDS:
args.append(proxy_placeholder('**' + next(names_iter)))
root_fn = _patch_function(root_fn, len(args))

return root_fn, args

def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
is_module = isinstance(root, torch.nn.Module)
if is_module:
self.root = root
fn = type(root).forward
else:
self.root = torch.nn.Module()
fn = root
self.graph = Graph()

assert isinstance(fn, FunctionType)

fn, args = self.create_args_for_root(fn, is_module)
fn = _patch_function(fn, len(args))

orig_call = torch.nn.Module.__call__

def module_call_wrapper(mod, *args, **kwargs):
module_qualified_name = _find_module(self.root, mod)

def forward(*args, **kwargs):
return orig_call(mod, *args, **kwargs)

return self.call_module(mod, forward, args, kwargs)
return self.call_module(mod, module_qualified_name, forward, args, kwargs)

try:
torch.nn.Module.__call__ = module_call_wrapper
Expand Down

0 comments on commit 8f12c0e

Please sign in to comment.