Skip to content

Commit

Permalink
Add autograd test for T.Spectrogram/T.MelSpectrogram (#1340)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Mar 31, 2021
1 parent c0bfb03 commit e4a0bd2
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_cpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import AutogradTestMixin


class AutogradCPUTest(AutogradTestMixin, PytorchTestCase):
device = 'cpu'
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_cuda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoCuda,
)
from .autograd_test_impl import AutogradTestMixin


@skipIfNoCuda
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
device = 'cuda'
62 changes: 62 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import List

from parameterized import parameterized
import torch
from torch.autograd import gradcheck, gradgradcheck
import torchaudio.transforms as T

from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
)


class AutogradTestMixin(TestBaseMixin):
def assert_grad(
self,
transform: torch.nn.Module,
inputs: List[torch.Tensor],
*,
nondet_tol: float = 0.0,
):
transform = transform.to(dtype=torch.float64, device=self.device)

inputs_ = []
for i in inputs:
i.requires_grad = True
inputs_.append(i.to(dtype=torch.float64, device=self.device))
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)

@parameterized.expand([
({'pad': 0, 'normalized': False, 'power': None}, ),
({'pad': 3, 'normalized': False, 'power': None}, ),
({'pad': 0, 'normalized': True, 'power': None}, ),
({'pad': 3, 'normalized': True, 'power': None}, ),
({'pad': 0, 'normalized': False, 'power': 1.0}, ),
({'pad': 3, 'normalized': False, 'power': 1.0}, ),
({'pad': 0, 'normalized': True, 'power': 1.0}, ),
({'pad': 3, 'normalized': True, 'power': 1.0}, ),
({'pad': 0, 'normalized': False, 'power': 2.0}, ),
({'pad': 3, 'normalized': False, 'power': 2.0}, ),
({'pad': 0, 'normalized': True, 'power': 2.0}, ),
({'pad': 3, 'normalized': True, 'power': 2.0}, ),
])
def test_spectrogram(self, kwargs):
# replication_pad1d_backward_cuda is not deteministic and
# gives very small (~2.7756e-17) difference.
#
# See https://github.com/pytorch/pytorch/issues/54093
transform = T.Spectrogram(**kwargs)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

def test_melspectrogram(self):
# replication_pad1d_backward_cuda is not deteministic and
# gives very small (~2.7756e-17) difference.
#
# See https://github.com/pytorch/pytorch/issues/54093
sample_rate = 8000
transform = T.MelSpectrogram(sample_rate=sample_rate)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

0 comments on commit e4a0bd2

Please sign in to comment.