From 120e93d48feca1896f2a08ab196c2221c62c891d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 23 Aug 2022 13:33:07 -0700 Subject: [PATCH] Force RNN modules to be inlined They call Tensor.set_ internally with Storage, which is no go for AOTAutograd. Inline into them so that we can graph break. Fixes https://github.com/pytorch/functorch/issues/586 Test strategy: ``` ./benchmarks/torchbench.py --inductor -dcuda --no-skip -k tts_angular ``` Note that inductor is still failing, but differently, after this PR. Signed-off-by: Edward Z. Yang --- torchdynamo/allowed_functions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchdynamo/allowed_functions.py b/torchdynamo/allowed_functions.py index 7b17461b03..ffc04f8885 100644 --- a/torchdynamo/allowed_functions.py +++ b/torchdynamo/allowed_functions.py @@ -116,7 +116,12 @@ def _allowed_function_ids(): def _is_allowed_module_prefix(obj): allowed_modules = ("torch", "math") - disallowed_modules = "torch.optim." + # torch.nn.modules.rnn is disallowed because these modules internally + # flatten their parameters. This flattening process will call + # Tensor.set_ with a Storage, and Storages cannot be traced with + # AOTAutograd; so we need to graph-break. To ensure this, we inline + # these functions, rather than keep them opaque-ly in the graph. + disallowed_modules = ("torch.optim.", "torch.nn.modules.rnn.") allowed_modules_dot = tuple([x + "." for x in allowed_modules]) module = inspect.getmodule(obj) if module is None: @@ -124,7 +129,7 @@ def _is_allowed_module_prefix(obj): mod_name = module.__name__ - if mod_name.startswith(disallowed_modules): + if any(mod_name.startswith(m) for m in disallowed_modules): return False return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)