Skip to content

Commit

Permalink
WIP: [FakeTensor] Fallback to eager for FFT ops
Browse files Browse the repository at this point in the history
FFT ops meta implementations don't have the correct strides, so currently
FakeTensor just raises we need to fall back to eager.

ghstack-source-id: 0af17f168826fc8611610a815f97935aa20065b7
Pull Request resolved: #106319
  • Loading branch information
peterbell10 committed Jul 31, 2023
1 parent 61e5245 commit ebaebe3
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 60 deletions.
40 changes: 20 additions & 20 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,7 +2731,6 @@ def forward(self, x):
aot_autograd_failures = {
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
Expand All @@ -2753,24 +2752,6 @@ def forward(self, x):
xfail('_segment_reduce', 'lengths'),
skip('nn.functional.nll_loss', ''), # UBSAN failure!

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''),

# Misc
xfail('to_sparse'),
xfail('corrcoef'),
Expand Down Expand Up @@ -2843,12 +2824,31 @@ def forward(self, x):
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')),

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.fftn', ''),
xfail('fft.fft2', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ifftn', ''),
xfail('fft.ifft2', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
}

def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
Expand Down
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,8 @@ def test_refs_are_in_decomp_table(self, op):

fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
xfail("_segment_reduce", "lengths"),
xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default
xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default
skip('nn.functional.ctc_loss'),
}

Expand Down
39 changes: 21 additions & 18 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,22 +1489,6 @@ def f(t):
xfail('nanquantile'),
xfail('narrow'),

# many complex operators incorrect striding, metadata
skip('fft.fft', ''),
skip('fft.hfft2', ''),
skip('fft.hfft', ''),
skip('fft.hfftn', ''),
skip('fft.ifft', ''),
skip('fft.ihfft2', ''),
skip('fft.ihfft', ''),
skip('fft.ihfftn', ''),
skip('fft.irfft2', ''),
skip('fft.irfft', ''),
skip('fft.irfftn', ''),
skip('fft.rfft2', ''),
skip('fft.rfft', ''),
skip('fft.rfftn', ''),

# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
Expand All @@ -1525,8 +1509,6 @@ def f(t):
xfail('repeat_interleave'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),

xfail("stft"),
}

symbolic_tensor_failures = {
Expand Down Expand Up @@ -1594,6 +1576,27 @@ def f(t):
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.fftn', ''),
xfail('fft.fft2', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ifftn', ''),
xfail('fft.ifft2', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.ihfft2', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
Expand Down
51 changes: 30 additions & 21 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
import weakref
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from weakref import ReferenceType

import torch
Expand Down Expand Up @@ -372,11 +383,17 @@ def __call__(
op_implementations = []


def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def register_op_impl(
run_impl_check: Union[
Callable[[OpOverload], bool], OpOverload, Sequence[OpOverload]
]
):
def impl_decorator(op_impl):
global op_implementations
if isinstance(run_impl_check, OpOverload):
op_implementations.append((lambda func: func == run_impl_check, op_impl))
elif isinstance(run_impl_check, Sequence):
op_implementations.append((lambda func: func in run_impl_check, op_impl))
else:
op_implementations.append((run_impl_check, op_impl))

Expand Down Expand Up @@ -428,25 +445,17 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
)


# Many of these operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration.
# TODO: fix registrations, add all existing impls that are correct
def unsupported_complex_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False

op_name = op.name()
if "fft" in op_name:
return True
return False
@register_op_impl([aten._fft_r2c.default, aten._fft_c2c.default, aten._fft_c2r.default])
def fft_stride_workaround(fake_mode, func, *args, **kwargs):
# This is a workaround for the FFT meta implmentations having incorrect strides
def extractor(input, *args, **kwargs):
return input

input = extractor(*args, **kwargs)
if not input._has_symbolic_sizes_strides and fake_mode.allow_fallback_kernels:
# For static shapes, we can fall back to eager for the real strides
return run_fallback_kernel(fake_mode, func, args, kwargs, None)

# These operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration
@register_op_impl(unsupported_complex_op)
def unsupported_fft(fake_mode, func, *args, **kwargs):
raise UnsupportedOperatorException(func)


Expand Down Expand Up @@ -1397,9 +1406,9 @@ def dispatch(self, func, types, args=(), kwargs=None):
# TODO - we should be use the prim aten impl
# TODO - fix prims complex ops
if (
"prims::" in func._schema.name
func.namespace == "prims"
and hasattr(func, "prim_meta_impl")
and not unsupported_complex_op(func)
and "fft" not in func.name()
):
with self:
return func.prim_meta_impl(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ def print_repro_on_failure(repro_str):
# NB: Hacking the exception args is the cleanest way I've found to append
# failure reproduction info without poisoning the stack trace.
if len(e.args) >= 1:
e.args = (e.args[0] + f"\n{repro_str}",) + e.args[1:]
e.args = (f"{e.args[0]}\n{repro_str}", *e.args[1:])
raise

# "min_satisfying_examples" setting has been deprecated in hypothesis
Expand Down

0 comments on commit ebaebe3

Please sign in to comment.