Skip to content

Commit

Permalink
Save/load TorchScript object in test (#1446)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 14, 2021
1 parent 931555c commit 5c696b5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
from parameterized import parameterized

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio_unittest.common_utils import (
skipIfRocm,
)


class Functional(common_utils.TestBaseMixin):
class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)

ts_func = torch.jit.script(func)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

output = func(tensor)
ts_output = ts_func(tensor)
if shape_only:
Expand Down Expand Up @@ -565,15 +569,18 @@ def func(tensor):
self._assert_consistency(func, tensor)


class FunctionalComplex:
class FunctionalComplex(TempDirMixin, TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None

def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
skipIfRocm,
TempDirMixin,
TestBaseMixin,
)


class Transforms(common_utils.TestBaseMixin):
class Transforms(TempDirMixin, TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

ts_transform = torch.jit.script(transform)
path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)

output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)
Expand All @@ -39,8 +44,8 @@ def test_AmplitudeToDB(self):
self._assert_consistency(T.AmplitudeToDB(), spec)

def test_MelScale(self):
spec_f = torch.rand((1, 6, 201))
self._assert_consistency(T.MelScale(), spec_f)
spec_f = torch.rand((1, 201, 6))
self._assert_consistency(T.MelScale(n_stft=201), spec_f)

def test_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
Expand Down Expand Up @@ -100,7 +105,7 @@ def test_SpectralCentroid(self):
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)


class TransformsComplex:
class TransformsComplex(TempDirMixin, TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
Expand All @@ -109,7 +114,10 @@ def _assert_consistency(self, transform, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.real_dtype)
ts_transform = torch.jit.script(transform)

path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
Expand Down

0 comments on commit 5c696b5

Please sign in to comment.