From 36d507141e5ec08c04365140600d21edf75056c3 Mon Sep 17 00:00:00 2001 From: Seungwon Park Date: Mon, 28 Oct 2019 18:18:45 +0900 Subject: [PATCH 1/6] remove weight_norm, add torch hub --- README.md | 20 +++++++++++++++++++- hubconf.py | 40 ++++++++++++++++++++++++++++++++++++++++ inference.py | 14 ++------------ model/generator.py | 33 +++++++++++++++++++++++++++++++++ model/res_stack.py | 5 +++++ 5 files changed, 99 insertions(+), 13 deletions(-) create mode 100644 hubconf.py diff --git a/README.md b/README.md index 8495be9..52519cf 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910 - MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](https://github.com/NVIDIA/waveglow). - This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio. -- TODO: Planning to publish pretrained model via [PyTorch Hub](https://pytorch.org/hub). +- Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub). ![](./assets/gd.png) @@ -27,6 +27,24 @@ pip install -r requirements.txt - `python trainer.py -c [config yaml file] -n [name of the run]` - `tensorboard --logdir logs/` +## Pretrained model + +Try with Google Colab: + +```python +import torch +vocoder = torch.hub.load('seungwonpark/melgan', 'melgan') +vocoder.eval() +mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here + +if torch.cuda.is_available(): + vocoder = vocoder.cuda() + mel = mel.cuda() + +with torch.no_grad(): + audio = vocoder(mel) +``` + ## Inference - `python inference.py -p [checkpoint path] -i [input mel path]` diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..4df2f3d --- /dev/null +++ b/hubconf.py @@ -0,0 +1,40 @@ +dependencies = ['torch'] +from model.generator import Generator + +model_params = { + 'nvidia_tacotron2_LJ11_epoch3200': { + 'mel_channel': 80, + 'model_url': '', + }, +} + + +def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True): + params = model_params[model_name] + model = Generator(params['mel_channel']) + + if pretrained: + state_dict = torch.hub.load_state_dict_from_url(params['model_url'], + progress=progress) + model.load_state_dict(state_dict['model_g']) + + model.eval(inference=True) + + return model + + +if __name__ == '__main__': + vocoder = torch.hub.load('seungwonpark/melgan', 'melgan') + mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here + + print('Input mel-spectrogram shape: {}'.format(mel.shape)) + + if torch.cuda.is_available(): + print('Moving data & model to GPU') + vocoder = vocoder.cuda() + mel = mel.cuda() + + with torch.no_grad(): + audio = vocoder.inference(mel) + + print('Output audio shape: {}'.format(audio.shape)) diff --git a/inference.py b/inference.py index 88e9b66..78d757f 100644 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ def main(args): model = Generator(hp.audio.n_mel_channels).cuda() model.load_state_dict(checkpoint['model_g']) - model.eval() + model.eval(inference=False) with torch.no_grad(): for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))): @@ -29,17 +29,7 @@ def main(args): mel = mel.unsqueeze(0) mel = mel.cuda() - # pad input mel with zeros to cut artifact - # see https://github.com/seungwonpark/melgan/issues/8 - zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).cuda() - mel = torch.cat((mel, zero), axis=2) - - audio = model(mel) - audio = audio.squeeze() # collapse all dimension except time axis - audio = audio[:-(hp.audio.hop_length*10)] - audio = MAX_WAV_VALUE * audio - audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE) - audio = audio.short() + audio = model.inference(hp, mel) audio = audio.cpu().detach().numpy() out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) diff --git a/model/generator.py b/model/generator.py index c6c9c4a..82d60d8 100644 --- a/model/generator.py +++ b/model/generator.py @@ -5,10 +5,13 @@ from .res_stack import ResStack #from res_stack import ResStack +MAX_WAV_VALUE = 32768.0 + class Generator(nn.Module): def __init__(self, mel_channel): super(Generator, self).__init__() + self.mel_channel = mel_channel self.generator = nn.Sequential( nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1, padding=3)), @@ -42,6 +45,36 @@ def forward(self, mel): mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram return self.generator(mel) + def eval(self, inference=False): + super(Generator, self).eval() + + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def remove_weight_norm(self): + for idx, layer in enumerate(self.generator): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except: + layer.remove_weight_norm() + + def inference(self, hp, mel): + # pad input mel with zeros to cut artifact + # see https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).to(mel.device) + mel = torch.cat((mel, zero), axis=2) + + audio = self.forward(mel) + audio = audio.squeeze() # collapse all dimension except time axis + audio = audio[:-(hp.audio.hop_length*10)] + audio = MAX_WAV_VALUE * audio + audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) + audio = audio.short() + + return audio + ''' to run this, fix diff --git a/model/res_stack.py b/model/res_stack.py index 6512409..37d9fc3 100644 --- a/model/res_stack.py +++ b/model/res_stack.py @@ -22,3 +22,8 @@ def forward(self, x): for layer in self.layers: x = x + layer(x) return x + + def remove_weight_norm(self): + for layer in self.layers: + nn.utils.remove_weight_norm(layer[1]) + nn.utils.remove_weight_norm(layer[3]) From 17a79ac7b48f9981082a8c8dc9d109b8aa300d0e Mon Sep 17 00:00:00 2001 From: Seungwon Park Date: Mon, 28 Oct 2019 18:35:18 +0900 Subject: [PATCH 2/6] add pretrain url --- hubconf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 4df2f3d..27d3e96 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,10 +1,11 @@ dependencies = ['torch'] +import torch from model.generator import Generator model_params = { 'nvidia_tacotron2_LJ11_epoch3200': { 'mel_channel': 80, - 'model_url': '', + 'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.1-alpha/nvidia_tacotron2_LJ11_epoch3200.pt', }, } From 8e6654bd1cdb638fd8fb47780d81243193adb1d9 Mon Sep 17 00:00:00 2001 From: Seungwon Park Date: Mon, 28 Oct 2019 18:36:58 +0900 Subject: [PATCH 3/6] specify branch --- hubconf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 27d3e96..393a502 100644 --- a/hubconf.py +++ b/hubconf.py @@ -25,7 +25,7 @@ def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progre if __name__ == '__main__': - vocoder = torch.hub.load('seungwonpark/melgan', 'melgan') + vocoder = torch.hub.load('seungwonpark/melgan[:hub]', 'melgan') mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here print('Input mel-spectrogram shape: {}'.format(mel.shape)) From f9586e5c8cb67e35de1d743431c5623fc96d4d3b Mon Sep 17 00:00:00 2001 From: Seungwon Park Date: Mon, 28 Oct 2019 18:38:18 +0900 Subject: [PATCH 4/6] fix --- hubconf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 393a502..d2bc180 100644 --- a/hubconf.py +++ b/hubconf.py @@ -25,7 +25,7 @@ def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progre if __name__ == '__main__': - vocoder = torch.hub.load('seungwonpark/melgan[:hub]', 'melgan') + vocoder = torch.hub.load('seungwonpark/melgan:hub', 'melgan') mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here print('Input mel-spectrogram shape: {}'.format(mel.shape)) From 24bce022cfc78b805bf6b740d5c745399072e003 Mon Sep 17 00:00:00 2001 From: Seungwon Park Date: Mon, 28 Oct 2019 18:43:47 +0900 Subject: [PATCH 5/6] don't use hp in .inference --- inference.py | 2 +- model/generator.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/inference.py b/inference.py index 78d757f..c60580a 100644 --- a/inference.py +++ b/inference.py @@ -29,7 +29,7 @@ def main(args): mel = mel.unsqueeze(0) mel = mel.cuda() - audio = model.inference(hp, mel) + audio = model.inference(mel) audio = audio.cpu().detach().numpy() out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) diff --git a/model/generator.py b/model/generator.py index 82d60d8..2614aee 100644 --- a/model/generator.py +++ b/model/generator.py @@ -60,15 +60,16 @@ def remove_weight_norm(self): except: layer.remove_weight_norm() - def inference(self, hp, mel): + def inference(self, mel): + hop_length = 256 # pad input mel with zeros to cut artifact # see https://github.com/seungwonpark/melgan/issues/8 - zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).to(mel.device) + zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) mel = torch.cat((mel, zero), axis=2) audio = self.forward(mel) audio = audio.squeeze() # collapse all dimension except time axis - audio = audio[:-(hp.audio.hop_length*10)] + audio = audio[:-(hop_length*10)] audio = MAX_WAV_VALUE * audio audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) audio = audio.short() From 4db45d90832f8f29efefcca902d6b74358376bca Mon Sep 17 00:00:00 2001 From: Seungwon Park Date: Mon, 28 Oct 2019 18:50:50 +0900 Subject: [PATCH 6/6] update README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 52519cf..c84d2bc 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install -r requirements.txt ## Pretrained model -Try with Google Colab: +Try with Google Colab: TODO ```python import torch @@ -42,7 +42,7 @@ if torch.cuda.is_available(): mel = mel.cuda() with torch.no_grad(): - audio = vocoder(mel) + audio = vocoder.inference(mel) ``` ## Inference @@ -51,7 +51,7 @@ with torch.no_grad(): ## Results -See audio samples at: http://swpark.me/melgan/. +See audio samples at: http://swpark.me/melgan/. ![](./assets/lj-tensorboard.png)