-
Notifications
You must be signed in to change notification settings - Fork 729
Add test for InverseMelScale #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ca2dcd5
92ab3cd
90602f0
aa819e1
227db11
d1e2cd2
c57e01b
b5cec06
297d80a
7f44b62
2e80d9f
75c2591
e985299
de3e426
b36568a
668f0f6
6137916
e142f6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -410,6 +410,25 @@ def test_batch_MelScale(self): | |
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
self.assertTrue(torch.allclose(computed, expected)) | ||
|
||
def test_batch_InverseMelScale(self): | ||
n_fft = 8 | ||
n_mels = 32 | ||
n_stft = 5 | ||
mel_spec = torch.randn(2, n_mels, 32) ** 2 | ||
|
||
# Single then transform then batch | ||
expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1) | ||
|
||
# Batch then transform | ||
computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1)) | ||
|
||
# shape = (3, 2, n_mels, 32) | ||
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
|
||
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield | ||
vincentqb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# exactly same result. For this reason, tolerance is very relaxed here. | ||
self.assertTrue(torch.allclose(computed, expected, atol=1.0)) | ||
|
||
def test_batch_compute_deltas(self): | ||
specgram = torch.randn(2, 31, 2786) | ||
|
||
|
@@ -509,5 +528,97 @@ def test_scriptmodule_TimeMasking(self): | |
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False) | ||
|
||
|
||
class TestLibrosaConsistency(unittest.TestCase): | ||
test_dirpath = None | ||
test_dir = None | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir() | ||
|
||
def _to_librosa(self, sound): | ||
return sound.cpu().numpy().squeeze() | ||
|
||
def _get_sample_data(self, *asset_paths, **kwargs): | ||
file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths) | ||
|
||
sound, sample_rate = torchaudio.load(file_path, **kwargs) | ||
return sound.mean(dim=0, keepdim=True), sample_rate | ||
|
||
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') | ||
def test_MelScale(self): | ||
mthrok marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""MelScale transform is comparable to that of librosa""" | ||
n_fft = 2048 | ||
n_mels = 256 | ||
hop_length = n_fft // 4 | ||
|
||
# Prepare spectrogram input. We use torchaudio to compute one. | ||
sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3') | ||
spec_ta = F.spectrogram( | ||
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft, | ||
hop_length=hop_length, win_length=n_fft, power=2, normalized=False) | ||
spec_lr = spec_ta.cpu().numpy().squeeze() | ||
# Perform MelScale with torchaudio and librosa | ||
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta) | ||
melspec_lr = librosa.feature.melspectrogram( | ||
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, | ||
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None) | ||
# Note: Using relaxed rtol instead of atol | ||
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using self.assertTrue might yield a nicer message if this fails. In the future and in a separate PR we might want to look into introducing some of the Unittest extensions that PyTorch implements that'll enable things such as self.assertAllClose and also does torch.Tensor specific checks such as dtype,memory layout etc. . allclose might do upcasting, broadcasting etc., but actually we care that those properties match. cc @vincentqb There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I found it opposite. Using
whereas
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combined with your comment on parameterized test, I think reorganizing test structure and using PyTorch's helper functions to show a good example of how to write a test will be great benefit for all developers. |
||
|
||
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') | ||
def test_InverseMelScale(self): | ||
"""InverseMelScale transform is comparable to that of librosa""" | ||
n_fft = 2048 | ||
n_mels = 256 | ||
n_stft = n_fft // 2 + 1 | ||
hop_length = n_fft // 4 | ||
|
||
# Prepare mel spectrogram input. We use torchaudio to compute one. | ||
sound, sample_rate = self._get_sample_data( | ||
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14) | ||
spec_orig = F.spectrogram( | ||
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft, | ||
hop_length=hop_length, win_length=n_fft, power=2, normalized=False) | ||
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig) | ||
melspec_lr = melspec_ta.cpu().numpy().squeeze() | ||
# Perform InverseMelScale with torch audio and librosa | ||
spec_ta = transforms.InverseMelScale( | ||
n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta) | ||
spec_lr = librosa.feature.inverse.mel_to_stft( | ||
melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None) | ||
spec_lr = torch.from_numpy(spec_lr[None, ...]) | ||
|
||
# Align dimensions | ||
# librosa does not return power spectrogram while torchaudio returns power spectrogram | ||
spec_orig = spec_orig.sqrt() | ||
spec_ta = spec_ta.sqrt() | ||
|
||
threshold = 2.0 | ||
# This threshold was choosen empirically, based on the following observation | ||
# | ||
# torch.dist(spec_lr, spec_ta, p=float('inf')) | ||
# >>> tensor(1.9666) | ||
# | ||
# The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise. | ||
# This is because they use different approximation algorithms and resulting values can live | ||
# in different magnitude. (although most of them are very close) | ||
# See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm | ||
# See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf | ||
# distance over frequencies. | ||
assert torch.allclose(spec_ta, spec_lr, atol=threshold) | ||
|
||
threshold = 1700.0 | ||
# This threshold was choosen empirically, based on the following observations | ||
# | ||
# torch.dist(spec_orig, spec_ta, p=1) | ||
# >>> tensor(1644.3516) | ||
# torch.dist(spec_orig, spec_lr, p=1) | ||
# >>> tensor(1420.7103) | ||
# torch.dist(spec_lr, spec_ta, p=1) | ||
# >>> tensor(943.2759) | ||
assert torch.dist(spec_orig, spec_ta, p=1) < threshold | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if it makes sense to vary some of these parameters a bit in a few subsequent tests. In particular to also include edge cases to see what the error behavior is. Unless this is verified through other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with running test on multiple parameters.
However, I think that it can be accomplished better with reorganizing the whole test suite.
Right now Tester class contains all kinds of test, (like batch test, torch script test, librosa compatibility test etc...) and it was hard to tell what type of test I should add and where to add.
By creating separate test suite for different test, it will be much easier.
So I think, I would add that kind parameterized test later, and add similar things to the existing ones.