Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torchaudio melscale 'slaney' instead of librosa in WaveRNN pipeline preprocessing #1444

Merged
merged 2 commits into from Apr 15, 2021

Conversation

discort
Copy link
Contributor

@discort discort commented Apr 9, 2021

cc #593

@discort
Copy link
Contributor Author

discort commented Apr 9, 2021

cc @mthrok @vincentqb

@@ -270,11 +270,11 @@ def main(args):

transforms = torch.nn.Sequential(
torchaudio.transforms.Spectrogram(**melkwargs),
LinearToMel(
torchaudio.transforms.MelScale(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have an extra bandwidth, can you add a test that verifies that librosa.feature.melspectrogram(S=spectrogram, ...) and MelScale()(spectrogram) produce the same result?

We do have a test for MelSpectrogram here

@nested_params(
[
param(n_fft=400, hop_length=200, n_mels=64),
param(n_fft=600, hop_length=100, n_mels=128),
param(n_fft=200, hop_length=50, n_mels=32),
],
[param(norm=norm) for norm in [None, 'slaney']],
[param(mel_scale=mel_scale) for mel_scale in ['htk', 'slaney']],
)
def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
sample_rate = 16000
waveform = get_sinusoid(
sample_rate=sample_rate, n_channels=1,
).to(self.device, self.dtype)
expected = librosa.feature.melspectrogram(
y=waveform[0].cpu().numpy(),
sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, norm=norm,
htk=mel_scale == "htk")
result = T.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels,
n_fft=n_fft, norm=norm, mel_scale=mel_scale,
).to(self.device, self.dtype)(waveform)[0]
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)

It will be nice to add one for MelScale as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking about MelSpectrogram, how about using MelSpectrogram directly?

Copy link
Contributor Author

@discort discort Apr 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking about MelSpectrogram, how about using MelSpectrogram directly?

Makes sense.
Fixed it and checked manually.

@mthrok
Copy link
Collaborator

mthrok commented Apr 13, 2021

@vincentqb Review please.

@vincentqb
Copy link
Contributor

LGTM, but we need to make sure the convergence of the model has not been affected. @discort -- have you reran the training loop to see? I'm rerunning the model on my side to see if the convergence profile changed.

@discort
Copy link
Contributor Author

discort commented Apr 15, 2021

thanks for getting back to me @vincentqb
Unfortunately I don't have resources to check the convergency. Let me know if there are any problems and next week I'll try to allocate some staff to check it manually.

@vincentqb
Copy link
Contributor

Alright, LGTM, thanks!

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@vincentqb vincentqb merged commit e061b26 into pytorch:master Apr 15, 2021
@discort discort deleted the replace_wavernn_librosa branch April 16, 2021 07:36
carolineechen pushed a commit to carolineechen/audio that referenced this pull request Apr 30, 2021
…ne preprocessing (pytorch#1444)

* Use torchaudio melscale instead of librosa
mthrok pushed a commit to mthrok/audio that referenced this pull request Dec 13, 2022
* Parametrizaitons tutorial

* Add remove_parametrization

* Correct name

* minor

* Proper version number

* Fuzzy spellcheck

* version

* Remove _tutorial from name

* Forgot to add the file...

* Rename parametrizations_tutorial by parametrizations everywhere
Add Alban's suggestions
Correct the code
Beter spacing after enumeration

* Minor

* Add more comments

* Minor

* Prefer unicode over math

* Minor

* minor

* Corrections

Co-authored-by: Brian Johnson <brianjo@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants