Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`InverseMelScale`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: InverseMelScale

.. automethod:: forward


:hidden:`MelSpectrogram`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
111 changes: 111 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,25 @@ def test_batch_MelScale(self):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_InverseMelScale(self):
n_fft = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it makes sense to vary some of these parameters a bit in a few subsequent tests. In particular to also include edge cases to see what the error behavior is. Unless this is verified through other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with running test on multiple parameters.
However, I think that it can be accomplished better with reorganizing the whole test suite.

Right now Tester class contains all kinds of test, (like batch test, torch script test, librosa compatibility test etc...) and it was hard to tell what type of test I should add and where to add.
By creating separate test suite for different test, it will be much easier.
So I think, I would add that kind parameterized test later, and add similar things to the existing ones.

n_mels = 32
n_stft = 5
mel_spec = torch.randn(2, n_mels, 32) ** 2

# Single then transform then batch
expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))

# shape = (3, 2, n_mels, 32)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))

# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self.assertTrue(torch.allclose(computed, expected, atol=1.0))

def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)

Expand Down Expand Up @@ -509,5 +528,97 @@ def test_scriptmodule_TimeMasking(self):
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)


class TestLibrosaConsistency(unittest.TestCase):
test_dirpath = None
test_dir = None

@classmethod
def setUpClass(cls):
cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir()

def _to_librosa(self, sound):
return sound.cpu().numpy().squeeze()

def _get_sample_data(self, *asset_paths, **kwargs):
file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths)

sound, sample_rate = torchaudio.load(file_path, **kwargs)
return sound.mean(dim=0, keepdim=True), sample_rate

@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_MelScale(self):
"""MelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
hop_length = n_fft // 4

# Prepare spectrogram input. We use torchaudio to compute one.
sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3')
spec_ta = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
spec_lr = spec_ta.cpu().numpy().squeeze()
# Perform MelScale with torchaudio and librosa
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
melspec_lr = librosa.feature.melspectrogram(
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using self.assertTrue might yield a nicer message if this fails.

In the future and in a separate PR we might want to look into introducing some of the Unittest extensions that PyTorch implements that'll enable things such as self.assertAllClose and also does torch.Tensor specific checks such as dtype,memory layout etc. . allclose might do upcasting, broadcasting etc., but actually we care that those properties match. cc @vincentqb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using self.assertTrue might yield a nicer message if this fails.

I found it opposite. Using assertTrue on torch.allclose only says

>       self.assertTrue(torch.allclose(spec_ta, spec_lr, atol=threshold))
E       AssertionError: False is not true

test/test_transforms.py:618: AssertionError

whereas assert says
(although this is still hard to read due to combination of multiple line messages and pytest's annotation)

>       assert torch.allclose(spec_ta, spec_lr, atol=threshold)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x121552eb0>(tensor([[[0.8752, 0.8655, 0.6858,  ..., 0.7232, 0.3609, 0.2115],\n         [0.7756, 1.1142, 0.9477,  ..., 0.8303, 1.985...0338, 0.0434, 0.0437,  ..., 0.0581, 0.0294, 0.0445],\n         [0.4310, 0.7263, 0.4167,  ..., 0.1131, 0.5628, 0.8183]]]), tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n          0.0000e+00, 0.0000e+00],\n         [7.7557e-0...e-04, 2.5366e-04],\n         [3.0709e-11, 5.1357e-12, 3.0357e-12,  ..., 0.0000e+00,\n          4.2634e-11, 1.0426e-10]]]), atol=1.0)
E        +    where <built-in method allclose of type object at 0x121552eb0> = torch.allclose

test/test_transforms.py:618: AssertionError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined with your comment on parameterized test, I think reorganizing test structure and using PyTorch's helper functions to show a good example of how to write a test will be great benefit for all developers.


@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
n_stft = n_fft // 2 + 1
hop_length = n_fft // 4

# Prepare mel spectrogram input. We use torchaudio to compute one.
sound, sample_rate = self._get_sample_data(
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
spec_orig = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
melspec_lr = melspec_ta.cpu().numpy().squeeze()
# Perform InverseMelScale with torch audio and librosa
spec_ta = transforms.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
spec_lr = librosa.feature.inverse.mel_to_stft(
melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
spec_lr = torch.from_numpy(spec_lr[None, ...])

# Align dimensions
# librosa does not return power spectrogram while torchaudio returns power spectrogram
spec_orig = spec_orig.sqrt()
spec_ta = spec_ta.sqrt()

threshold = 2.0
# This threshold was choosen empirically, based on the following observation
#
# torch.dist(spec_lr, spec_ta, p=float('inf'))
# >>> tensor(1.9666)
#
# The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise.
# This is because they use different approximation algorithms and resulting values can live
# in different magnitude. (although most of them are very close)
# See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
assert torch.allclose(spec_ta, spec_lr, atol=threshold)

threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
#
# torch.dist(spec_orig, spec_ta, p=1)
# >>> tensor(1644.3516)
# torch.dist(spec_orig, spec_lr, p=1)
# >>> tensor(1420.7103)
# torch.dist(spec_lr, spec_ta, p=1)
# >>> tensor(943.2759)
assert torch.dist(spec_orig, spec_ta, p=1) < threshold


if __name__ == '__main__':
unittest.main()
85 changes: 85 additions & 0 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
'GriffinLim',
'AmplitudeToDB',
'MelScale',
'InverseMelScale',
'MelSpectrogram',
'MFCC',
'MuLawEncoding',
Expand Down Expand Up @@ -233,6 +234,90 @@ def forward(self, specgram):
return mel_specgram


class InverseMelScale(torch.nn.Module):
r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.

It minimizes the euclidian norm between the input mel-spectrogram and the product between
the estimated spectrogram and the filter banks using SGD.

Args:
n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
n_mels (int): Number of mel filterbanks. (Default: ``128``)
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
f_min (float): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
max_iter (int): Maximum number of optimization iterations.
tolerance_loss (float): Value of loss to stop optimization at.
tolerance_change (float): Difference in losses to stop optimization at.
sgdargs (dict): Arguments for the SGD optimizer.
"""
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs']

def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000,
tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None):
super(InverseMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max or float(sample_rate // 2)
self.f_min = f_min
self.max_iter = max_iter
self.tolerance_loss = tolerance_loss
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb)

def forward(self, melspec):
r"""
Args:
melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)

Returns:
torch.Tensor: Linear scale spectrogram of size (..., freq, time)
"""
# pack batch
shape = melspec.size()
melspec = melspec.view(-1, shape[-2], shape[-1])

n_mels, time = shape[-2], shape[-1]
freq, _ = self.fb.size() # (freq, n_mels)
melspec = melspec.transpose(-1, -2)
assert self.n_mels == n_mels

specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
dtype=melspec.dtype, device=melspec.device)

optim = torch.optim.SGD([specgram], **self.sgdargs)

loss = float('inf')
for _ in range(self.max_iter):
optim.zero_grad()
diff = melspec - specgram.matmul(self.fb)
new_loss = diff.pow(2).sum(axis=-1).mean()
# take sum over mel-frequency then average over other dimensions
# so that loss threshold is applied par unit timeframe
new_loss.backward()
optim.step()
specgram.data = specgram.data.clamp(min=0)

new_loss = new_loss.item()
if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
break
loss = new_loss

specgram.requires_grad_(False)
specgram = specgram.clamp(min=0).transpose(-1, -2)

# unpack batch
specgram = specgram.view(shape[:-2] + (freq, time))
return specgram


class MelSpectrogram(torch.nn.Module):
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
Expand Down