From dc7b061683960d21e5fea4dff0cdbd1b86565306 Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 22 Oct 2024 15:30:22 -0700 Subject: [PATCH 1/2] Add debug backend that applies CrossRefFakeMode, use in compiler bisector [ghstack-poisoned] --- test/dynamo/test_compiler_bisector.py | 64 +++++++++++++++++++++++++++ torch/_dynamo/backends/debugging.py | 32 +++++++++++++- torch/_functorch/config.py | 4 ++ torch/_inductor/bisect_helper.py | 2 + torch/_subclasses/fake_utils.py | 7 +++ 5 files changed, 107 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index f89f935b18cce..a5671dd4a20a4 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -9,6 +9,7 @@ from torch._dynamo.test_case import TestCase from torch._inductor import config from torch._inductor.bisect_helper import BisectionManager +from torch.library import _scoped_library, Library from torch.testing._internal.inductor_utils import HAS_CUDA @@ -23,6 +24,23 @@ @requires_cuda class TestCompilerBisector(TestCase): + test_ns = "_test_bisector" + + def tearDown(self): + if hasattr(torch.ops, self.test_ns): + delattr(torch.ops, self.test_ns) + if hasattr(self, "lib"): + del self.lib.m + del self.lib + + def get_op(self, name): + return getattr(getattr(torch.ops, self.test_ns), name).default + + def get_lib(self): + lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 + self.lib = lib + return lib + def test_bad_decomp(self): mod = import_module("torch._inductor.compile_fx") @@ -78,6 +96,52 @@ def test_fn(): self.assertEqual(out.bisect_number, 1) self.assertTrue("aten.exponential" in out.debug_info) + def test_crossref(self): + test_ns = "bisect_ops" + with _scoped_library(self.test_ns, "FRAGMENT") as lib: + lib.define("foo(Tensor x) -> Tensor") + op = self.get_op("foo") + + class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python + with torch._C._AutoDispatchBelowAutograd(): + with torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet( + torch._C.DispatchKey.ADInplaceOrView + ) + ): + return op(x) + + @staticmethod + def backward(ctx, gx): + return gx + + def foo_impl(x): + return x.view_as(x).clone() + + def foo_meta(x): + return x.view_as(x) + + lib.impl("foo", Foo.apply, "Autograd") + lib.impl("foo", foo_impl, "CPU") + lib.impl("foo", foo_meta, "Meta") + + x = torch.tensor(3.14159 / 3, requires_grad=True) + + def test_fn(): + torch._dynamo.reset() + + try: + torch.testing.assert_allclose(torch.compile(op)(x), op(x)) + except Exception: + return False + return True + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref") + def test_emulate_precision_casts(self): def test_fn(): torch._dynamo.reset() diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index abd5111dbb1aa..94ed9b0865091 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -118,6 +118,23 @@ def run(args): return run +def fake_crossref_boxed_nop(fx_g, example_inputs): + def run(args): + with torch._subclasses.CrossRefFakeMode(): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +def get_nop_func(): + return ( + boxed_nop + if not torch._functorch.config.fake_tensor_crossref + else fake_crossref_boxed_nop + ) + + # Useful for debugging purpose # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. def aot_eager( @@ -166,8 +183,8 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs): with functorch_config.patch(config_patches): return aot_autograd( # these are taken from memory_efficient_fusion() - fw_compiler=boxed_nop, - bw_compiler=boxed_nop, + fw_compiler=get_nop_func(), + bw_compiler=get_nop_func(), # NB: lambda here is to delay import of inductor decompositions=lambda: import_module( "torch._inductor.compile_fx" @@ -183,6 +200,17 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs): ) +def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs): + with functorch_config.patch(fake_tensor_crossref=True): + return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs) + + +register_backend( + name="aot_eager_decomp_partition_crossref", + compiler_fn=aot_eager_decomp_partition_crossref, +) + + # AOT Autograd with torchscript backend. Default partitioner. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser # by using the relevant fuser with torch.jit.fuser(...) diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 8c042ee7ed56a..9d148de1aa794 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -162,6 +162,10 @@ def remote_autograd_cache_default() -> Optional[bool]: # tokens. unlift_effect_tokens = False + +# Run aot eager decomp partition with CrossRefFakeMode +fake_tensor_crossref = False + # This mode specifies that we should also keep track of the real # tensor along with the fake tensor, and do real compute. While # seemingly this eliminates the whole point of fake tensors, there are diff --git a/torch/_inductor/bisect_helper.py b/torch/_inductor/bisect_helper.py index b072aea53e529..5cb1dd5691804 100644 --- a/torch/_inductor/bisect_helper.py +++ b/torch/_inductor/bisect_helper.py @@ -53,6 +53,8 @@ def __post_init__(self) -> None: "decomposition" ), # number of decompositions we apply in tracing ], # TODO - add cse ? + # applies CrossRefFakeMode on invocation + "aot_eager_decomp_partition_crossref": [], "inductor": [ BisectSubsystem( "post_grad_passes" diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index 28fc7a4028917..2e3f5be03fc19 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -82,6 +82,7 @@ def __init__( *, check_strides=True, check_aliasing=True, + only_check_ops_with_meta=True, ): super().__init__() self.ignore_op_fn = ( @@ -89,11 +90,13 @@ def __init__( ) self.check_strides = check_strides self.check_aliasing = check_aliasing + self.only_check_ops_with_meta = only_check_ops_with_meta def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} fake_r = None + breakpoint() # empty_like excluded for now due to sparse complex # aten._to_dense.default this one is getting called with csc @@ -105,6 +108,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): aten.set_.source_Storage_storage_offset, ) and not self.ignore_op_fn(func) + and ( + not self.only_check_ops_with_meta + or torch._subclasses.fake_impls.has_meta(func) + ) and torch.Tag.dynamic_output_shape not in func.tags and torch.Tag.inplace_view not in func.tags and torch.Tag.data_dependent_output not in func.tags From 0a106ebc31ec4982c9c4cb414b4accc753ecfeec Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 24 Oct 2024 10:22:17 -0700 Subject: [PATCH 2/2] Update on "Add debug backend that applies CrossRefFakeMode, use in compiler bisector" I was debugging an internal ne divergence for a while that ended up being because of a bad meta. I added an explicit a config option and an explicit backend `aot_eager_decomp_partition_crossref` to enable the FakeCrossRefMode when running the graph. I added an explicit backend bc I suspect it will be useful for internal models but I'm also happy to leave as config option. It will only test ops that have meta to avoid memory overhead of hitting fallback path and running in eager. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned] --- torch/_subclasses/fake_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index 2e3f5be03fc19..c610ee9dbab40 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -96,7 +96,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} fake_r = None - breakpoint() # empty_like excluded for now due to sparse complex # aten._to_dense.default this one is getting called with csc