Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wavernn example pipeline #749

Merged
merged 29 commits into from Jul 21, 2020
Merged

Conversation

jimchen90
Copy link
Contributor

@jimchen90 jimchen90 commented Jun 24, 2020

This is a reference example using WaveRNN model to train on LJSpeech. The structure will be inspired by #632 and WaveRNN.

There are at least a few more things to do:

  • Add bg_iterator and README.
  • Add torchaudio transforms on mel-spectrogram.

Related to #446

Stack:

Add MelResNet Block #705, #751
Add Upsampling Block #724
Add WaveRNN Model #735
Add example pipeline with WaveRNN #749
Remove underscore of wavernn model #810

cc @cpuhrsch @zhangguanheng66
internal

@jimchen90 jimchen90 requested a review from vincentqb June 24, 2020 22:54
@vincentqb vincentqb changed the title Add wavernn example Add wavernn example pipeline Jun 25, 2020
@jimchen90 jimchen90 mentioned this pull request Jun 25, 2020
Comment on lines 27 to 32
bits = 16 if self.mode == 'MOL' else self.n_bits

x = (x + 1.) * (2 ** bits - 1) / 2
x = torch.clamp(x, min=0, max=2 ** bits - 1)

return mel.squeeze(0), x.int().squeeze(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This converts representation of a waveform from [-1, 1] to 16-bit integer representation. For instance, this is done in load_wav already. Since this is an important step and can be generalized, let's make this into a function within torchaudio. One point of discussion is whether we add that directly in WaveRNN.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has been added as normalized_waveform_to_bits function in processing.py.

@jimchen90 jimchen90 mentioned this pull request Jun 25, 2020
@codecov
Copy link

codecov bot commented Jun 25, 2020

Codecov Report

Merging #749 into master will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #749      +/-   ##
==========================================
+ Coverage   89.87%   89.88%   +0.01%     
==========================================
  Files          34       34              
  Lines        2666     2660       -6     
==========================================
- Hits         2396     2391       -5     
+ Misses        270      269       -1     
Impacted Files Coverage Δ
torchaudio/models/_wavernn.py 99.03% <100.00%> (+0.85%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 209858e...b306f68. Read the comment docs.

@jimchen90 jimchen90 mentioned this pull request Jun 25, 2020
@vincentqb
Copy link
Contributor

btw, can you add a README.md to discuss the pipeline?

@vincentqb
Copy link
Contributor

It'd be nice to get a baseline by comparing the error you get here to the output obtained by Griffin-Lim, say, and in other norms too L^1, L^2 for instance.

@PetrochukM
Copy link

Since this is not the original WaveRNN model, I'd recommend renaming it as "FatchordWaveRNN" or something similar.

@jimchen90 jimchen90 force-pushed the pipeline_wavernn branch 2 times, most recently from dc0fd1b to 04bfe24 Compare July 8, 2020 13:33
@jimchen90 jimchen90 requested a review from vincentqb July 18, 2020 20:26
Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Minor things to address:

@jimchen90 jimchen90 merged commit fac1bba into pytorch:master Jul 21, 2020
mthrok pushed a commit to mthrok/audio that referenced this pull request Feb 26, 2021
…ials

Fix formatting and clean up tutorial on quantized transfer learning
mpc001 pushed a commit to mpc001/audio that referenced this pull request Aug 4, 2023
Co-authored-by: Shen Li <shenli@devfair017.maas>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants