Skip to content

Commit

Permalink
[FakeTensor] Workaround FFT ops with incorrect meta strides (#106319)
Browse files Browse the repository at this point in the history
Currently there are FFT operators which raise `UnsupportedOperatorException`
because their meta implementations sometimes give incorrect strides. This works
around the problem for static shapes by falling back to eager. Though we still
don't support calls with dynamic shapes.
Pull Request resolved: #106319
Approved by: https://github.com/ezyang
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Aug 7, 2023
1 parent 66d90e8 commit d4d090e
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 48 deletions.
38 changes: 18 additions & 20 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,7 +2748,6 @@ def forward(self, x):
aot_autograd_failures = {
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
Expand All @@ -2770,24 +2769,6 @@ def forward(self, x):
xfail('_segment_reduce', 'lengths'),
skip('nn.functional.nll_loss', ''), # UBSAN failure!

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''),

# Misc
xfail('to_sparse'),
xfail('corrcoef'),
Expand Down Expand Up @@ -2850,12 +2831,29 @@ def forward(self, x):
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')),

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
}

def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
Expand Down
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,8 @@ def test_refs_are_in_decomp_table(self, op):

fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
xfail("_segment_reduce", "lengths"),
xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default
xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default
skip('nn.functional.ctc_loss'),
}

Expand Down
35 changes: 17 additions & 18 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,22 +1488,6 @@ def f(t):
xfail('nanquantile'),
xfail('narrow'),

# many complex operators incorrect striding, metadata
skip('fft.fft', ''),
skip('fft.hfft2', ''),
skip('fft.hfft', ''),
skip('fft.hfftn', ''),
skip('fft.ifft', ''),
skip('fft.ihfft2', ''),
skip('fft.ihfft', ''),
skip('fft.ihfftn', ''),
skip('fft.irfft2', ''),
skip('fft.irfft', ''),
skip('fft.irfftn', ''),
skip('fft.rfft2', ''),
skip('fft.rfft', ''),
skip('fft.rfftn', ''),

# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
Expand All @@ -1524,8 +1508,6 @@ def f(t):
xfail('repeat_interleave'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),

xfail("stft"),
}

symbolic_tensor_failures = {
Expand Down Expand Up @@ -1583,6 +1565,23 @@ def f(t):
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.ihfft2', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
Expand Down
31 changes: 22 additions & 9 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,7 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
)


# Many of these operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration.
# TODO: fix registrations, add all existing impls that are correct
def unsupported_complex_op(op):
def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
Expand All @@ -443,10 +440,26 @@ def unsupported_complex_op(op):
return False


# These operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration
@register_op_impl(unsupported_complex_op)
def unsupported_fft(fake_mode, func, *args, **kwargs):
# These operators have meta implementations with incorrect strides
@register_op_impl(stride_incorrect_op)
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
# This is a workaround for meta implmentations with incorrect strides

def is_symbolic(x):
if isinstance(x, FakeTensor):
return x._has_symbolic_sizes_strides
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return True
return False

# For static shapes, we can fall back to eager for the real strides
if fake_mode.allow_fallback_kernels:
require_dynamic = any(
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
)
if not require_dynamic:
return run_fallback_kernel(fake_mode, func, args, kwargs, None)

raise UnsupportedOperatorException(func)


Expand Down Expand Up @@ -1436,7 +1449,7 @@ def dispatch(self, func, types, args=(), kwargs=None):
if (
"prims::" in func._schema.name
and hasattr(func, "prim_meta_impl")
and not unsupported_complex_op(func)
and not stride_incorrect_op(func)
):
with self:
return func.prim_meta_impl(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ def print_repro_on_failure(repro_str):
# NB: Hacking the exception args is the cleanest way I've found to append
# failure reproduction info without poisoning the stack trace.
if len(e.args) >= 1:
e.args = (str(e.args[0]) + f"\n{repro_str}",) + e.args[1:]
e.args = (f"{e.args[0]}\n{repro_str}", *e.args[1:])
raise

# "min_satisfying_examples" setting has been deprecated in hypothesis
Expand Down

0 comments on commit d4d090e

Please sign in to comment.