Skip to content

Commit

Permalink
[FakeTensor] Fallback to eager for FFT ops
Browse files Browse the repository at this point in the history
FFT ops meta implementations don't have the correct strides, so currently
FakeTensor just raises we need to fall back to eager.

ghstack-source-id: df7a43236c27ccf5b7ebeaf1248d8dd03b48e58d
Pull Request resolved: pytorch#106319
  • Loading branch information
peterbell10 committed Aug 7, 2023
1 parent 1cc0026 commit 346ff0d
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 48 deletions.
38 changes: 18 additions & 20 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,7 +2748,6 @@ def forward(self, x):
aot_autograd_failures = {
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
Expand All @@ -2770,24 +2769,6 @@ def forward(self, x):
xfail('_segment_reduce', 'lengths'),
skip('nn.functional.nll_loss', ''), # UBSAN failure!

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''),

# Misc
xfail('to_sparse'),
xfail('corrcoef'),
Expand Down Expand Up @@ -2850,12 +2831,29 @@ def forward(self, x):
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')),

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
}

def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
Expand Down
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,8 @@ def test_refs_are_in_decomp_table(self, op):

fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
xfail("_segment_reduce", "lengths"),
xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default
xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default
skip('nn.functional.ctc_loss'),
}

Expand Down
35 changes: 17 additions & 18 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,22 +1488,6 @@ def f(t):
xfail('nanquantile'),
xfail('narrow'),

# many complex operators incorrect striding, metadata
skip('fft.fft', ''),
skip('fft.hfft2', ''),
skip('fft.hfft', ''),
skip('fft.hfftn', ''),
skip('fft.ifft', ''),
skip('fft.ihfft2', ''),
skip('fft.ihfft', ''),
skip('fft.ihfftn', ''),
skip('fft.irfft2', ''),
skip('fft.irfft', ''),
skip('fft.irfftn', ''),
skip('fft.rfft2', ''),
skip('fft.rfft', ''),
skip('fft.rfftn', ''),

# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
Expand All @@ -1524,8 +1508,6 @@ def f(t):
xfail('repeat_interleave'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),

xfail("stft"),
}

symbolic_tensor_failures = {
Expand Down Expand Up @@ -1583,6 +1565,23 @@ def f(t):
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.ihfft2', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
Expand Down
31 changes: 22 additions & 9 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,7 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
)


# Many of these operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration.
# TODO: fix registrations, add all existing impls that are correct
def unsupported_complex_op(op):
def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
Expand All @@ -443,10 +440,26 @@ def unsupported_complex_op(op):
return False


# These operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration
@register_op_impl(unsupported_complex_op)
def unsupported_fft(fake_mode, func, *args, **kwargs):
# These operators have meta implementations with incorrect strides
@register_op_impl(stride_incorrect_op)
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
# This is a workaround for meta implmentations with incorrect strides

def is_symbolic(x):
if isinstance(x, FakeTensor):
return x._has_symbolic_sizes_strides
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return True
return False

# For static shapes, we can fall back to eager for the real strides
if fake_mode.allow_fallback_kernels:
require_dynamic = any(
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
)
if not require_dynamic:
return run_fallback_kernel(fake_mode, func, args, kwargs, None)

raise UnsupportedOperatorException(func)


Expand Down Expand Up @@ -1436,7 +1449,7 @@ def dispatch(self, func, types, args=(), kwargs=None):
if (
"prims::" in func._schema.name
and hasattr(func, "prim_meta_impl")
and not unsupported_complex_op(func)
and not stride_incorrect_op(func)
):
with self:
return func.prim_meta_impl(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ def print_repro_on_failure(repro_str):
# NB: Hacking the exception args is the cleanest way I've found to append
# failure reproduction info without poisoning the stack trace.
if len(e.args) >= 1:
e.args = (str(e.args[0]) + f"\n{repro_str}",) + e.args[1:]
e.args = (f"{e.args[0]}\n{repro_str}", *e.args[1:])
raise

# "min_satisfying_examples" setting has been deprecated in hypothesis
Expand Down

0 comments on commit 346ff0d

Please sign in to comment.