Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions test/dynamo/test_compiler_bisector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch._dynamo.test_case import TestCase
from torch._inductor import config
from torch._inductor.bisect_helper import BisectionManager
from torch.library import _scoped_library, Library
from torch.testing._internal.inductor_utils import HAS_CUDA


Expand All @@ -23,6 +24,23 @@

@requires_cuda
class TestCompilerBisector(TestCase):
test_ns = "_test_bisector"

def tearDown(self):
if hasattr(torch.ops, self.test_ns):
delattr(torch.ops, self.test_ns)
if hasattr(self, "lib"):
del self.lib.m
del self.lib

def get_op(self, name):
return getattr(getattr(torch.ops, self.test_ns), name).default

def get_lib(self):
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
self.lib = lib
return lib

def test_bad_decomp(self):
mod = import_module("torch._inductor.compile_fx")

Expand Down Expand Up @@ -78,6 +96,52 @@ def test_fn():
self.assertEqual(out.bisect_number, 1)
self.assertTrue("aten.exponential" in out.debug_info)

def test_crossref(self):
test_ns = "bisect_ops"
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor x) -> Tensor")
op = self.get_op("foo")

class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
with torch._C._AutoDispatchBelowAutograd():
with torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(
torch._C.DispatchKey.ADInplaceOrView
)
):
return op(x)

@staticmethod
def backward(ctx, gx):
return gx

def foo_impl(x):
return x.view_as(x).clone()

def foo_meta(x):
return x.view_as(x)

lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_meta, "Meta")

x = torch.tensor(3.14159 / 3, requires_grad=True)

def test_fn():
torch._dynamo.reset()

try:
torch.testing.assert_allclose(torch.compile(op)(x), op(x))
except Exception:
return False
return True

out = BisectionManager.do_bisect(test_fn)
self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref")

def test_emulate_precision_casts(self):
def test_fn():
torch._dynamo.reset()
Expand Down
32 changes: 30 additions & 2 deletions torch/_dynamo/backends/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ def run(args):
return run


def fake_crossref_boxed_nop(fx_g, example_inputs):
def run(args):
with torch._subclasses.CrossRefFakeMode():
return torch.fx.Interpreter(fx_g).boxed_run(args)

run._boxed_call = True
return run


def get_nop_func():
return (
boxed_nop
if not torch._functorch.config.fake_tensor_crossref
else fake_crossref_boxed_nop
)


# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
def aot_eager(
Expand Down Expand Up @@ -166,8 +183,8 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
with functorch_config.patch(config_patches):
return aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=boxed_nop,
bw_compiler=boxed_nop,
fw_compiler=get_nop_func(),
bw_compiler=get_nop_func(),
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
Expand All @@ -183,6 +200,17 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
)


def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs):
with functorch_config.patch(fake_tensor_crossref=True):
return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs)


register_backend(
name="aot_eager_decomp_partition_crossref",
compiler_fn=aot_eager_decomp_partition_crossref,
)


# AOT Autograd with torchscript backend. Default partitioner.
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...)
Expand Down
4 changes: 4 additions & 0 deletions torch/_functorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def remote_autograd_cache_default() -> Optional[bool]:
# tokens.
unlift_effect_tokens = False


# Run aot eager decomp partition with CrossRefFakeMode
fake_tensor_crossref = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to check, your PR made this both a config option and a backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's correct


# This mode specifies that we should also keep track of the real
# tensor along with the fake tensor, and do real compute. While
# seemingly this eliminates the whole point of fake tensors, there are
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/bisect_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __post_init__(self) -> None:
"decomposition"
), # number of decompositions we apply in tracing
], # TODO - add cse ?
# applies CrossRefFakeMode on invocation
"aot_eager_decomp_partition_crossref": [],
"inductor": [
BisectSubsystem(
"post_grad_passes"
Expand Down
6 changes: 6 additions & 0 deletions torch/_subclasses/fake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ def __init__(
*,
check_strides=True,
check_aliasing=True,
only_check_ops_with_meta=True,
):
super().__init__()
self.ignore_op_fn = (
ignore_op_fn if ignore_op_fn is not None else lambda fn: False
)
self.check_strides = check_strides
self.check_aliasing = check_aliasing
self.only_check_ops_with_meta = only_check_ops_with_meta

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
Expand All @@ -105,6 +107,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
aten.set_.source_Storage_storage_offset,
)
and not self.ignore_op_fn(func)
and (
not self.only_check_ops_with_meta
or torch._subclasses.fake_impls.has_meta(func)
)
and torch.Tag.dynamic_output_shape not in func.tags
and torch.Tag.inplace_view not in func.tags
and torch.Tag.data_dependent_output not in func.tags
Expand Down
Loading