Skip to content

Commit

Permalink
WIP: [FakeTensor] Workaround FFT ops with incorrect meta strides
Browse files Browse the repository at this point in the history
This registers a custom fake tensor implemantion for the core fft ops which
takes the strides from eager where possible, or otherwise uses unbacked symints
to represent the strides.

ghstack-source-id: 0bf9c52c04d54284ce4074c6b98245c38f94e507
Pull Request resolved: pytorch#106319
  • Loading branch information
peterbell10 committed Jul 31, 2023
1 parent ced5e8b commit a0a2bd5
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 37 deletions.
16 changes: 0 additions & 16 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,22 +1489,6 @@ def f(t):
xfail('nanquantile'),
xfail('narrow'),

# many complex operators incorrect striding, metadata
skip('fft.fft', ''),
skip('fft.hfft2', ''),
skip('fft.hfft', ''),
skip('fft.hfftn', ''),
skip('fft.ifft', ''),
skip('fft.ihfft2', ''),
skip('fft.ihfft', ''),
skip('fft.ihfftn', ''),
skip('fft.irfft2', ''),
skip('fft.irfft', ''),
skip('fft.irfftn', ''),
skip('fft.rfft2', ''),
skip('fft.rfft', ''),
skip('fft.rfftn', ''),

# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
Expand Down
67 changes: 46 additions & 21 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
import weakref
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from weakref import ReferenceType

import torch
Expand Down Expand Up @@ -372,11 +383,17 @@ def __call__(
op_implementations = []


def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def register_op_impl(
run_impl_check: Union[
Callable[[OpOverload], bool], OpOverload, Sequence[OpOverload]
]
):
def impl_decorator(op_impl):
global op_implementations
if isinstance(run_impl_check, OpOverload):
op_implementations.append((lambda func: func == run_impl_check, op_impl))
elif isinstance(run_impl_check, Sequence):
op_implementations.append((lambda func: func in run_impl_check, op_impl))
else:
op_implementations.append((run_impl_check, op_impl))

Expand Down Expand Up @@ -428,26 +445,34 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
)


# Many of these operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration.
# TODO: fix registrations, add all existing impls that are correct
def unsupported_complex_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False
# Create a tensor with uknown strides
def empty_with_unbacked_strides(fake_mode, func, shape, dtype, device):
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_dynamic_stride_ops:
raise UnsupportedOperatorException(func)

op_name = op.name()
if "fft" in op_name:
return True
return False
stride = [fake_mode.shape_env.create_unbacked_symint() for _ in shape]
with in_kernel_invocation_manager(fake_mode):
out = aten.empty_strided(shape, stride=stride, dtype=dtype, device=device)
return FakeTensor(fake_mode, out, device)


@register_op_impl([aten._fft_r2c.default, aten._fft_c2c.default, aten._fft_c2r.default])
def fft_stride_workaround(fake_mode, func, *args, **kwargs):
# This is a workaround for the FFT meta implmentations having incorrect strides
def extractor(input, *args, **kwargs):
return input

input = extractor(*args, **kwargs)
if not input._has_symbolic_sizes_strides and fake_mode.allow_fallback_kernels:
# For static shapes, we can fall back to eager for the real strides
return run_fallback_kernel(fake_mode, func, args, kwargs, None)

# These operators mutate striding in place and output conj depending on input
# that is not reflected in meta registration
@register_op_impl(unsupported_complex_op)
def unsupported_fft(fake_mode, func, *args, **kwargs):
raise UnsupportedOperatorException(func)
# For dynamic shapes, we use unbacked symints for the strides
with in_kernel_invocation_manager(fake_mode):
output = func(*args, **kwargs)
return empty_with_unbacked_strides(
fake_mode, func, output.shape, output.dtype, output.device
)


# Dont default to default device handling,
Expand Down Expand Up @@ -1397,9 +1422,9 @@ def dispatch(self, func, types, args=(), kwargs=None):
# TODO - we should be use the prim aten impl
# TODO - fix prims complex ops
if (
"prims::" in func._schema.name
func.namespace == "prims"
and hasattr(func, "prim_meta_impl")
and not unsupported_complex_op(func)
and "fft" not in func.name()
):
with self:
return func.prim_meta_impl(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,7 @@ def __init__(
self, *,
allow_scalar_outputs=True,
allow_dynamic_output_shape_ops=True,
allow_dynamic_stride_ops=True,
# NB: These are legacy configuration that help us make good choices
# when the constraint/dynamic dims are not explicitly passed to us.
# Ideally we will fix all call sites to be explicit and not have
Expand Down

0 comments on commit a0a2bd5

Please sign in to comment.