Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FakeTensor] Workaround FFT ops with incorrect meta strides #106319

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 18 additions & 20 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,7 +2748,6 @@ def forward(self, x):
aot_autograd_failures = {
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
Expand All @@ -2770,24 +2769,6 @@ def forward(self, x):
xfail('_segment_reduce', 'lengths'),
skip('nn.functional.nll_loss', ''), # UBSAN failure!

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''),

# Misc
xfail('to_sparse'),
xfail('corrcoef'),
Expand Down Expand Up @@ -2850,12 +2831,29 @@ def forward(self, x):
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')),

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
}

def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
Expand Down
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,8 @@ def test_refs_are_in_decomp_table(self, op):

fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
xfail("_segment_reduce", "lengths"),
xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default
xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default
skip('nn.functional.ctc_loss'),
}

Expand Down
35 changes: 17 additions & 18 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,22 +1488,6 @@ def f(t):
xfail('nanquantile'),
xfail('narrow'),

# many complex operators incorrect striding, metadata
skip('fft.fft', ''),
skip('fft.hfft2', ''),
skip('fft.hfft', ''),
skip('fft.hfftn', ''),
skip('fft.ifft', ''),
skip('fft.ihfft2', ''),
skip('fft.ihfft', ''),
skip('fft.ihfftn', ''),
skip('fft.irfft2', ''),
skip('fft.irfft', ''),
skip('fft.irfftn', ''),
skip('fft.rfft2', ''),
skip('fft.rfft', ''),
skip('fft.rfftn', ''),

# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
Expand All @@ -1524,8 +1508,6 @@ def f(t):
xfail('repeat_interleave'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),

xfail("stft"),
}

symbolic_tensor_failures = {
Expand Down Expand Up @@ -1583,6 +1565,23 @@ def f(t):
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.ihfft2', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
Expand Down
31 changes: 22 additions & 9 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,7 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
)


# Many of these operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration.
# TODO: fix registrations, add all existing impls that are correct
def unsupported_complex_op(op):
def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
Expand All @@ -443,10 +440,26 @@ def unsupported_complex_op(op):
return False


# These operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration
@register_op_impl(unsupported_complex_op)
def unsupported_fft(fake_mode, func, *args, **kwargs):
# These operators have meta implementations with incorrect strides
@register_op_impl(stride_incorrect_op)
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
# This is a workaround for meta implmentations with incorrect strides

def is_symbolic(x):
if isinstance(x, FakeTensor):
return x._has_symbolic_sizes_strides
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return True
return False

# For static shapes, we can fall back to eager for the real strides
if fake_mode.allow_fallback_kernels:
require_dynamic = any(
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
)
if not require_dynamic:
return run_fallback_kernel(fake_mode, func, args, kwargs, None)

raise UnsupportedOperatorException(func)


Expand Down Expand Up @@ -1436,7 +1449,7 @@ def dispatch(self, func, types, args=(), kwargs=None):
if (
"prims::" in func._schema.name
and hasattr(func, "prim_meta_impl")
and not unsupported_complex_op(func)
and not stride_incorrect_op(func)
):
with self:
return func.prim_meta_impl(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ def print_repro_on_failure(repro_str):
# NB: Hacking the exception args is the cleanest way I've found to append
# failure reproduction info without poisoning the stack trace.
if len(e.args) >= 1:
e.args = (str(e.args[0]) + f"\n{repro_str}",) + e.args[1:]
e.args = (f"{e.args[0]}\n{repro_str}", *e.args[1:])
raise

# "min_satisfying_examples" setting has been deprecated in hypothesis
Expand Down