Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,39 @@ def test_mu_law_companding(self):
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)

def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100)
melscale_transform = transforms.MelScale()
melscale_transform(specgram)

melscale_transform_copy = transforms.MelScale(n_stft=1000)
melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

fb = melscale_transform.fb
fb_copy = melscale_transform_copy.fb

self.assertEqual(fb_copy.size(), (1000, 128))
self.assertTrue(torch.allclose(fb, fb_copy))

def test_melspectrogram_load_save(self):
waveform = self.waveform.float()
mel_spectrogram_transform = transforms.MelSpectrogram()
mel_spectrogram_transform(waveform)

mel_spectrogram_transform_copy = transforms.MelSpectrogram()
mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

window = mel_spectrogram_transform.spectrogram.window
window_copy = mel_spectrogram_transform_copy.spectrogram.window

fb = mel_spectrogram_transform.mel_scale.fb
fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

self.assertTrue(torch.allclose(window, window_copy))
# the default for n_fft = 400 and n_mels = 128
self.assertEqual(fb_copy.size(), (201, 128))
self.assertTrue(torch.allclose(fb, fb_copy))

def test_mel2(self):
top_db = 80.
s2db = transforms.AmplitudeToDB('power', top_db)
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non
hop_length=self.hop_length,
pad=self.pad, window_fn=window_fn, power=2,
normalized=False, wkwargs=wkwargs)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to read off the size of the spectrogram instead of recomputing it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not unless the spectrogram computes a specgram and we look at the tensor's dimension

Copy link
Contributor

@vincentqb vincentqb Aug 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that constant n_fft // 2 + 1 is used in a few places within Spectrogram, we could attach it to Spectrogram. If that's our inferred default dimension, we could also just attach a default_dimension to Spectrogram, in case that ever changes, but that would add clutter to the Spectrogram interface.

Alright, I'm ok with having the value recomputed here. If there's ever a change in this, the test will fail :)


@torch.jit.script_method
def forward(self, waveform):
Expand Down