Skip to content

Commit

Permalink
WIP: [FakeTensor] Workaround FFT ops with incorrect meta strides
Browse files Browse the repository at this point in the history
This registers a custom fake tensor implemantion for the core fft ops which
takes the strides from eager where possible, or otherwise uses unbacked symints
to represent the strides.

ghstack-source-id: c36c1e121f44133b6ba8c68d72e8efeb6ccae716
Pull Request resolved: pytorch#106319
  • Loading branch information
peterbell10 committed Jul 31, 2023
1 parent 61e5245 commit afa7e57
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 57 deletions.
20 changes: 0 additions & 20 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,7 +2731,6 @@ def forward(self, x):
aot_autograd_failures = {
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
Expand All @@ -2753,24 +2752,6 @@ def forward(self, x):
xfail('_segment_reduce', 'lengths'),
skip('nn.functional.nll_loss', ''), # UBSAN failure!

# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
xfail('fft.hfft', ''),
xfail('fft.hfftn', ''),
xfail('fft.ifft', ''),
xfail('fft.ihfft2', ''),
xfail('fft.ihfft', ''),
xfail('fft.ihfftn', ''),
xfail('fft.irfft2', ''),
xfail('fft.irfft', ''),
xfail('fft.irfftn', ''),
xfail('fft.rfft2', ''),
xfail('fft.rfft', ''),
xfail('fft.rfftn', ''),

xfail('stft', ''),

# Misc
xfail('to_sparse'),
xfail('corrcoef'),
Expand Down Expand Up @@ -2843,7 +2824,6 @@ def forward(self, x):
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
Expand Down
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,8 @@ def test_refs_are_in_decomp_table(self, op):

fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
xfail("_segment_reduce", "lengths"),
xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default
xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default
skip('nn.functional.ctc_loss'),
}

Expand Down
16 changes: 0 additions & 16 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,22 +1489,6 @@ def f(t):
xfail('nanquantile'),
xfail('narrow'),

# many complex operators incorrect striding, metadata
skip('fft.fft', ''),
skip('fft.hfft2', ''),
skip('fft.hfft', ''),
skip('fft.hfftn', ''),
skip('fft.ifft', ''),
skip('fft.ihfft2', ''),
skip('fft.ihfft', ''),
skip('fft.ihfftn', ''),
skip('fft.irfft2', ''),
skip('fft.irfft', ''),
skip('fft.irfftn', ''),
skip('fft.rfft2', ''),
skip('fft.rfft', ''),
skip('fft.rfftn', ''),

# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
Expand Down
68 changes: 47 additions & 21 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
import weakref
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from weakref import ReferenceType

import torch
Expand Down Expand Up @@ -372,11 +383,17 @@ def __call__(
op_implementations = []


def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def register_op_impl(
run_impl_check: Union[
Callable[[OpOverload], bool], OpOverload, Sequence[OpOverload]
]
):
def impl_decorator(op_impl):
global op_implementations
if isinstance(run_impl_check, OpOverload):
op_implementations.append((lambda func: func == run_impl_check, op_impl))
elif isinstance(run_impl_check, Sequence):
op_implementations.append((lambda func: func in run_impl_check, op_impl))
else:
op_implementations.append((run_impl_check, op_impl))

Expand Down Expand Up @@ -428,26 +445,35 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
)


# Many of these operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration.
# TODO: fix registrations, add all existing impls that are correct
def unsupported_complex_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False
# Create a tensor with uknown strides
def fake_like_with_unbacked_strides(fake_mode, func, like):
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_dynamic_stride_ops:
raise UnsupportedOperatorException(func)

shape, dtype, device = like.shape, like.dtype, like.device
stride = [fake_mode.shape_env.create_unbacked_symint() for _ in shape]
with in_kernel_invocation_manager(fake_mode):
out = aten.empty_strided(shape, stride=stride, dtype=dtype, device=device)
out._set_conj(like.is_conj())
out._set_neg(like.is_neg())
return FakeTensor(fake_mode, out, device)

op_name = op.name()
if "fft" in op_name:
return True
return False

@register_op_impl([aten._fft_r2c.default, aten._fft_c2c.default, aten._fft_c2r.default])
def fft_stride_workaround(fake_mode, func, *args, **kwargs):
# This is a workaround for the FFT meta implmentations having incorrect strides
def extractor(input, *args, **kwargs):
return input

# These operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration
@register_op_impl(unsupported_complex_op)
def unsupported_fft(fake_mode, func, *args, **kwargs):
raise UnsupportedOperatorException(func)
input = extractor(*args, **kwargs)
if not input._has_symbolic_sizes_strides and fake_mode.allow_fallback_kernels:
# For static shapes, we can fall back to eager for the real strides
return run_fallback_kernel(fake_mode, func, args, kwargs, None)

# For dynamic shapes, we use unbacked symints for the strides
with in_kernel_invocation_manager(fake_mode):
output = func(*args, **kwargs)
return fake_like_with_unbacked_strides(fake_mode, func, output)


# Dont default to default device handling,
Expand Down Expand Up @@ -1397,9 +1423,9 @@ def dispatch(self, func, types, args=(), kwargs=None):
# TODO - we should be use the prim aten impl
# TODO - fix prims complex ops
if (
"prims::" in func._schema.name
func.namespace == "prims"
and hasattr(func, "prim_meta_impl")
and not unsupported_complex_op(func)
and "fft" not in func.name()
):
with self:
return func.prim_meta_impl(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,7 @@ def __init__(
self, *,
allow_scalar_outputs=True,
allow_dynamic_output_shape_ops=True,
allow_dynamic_stride_ops=True,
# NB: These are legacy configuration that help us make good choices
# when the constraint/dynamic dims are not explicitly passed to us.
# Ideally we will fix all call sites to be explicit and not have
Expand Down

0 comments on commit afa7e57

Please sign in to comment.