From b52120b2dd911897329959d9c8c40a981828f6f2 Mon Sep 17 00:00:00 2001 From: borundev Date: Thu, 13 May 2021 02:27:20 +0100 Subject: [PATCH 1/3] TorchScript will raise an error if MelScale constructor is called without n_stft --- .../transforms/torchscript_consistency_impl.py | 1 + torchaudio/transforms.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 506ca9af6c..8b9e4d4d72 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -61,6 +61,7 @@ def test_AmplitudeToDB(self): def test_MelScale(self): spec_f = torch.rand((1, 201, 6)) + self._assert_consistency(T.MelScale(), spec_f) self._assert_consistency(T.MelScale(n_stft=201), spec_f) def test_MelSpectrogram(self): diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index c7a72c55a7..877fd0e0b1 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -276,14 +276,24 @@ def __init__(self, self.f_min = f_min self.norm = norm self.mel_scale = mel_scale + self.n_stft = n_stft assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) - fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( + fb = torch.empty(0) if self.n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) self.register_buffer('fb', fb) + def __prepare_scriptable__(self): + r""" + Prepare self to be scripted + Returns: + MelScale: self + """ + assert self.n_stft is not None, ValueError("n_stft must be provided at construction") + return self + def forward(self, specgram: Tensor) -> Tensor: r""" Args: From 511bbc479badcef9bdfed7e7fd91c625228cbdd8 Mon Sep 17 00:00:00 2001 From: borundev Date: Fri, 14 May 2021 18:01:00 +0100 Subject: [PATCH 2/3] scripting test based on self.fb --- torchaudio/transforms.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 877fd0e0b1..930bd88a8f 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -276,22 +276,26 @@ def __init__(self, self.f_min = f_min self.norm = norm self.mel_scale = mel_scale - self.n_stft = n_stft assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) - fb = torch.empty(0) if self.n_stft is None else F.create_fb_matrix( + fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) self.register_buffer('fb', fb) def __prepare_scriptable__(self): - r""" - Prepare self to be scripted + r"""If `self.fb` is empty, the `forward` method will try to resize the parameter, + which does not work once the transform is scripted. However, this error does not happen + until the transform is executed. This is inconvenient especially if the resulting + TorchScript object is executed in other environments. Therefore, we check the + validity of `self.fb` here and fail if the resulting TS does not work. + Returns: MelScale: self """ - assert self.n_stft is not None, ValueError("n_stft must be provided at construction") + if self.fb.numel() == 0: + raise ValueError("n_stft must be provided at construction") return self def forward(self, specgram: Tensor) -> Tensor: From a11096497bbb37df6dc9cddc8bf2480631500548 Mon Sep 17 00:00:00 2001 From: borundev Date: Fri, 14 May 2021 18:03:28 +0100 Subject: [PATCH 3/3] bug fix --- .../transforms/torchscript_consistency_impl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 8b9e4d4d72..5258343181 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -59,9 +59,12 @@ def test_AmplitudeToDB(self): spec = torch.rand((6, 201)) self._assert_consistency(T.AmplitudeToDB(), spec) + def test_MelScale_invalid(self): + with self.assertRaises(ValueError): + torch.jit.script(T.MelScale()) + def test_MelScale(self): spec_f = torch.rand((1, 201, 6)) - self._assert_consistency(T.MelScale(), spec_f) self._assert_consistency(T.MelScale(n_stft=201), spec_f) def test_MelSpectrogram(self):