From 3483d567169bc79d381386c3e217e7aad9cb2ee8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 19 Apr 2024 09:37:01 -0700 Subject: [PATCH] [dynamo] Allow inlining oh hooks for the top module ghstack-source-id: d0e3ef1e8145ffd5479328c2db1989791e3d47b7 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124501 --- torch/_dynamo/eval_frame.py | 19 ++++++++++++------- torch/_dynamo/trace_rules.py | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 99a466523aadd..a6126d9649304 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -137,15 +137,14 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx): def _initialize(self): # Do this stuff in constructor to lower overhead slightly - if isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check( - self._orig_mod.forward - ): - # This may be a torch.nn.* instance in trace_rules.py which - # won't trigger a frame evaluation workaround to add an extra - # frame we can capture + + if trace_rules.should_wrap_top_module(): self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod)) else: # Invoke hooks outside of dynamo then pickup the inner frame + # TODO(export-team) - This part is only run for export. Remove the + # if condition when export has made adjustments to account for + # wrapping top module. self.forward = self.dynamo_ctx(self._orig_mod.__call__) if hasattr(self._orig_mod, "_initialize_hook"): @@ -1215,7 +1214,13 @@ def result_capturing_wrapper(*graph_inputs): automatic_dynamic_shapes=False, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, - ): + ), trace_rules.dont_wrap_top_module(): + # TODO(export-team) - discrepancy between torch.compile and + # torch.export because torch.compile is planning to inline the + # _call_impl (one level above forward) to inline hooks. But doing + # that for export breaks many tests because (1) tests are hardcoded + # to assume that tracing starts from forward, and (2) some + # discrepancies between strict and non strict mode. opt_f = optimize_assert( dynamo_normalization_capturing_compiler, hooks=Hooks( diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 763f6482cb92e..11a6dcd85d632 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -32,6 +32,7 @@ import unittest import weakref from collections import defaultdict +from contextlib import contextmanager from typing import Any, Callable, cast, Dict, List, Optional, Set, Union np: Optional[types.ModuleType] = None @@ -127,6 +128,23 @@ """ + +_TLS = threading.local() + +@contextmanager +def dont_wrap_top_module(): + old = getattr(_TLS, "wrap_top_module", True) + _TLS.wrap_top_module = False + try: + yield False + finally: + _TLS.wrap_top_module = old + + +def should_wrap_top_module(): + return getattr(_TLS, "wrap_top_module", True) + + manual_torch_name_rule_map = { "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,