Skip to content

Commit

Permalink
Update sparse_funcs to include primtorch types (#107421)
Browse files Browse the repository at this point in the history
Fixes #107335.

A few issues have been identified while enabling this test and filed:
#105986
#108204
#108205

Pull Request resolved: #107421
Approved by: https://github.com/ezyang
  • Loading branch information
ekamiti authored and pytorchmergebot committed Sep 5, 2023
1 parent e27ddd2 commit 0ef2556
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 6 deletions.
17 changes: 12 additions & 5 deletions test/test_spectral_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,18 @@ def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
# TODO: Remove torch.half error when complex32 is fully implemented
sample = first_sample(self, op.sample_inputs(device, dtype))
device_type = torch.device(device).type
# FIXME: https://github.com/pytorch/pytorch/issues/108204
default_msg = (
r"(Unsupported dtype|"
r"FFT doesn't support (tensors*|transforms) of type|"
r"expected scalar type \w+ but found|)"
)
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
err_msg = "Unsupported dtype "
err_msg = default_msg
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
else:
err_msg = "Unsupported dtype "
err_msg = default_msg
with self.assertRaisesRegex(RuntimeError, err_msg):
op(sample.input, *sample.args, **sample.kwargs)

Expand Down Expand Up @@ -444,11 +450,12 @@ def test_fftn_round_trip(self, device, dtype):
allowed_dtypes=[torch.float, torch.cfloat])
def test_fftn_invalid(self, device, dtype, op):
a = torch.rand(10, 10, 10, device=device, dtype=dtype)

with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
# FIXME: https://github.com/pytorch/pytorch/issues/108205
errMsg = r"(dims must be unique|duplicate value in the list of dims)"
with self.assertRaisesRegex(RuntimeError, errMsg):
op(a, dim=(0, 1, 0))

with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
with self.assertRaisesRegex(RuntimeError, errMsg):
op(a, dim=(2, -1))

with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21585,7 +21585,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)]
binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)]
binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo))
spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)]
spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)]
sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse]
Expand Down
157 changes: 157 additions & 0 deletions torch/testing/_internal/opinfo/definitions/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,34 +650,110 @@ def mt(shape, **kwargs):
SpectralFuncPythonRefInfo(
"_refs.fft.fft",
torch_opinfo_name="fft.fft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifft",
torch_opinfo_name="fft.ifft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfft",
torch_opinfo_name="fft.rfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.irfft",
torch_opinfo_name="fft.irfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
# TODO: internally promoted to complex64 so not rejected
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfft",
torch_opinfo_name="fft.hfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfft",
torch_opinfo_name="fft.ihfft",
skips=(
# _refs.fft.* functions have inconsistent behavior for empty tensors
# https://github.com/pytorch/pytorch/issues/105986
DecorateInfo(unittest.expectedFailure, "TestFFT", "test_empty_fft"),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.fftn",
torch_opinfo_name="fft.fftn",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
"TestFFT",
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifftn",
torch_opinfo_name="fft.ifftn",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
"TestFFT",
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfftn",
Expand All @@ -686,14 +762,67 @@ def mt(shape, **kwargs):
SpectralFuncPythonRefInfo(
"_refs.fft.irfftn",
torch_opinfo_name="fft.irfftn",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
"TestFFT",
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfftn",
torch_opinfo_name="fft.hfftn",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
"TestFFT",
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108204
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fft_half_and_bfloat16_errors",
dtypes=[torch.bfloat16],
),
# FIXME: https://github.com/pytorch/pytorch/issues/108205
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fftn_invalid",
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfftn",
torch_opinfo_name="fft.ihfftn",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 2e-4}),
"TestFFT",
"test_reference_nd",
)
],
skips=(
# FIXME: https://github.com/pytorch/pytorch/issues/108205
DecorateInfo(
unittest.expectedFailure,
"TestFFT",
"test_fftn_invalid",
),
),
),
SpectralFuncPythonRefInfo(
"_refs.fft.fft2",
Expand All @@ -702,6 +831,13 @@ def mt(shape, **kwargs):
SpectralFuncPythonRefInfo(
"_refs.fft.ifft2",
torch_opinfo_name="fft.ifft2",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
"TestFFT",
"test_reference_nd",
)
],
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfft2",
Expand All @@ -710,14 +846,35 @@ def mt(shape, **kwargs):
SpectralFuncPythonRefInfo(
"_refs.fft.irfft2",
torch_opinfo_name="fft.irfft2",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
"TestFFT",
"test_reference_nd",
)
],
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfft2",
torch_opinfo_name="fft.hfft2",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
"TestFFT",
"test_reference_nd",
)
],
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfft2",
torch_opinfo_name="fft.ihfft2",
decorators=[
DecorateInfo(
precisionOverride({torch.float: 2e-4}),
"TestFFT",
"test_reference_nd",
)
],
),
PythonRefInfo(
"_refs.fft.fftshift",
Expand Down

0 comments on commit 0ef2556

Please sign in to comment.