Skip to content

Commit

Permalink
Address most of feedbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Mar 3, 2021
1 parent b3fbe41 commit b67b2c3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
6 changes: 2 additions & 4 deletions test/torchaudio_unittest/transforms/autograd_cpu_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import AutogradTestCase
from .autograd_test_impl import AutogradTestMixin


class AutogradCPUTest(AutogradTestCase, PytorchTestCase):
class AutogradCPUTest(AutogradTestMixin, PytorchTestCase):
device = 'cpu'
dtype = torch.float64
6 changes: 2 additions & 4 deletions test/torchaudio_unittest/transforms/autograd_cuda_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import torch
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoCuda,
)
from .autograd_test_impl import AutogradTestCase
from .autograd_test_impl import AutogradTestMixin


@skipIfNoCuda
class AutogradCUDATest(AutogradTestCase, PytorchTestCase):
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
device = 'cuda'
dtype = torch.float64
13 changes: 7 additions & 6 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from parameterized import parameterized
import torch
from torch.autograd import gradcheck, gradgradcheck
import torchaudio.transforms as T

Expand All @@ -8,16 +9,16 @@
)


class AutogradTestCase(TestBaseMixin):
def assert_grad(self, transform, *inputs, eps=1e-06, atol=1e-05, rtol=0.001):
transform = transform.to(self.device, self.dtype)
class AutogradTestMixin(TestBaseMixin):
def assert_grad(self, transform, *inputs):
transform = transform.to(dtype=torch.float64, device=self.device)

inputs_ = []
for i in inputs:
i.requires_grad = True
inputs_.append(i.to(dtype=self.dtype, device=self.device))
assert gradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)
assert gradgradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)
inputs_.append(i.to(dtype=torch.float64, device=self.device))
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_)

@parameterized.expand([
({'pad': 0, 'normalized': False, 'power': None}, ),
Expand Down

0 comments on commit b67b2c3

Please sign in to comment.