diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index b460115a88..78a842916e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -37,6 +37,14 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward +:hidden:`InverseMelScale` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: InverseMelScale + + .. automethod:: forward + + :hidden:`MelSpectrogram` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_transforms.py b/test/test_transforms.py index 267ca16384..de308ed0ef 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -410,6 +410,25 @@ def test_batch_MelScale(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + def test_batch_InverseMelScale(self): + n_fft = 8 + n_mels = 32 + n_stft = 5 + mel_spec = torch.randn(2, n_mels, 32) ** 2 + + # Single then transform then batch + expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1)) + + # shape = (3, 2, n_mels, 32) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + + # Because InverseMelScale runs SGD on randomly initialized values so they do not yield + # exactly same result. For this reason, tolerance is very relaxed here. + self.assertTrue(torch.allclose(computed, expected, atol=1.0)) + def test_batch_compute_deltas(self): specgram = torch.randn(2, 31, 2786) @@ -509,5 +528,97 @@ def test_scriptmodule_TimeMasking(self): _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False) +class TestLibrosaConsistency(unittest.TestCase): + test_dirpath = None + test_dir = None + + @classmethod + def setUpClass(cls): + cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir() + + def _to_librosa(self, sound): + return sound.cpu().numpy().squeeze() + + def _get_sample_data(self, *asset_paths, **kwargs): + file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths) + + sound, sample_rate = torchaudio.load(file_path, **kwargs) + return sound.mean(dim=0, keepdim=True), sample_rate + + @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') + def test_MelScale(self): + """MelScale transform is comparable to that of librosa""" + n_fft = 2048 + n_mels = 256 + hop_length = n_fft // 4 + + # Prepare spectrogram input. We use torchaudio to compute one. + sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3') + spec_ta = F.spectrogram( + sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft, + hop_length=hop_length, win_length=n_fft, power=2, normalized=False) + spec_lr = spec_ta.cpu().numpy().squeeze() + # Perform MelScale with torchaudio and librosa + melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta) + melspec_lr = librosa.feature.melspectrogram( + S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, + win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None) + # Note: Using relaxed rtol instead of atol + assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3) + + @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') + def test_InverseMelScale(self): + """InverseMelScale transform is comparable to that of librosa""" + n_fft = 2048 + n_mels = 256 + n_stft = n_fft // 2 + 1 + hop_length = n_fft // 4 + + # Prepare mel spectrogram input. We use torchaudio to compute one. + sound, sample_rate = self._get_sample_data( + 'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14) + spec_orig = F.spectrogram( + sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft, + hop_length=hop_length, win_length=n_fft, power=2, normalized=False) + melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig) + melspec_lr = melspec_ta.cpu().numpy().squeeze() + # Perform InverseMelScale with torch audio and librosa + spec_ta = transforms.InverseMelScale( + n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta) + spec_lr = librosa.feature.inverse.mel_to_stft( + melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None) + spec_lr = torch.from_numpy(spec_lr[None, ...]) + + # Align dimensions + # librosa does not return power spectrogram while torchaudio returns power spectrogram + spec_orig = spec_orig.sqrt() + spec_ta = spec_ta.sqrt() + + threshold = 2.0 + # This threshold was choosen empirically, based on the following observation + # + # torch.dist(spec_lr, spec_ta, p=float('inf')) + # >>> tensor(1.9666) + # + # The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise. + # This is because they use different approximation algorithms and resulting values can live + # in different magnitude. (although most of them are very close) + # See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm + # See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf + # distance over frequencies. + assert torch.allclose(spec_ta, spec_lr, atol=threshold) + + threshold = 1700.0 + # This threshold was choosen empirically, based on the following observations + # + # torch.dist(spec_orig, spec_ta, p=1) + # >>> tensor(1644.3516) + # torch.dist(spec_orig, spec_lr, p=1) + # >>> tensor(1420.7103) + # torch.dist(spec_lr, spec_ta, p=1) + # >>> tensor(943.2759) + assert torch.dist(spec_orig, spec_ta, p=1) < threshold + + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 707252a9f3..553d2fa403 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -14,6 +14,7 @@ 'GriffinLim', 'AmplitudeToDB', 'MelScale', + 'InverseMelScale', 'MelSpectrogram', 'MFCC', 'MuLawEncoding', @@ -233,6 +234,90 @@ def forward(self, specgram): return mel_specgram +class InverseMelScale(torch.nn.Module): + r"""Solve for a normal STFT from a mel frequency STFT, using a conversion + matrix. This uses triangular filter banks. + + It minimizes the euclidian norm between the input mel-spectrogram and the product between + the estimated spectrogram and the filter banks using SGD. + + Args: + n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. + n_mels (int): Number of mel filterbanks. (Default: ``128``) + sample_rate (int): Sample rate of audio signal. (Default: ``16000``) + f_min (float): Minimum frequency. (Default: ``0.``) + f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``) + max_iter (int): Maximum number of optimization iterations. + tolerance_loss (float): Value of loss to stop optimization at. + tolerance_change (float): Difference in losses to stop optimization at. + sgdargs (dict): Arguments for the SGD optimizer. + """ + __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss', + 'tolerance_change', 'sgdargs'] + + def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000, + tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None): + super(InverseMelScale, self).__init__() + self.n_mels = n_mels + self.sample_rate = sample_rate + self.f_max = f_max or float(sample_rate // 2) + self.f_min = f_min + self.max_iter = max_iter + self.tolerance_loss = tolerance_loss + self.tolerance_change = tolerance_change + self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9} + + assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) + + fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) + self.register_buffer('fb', fb) + + def forward(self, melspec): + r""" + Args: + melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time) + + Returns: + torch.Tensor: Linear scale spectrogram of size (..., freq, time) + """ + # pack batch + shape = melspec.size() + melspec = melspec.view(-1, shape[-2], shape[-1]) + + n_mels, time = shape[-2], shape[-1] + freq, _ = self.fb.size() # (freq, n_mels) + melspec = melspec.transpose(-1, -2) + assert self.n_mels == n_mels + + specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True, + dtype=melspec.dtype, device=melspec.device) + + optim = torch.optim.SGD([specgram], **self.sgdargs) + + loss = float('inf') + for _ in range(self.max_iter): + optim.zero_grad() + diff = melspec - specgram.matmul(self.fb) + new_loss = diff.pow(2).sum(axis=-1).mean() + # take sum over mel-frequency then average over other dimensions + # so that loss threshold is applied par unit timeframe + new_loss.backward() + optim.step() + specgram.data = specgram.data.clamp(min=0) + + new_loss = new_loss.item() + if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change: + break + loss = new_loss + + specgram.requires_grad_(False) + specgram = specgram.clamp(min=0).transpose(-1, -2) + + # unpack batch + specgram = specgram.view(shape[:-2] + (freq, time)) + return specgram + + class MelSpectrogram(torch.nn.Module): r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram and MelScale.