Skip to content

Commit

Permalink
Mark parts of spectral tests as slow (#46509)
Browse files Browse the repository at this point in the history
Summary:
According to https://app.circleci.com/pipelines/github/pytorch/pytorch/228154/workflows/31951076-b633-4391-bd0d-b2953c940876/jobs/8290059
TestFFTCUDA.test_fftn_backward_cuda_complex128 takes 242 seconds to finish, where most of the time spent checking 2nd gradient

Refactor common part of test_fft_backward and test_fftn_backward into _fft_grad_check_helper
Introduce `slowAwareTest` decorator
Split test into fast and slow parts by checking 2nd degree gradient only during the slow part

Pull Request resolved: #46509

Reviewed By: walterddr

Differential Revision: D24378901

Pulled By: malfet

fbshipit-source-id: 606670c2078480219905f63b9b278b835e760a66
  • Loading branch information
malfet authored and facebook-github-bot committed Oct 19, 2020
1 parent e7564b0 commit 172ed51
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 36 deletions.
67 changes: 31 additions & 36 deletions test/test_spectral_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import itertools

from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, _assertGradAndGradgradChecks)
(TestCase, run_tests, TEST_WITH_SLOW, TEST_NUMPY, TEST_LIBROSA, slowAwareTest)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, precisionOverride,
skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA)
from torch.autograd.gradcheck import gradgradcheck

from distutils.version import LooseVersion
from typing import Optional, List
Expand Down Expand Up @@ -307,6 +308,28 @@ def test_fft_half_errors(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "):
fn(x)


def _fft_grad_check_helper(self, fname, input, args):
torch_fn = getattr(torch.fft, fname)
# Workaround for gradcheck's poor support for complex input
# Use real input instead and put view_as_complex into the graph
if input.dtype.is_complex:
def test_fn(x):
out = torch_fn(torch.view_as_complex(x), *args)
return torch.view_as_real(out) if out.is_complex() else out
inputs = (torch.view_as_real(input).detach().requires_grad_(),)
else:
def test_fn(x):
out = torch_fn(x, *args)
return torch.view_as_real(out) if out.is_complex() else out
inputs = (input.detach().requires_grad_(),)

self.assertTrue(torch.autograd.gradcheck(test_fn, inputs))
if TEST_WITH_SLOW:
self.assertTrue(gradgradcheck(test_fn, inputs))


@slowAwareTest
@skipCPUIfNoMkl
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
Expand All @@ -321,7 +344,7 @@ def test_fft_backward(self, device, dtype):
# dim
(-1, 0),
# norm
(None, "forward", "backward", "ortho")
(None, "forward", "backward", "ortho") if TEST_WITH_SLOW else (None,)
))

fft_functions = ['fft', 'ifft', 'hfft', 'irfft']
Expand All @@ -330,27 +353,11 @@ def test_fft_backward(self, device, dtype):
fft_functions += ['rfft', 'ihfft']

for fname in fft_functions:
torch_fn = getattr(torch.fft, fname)

for iargs in test_args:
args = list(iargs)
input = args[0]
args = args[1:]

# Workaround for gradcheck's poor support for complex input
# Use real input instead and put view_as_complex into the graph
if dtype.is_complex:
def test_fn(x):
out = torch_fn(torch.view_as_complex(x), *args)
return torch.view_as_real(out) if out.is_complex() else out
inputs = (torch.view_as_real(input).detach().requires_grad_(),)
else:
def test_fn(x):
out = torch_fn(x, *args)
return torch.view_as_real(out) if out.is_complex() else out
inputs = (input.detach().requires_grad_(),)

_assertGradAndGradgradChecks(self, test_fn, inputs)
self._fft_grad_check_helper(fname, input, args)

# nd-fft tests

Expand Down Expand Up @@ -441,6 +448,7 @@ def test_fftn_round_trip(self, device, dtype):
self.assertEqual(x, y, exact_dtype=(
forward != torch.fft.fftn or x.is_complex()))

@slowAwareTest
@skipCPUIfNoMkl
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
Expand All @@ -457,7 +465,9 @@ def test_fftn_backward(self, device, dtype):
(1, None, (0,)),
(1, (11,), (0,)),
]
norm_modes = (None, "forward", "backward", "ortho")
if not TEST_WITH_SLOW:
transform_desc = [desc for desc in transform_desc if desc[0] < 3]
norm_modes = (None, "forward", "backward", "ortho") if TEST_WITH_SLOW else (None, )

fft_functions = ['fftn', 'ifftn', 'irfftn']
# Real-only functions
Expand All @@ -469,22 +479,7 @@ def test_fftn_backward(self, device, dtype):
input = torch.randn(*shape, device=device, dtype=dtype)

for fname, norm in product(fft_functions, norm_modes):
torch_fn = getattr(torch.fft, fname)

# Workaround for gradcheck's poor support for complex input
# Use real input instead and put view_as_complex into the graph
if dtype.is_complex:
def test_fn(x):
out = torch_fn(torch.view_as_complex(x), s, dim, norm)
return torch.view_as_real(out) if out.is_complex() else out
inputs = (torch.view_as_real(input).detach().requires_grad_(),)
else:
def test_fn(x):
out = torch_fn(x, s, dim, norm)
return torch.view_as_real(out) if out.is_complex() else out
inputs = (input.detach().requires_grad_(),)

_assertGradAndGradgradChecks(self, test_fn, inputs)
self._fft_grad_check_helper(fname, input, (s, dim, norm))

@skipCUDAIfRocm
@skipCPUIfNoMkl
Expand Down
5 changes: 5 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ def wrapper(*args, **kwargs):
return wrapper


def slowAwareTest(fn):
fn.__dict__['slow_test'] = True
return fn


def skipCUDAMemoryLeakCheckIf(condition):
def dec(fn):
if getattr(fn, '_do_cuda_memory_leak_check', True): # if current True
Expand Down

0 comments on commit 172ed51

Please sign in to comment.