Skip to content

Commit

Permalink
Force RNN modules to be inlined
Browse files Browse the repository at this point in the history
They call Tensor.set_ internally with Storage, which is no go for AOTAutograd.
Inline into them so that we can graph break.

Fixes pytorch/functorch#586

Test strategy:

```
./benchmarks/torchbench.py --inductor  -dcuda --no-skip -k tts_angular
```

Note that inductor is still failing, but differently, after this PR.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang committed Aug 23, 2022
1 parent ea455b7 commit 120e93d
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchdynamo/allowed_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,20 @@ def _allowed_function_ids():

def _is_allowed_module_prefix(obj):
allowed_modules = ("torch", "math")
disallowed_modules = "torch.optim."
# torch.nn.modules.rnn is disallowed because these modules internally
# flatten their parameters. This flattening process will call
# Tensor.set_ with a Storage, and Storages cannot be traced with
# AOTAutograd; so we need to graph-break. To ensure this, we inline
# these functions, rather than keep them opaque-ly in the graph.
disallowed_modules = ("torch.optim.", "torch.nn.modules.rnn.")
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
module = inspect.getmodule(obj)
if module is None:
return False

mod_name = module.__name__

if mod_name.startswith(disallowed_modules):
if any(mod_name.startswith(m) for m in disallowed_modules):
return False

return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
Expand Down

0 comments on commit 120e93d

Please sign in to comment.