Skip to content

Commit

Permalink
Run test only on CPU and CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 7, 2020
1 parent cc0bca4 commit 45a146f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions test/test_torch.py
Expand Up @@ -11888,7 +11888,7 @@ def test_fft_input_modification(self, device):
_ = torch.irfft(half_spectrum_copy, 2, signal_sizes=(2, 2))
self.assertEqual(half_spectrum, half_spectrum_copy)

@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
@dtypes(torch.double)
def test_istft_round_trip_simple_cases(self, device, dtype):
Expand All @@ -11901,7 +11901,7 @@ def _test(input, n_fft, length):
_test(torch.ones(4, dtype=dtype, device=device), 4, 4)
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)

@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
@dtypes(torch.double)
def test_istft_round_trip_various_params(self, device, dtype):
Expand Down Expand Up @@ -11979,6 +11979,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
for i, pattern in enumerate(patterns):
_test_istft_is_inverse_of_stft(pattern)

@onlyOnCPUAndCUDA
def test_istft_throws(self, device):
"""istft should throw exception for invalid parameters"""
stft = torch.zeros((3, 5, 2), device=device)
Expand All @@ -11994,7 +11995,7 @@ def test_istft_throws(self, device):
self.assertRaises(AssertionError, torch.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(AssertionError, torch.istft, torch.zeros((0, 3, 2)), 2)

@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@dtypes(torch.double)
def test_istft_of_sine(self, device, dtype):
def _test(amplitude, L, n):
Expand Down Expand Up @@ -12027,6 +12028,7 @@ def _test(amplitude, L, n):
_test(amplitude=80, L=9, n=6)
_test(amplitude=99, L=10, n=7)

@onlyOnCPUAndCUDA
@dtypes(torch.double)
def test_istft_linearity(self, device, dtype):
num_trials = 100
Expand Down Expand Up @@ -12090,7 +12092,7 @@ def _test(data_size, kwargs):
for data_size, kwargs in patterns:
_test(data_size, kwargs)

@skipCUDAIfRocm
@onlyOnCPUAndCUDA
def test_batch_istft(self, device):
original = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
Expand Down

0 comments on commit 45a146f

Please sign in to comment.