diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 506ca9af6c..5258343181 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -59,6 +59,10 @@ def test_AmplitudeToDB(self): spec = torch.rand((6, 201)) self._assert_consistency(T.AmplitudeToDB(), spec) + def test_MelScale_invalid(self): + with self.assertRaises(ValueError): + torch.jit.script(T.MelScale()) + def test_MelScale(self): spec_f = torch.rand((1, 201, 6)) self._assert_consistency(T.MelScale(n_stft=201), spec_f) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index c7a72c55a7..930bd88a8f 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -284,6 +284,20 @@ def __init__(self, self.mel_scale) self.register_buffer('fb', fb) + def __prepare_scriptable__(self): + r"""If `self.fb` is empty, the `forward` method will try to resize the parameter, + which does not work once the transform is scripted. However, this error does not happen + until the transform is executed. This is inconvenient especially if the resulting + TorchScript object is executed in other environments. Therefore, we check the + validity of `self.fb` here and fail if the resulting TS does not work. + + Returns: + MelScale: self + """ + if self.fb.numel() == 0: + raise ValueError("n_stft must be provided at construction") + return self + def forward(self, specgram: Tensor) -> Tensor: r""" Args: