Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pretrained model via pytorch hub, remove weightnorm at inference (resolves #16, #18) #19

Merged
merged 6 commits into from
Oct 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910

- MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](https://github.com/NVIDIA/waveglow).
- This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio.
- TODO: Planning to publish pretrained model via [PyTorch Hub](https://pytorch.org/hub).
- Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub).

![](./assets/gd.png)

Expand All @@ -27,13 +27,31 @@ pip install -r requirements.txt
- `python trainer.py -c [config yaml file] -n [name of the run]`
- `tensorboard --logdir logs/`

## Pretrained model

Try with Google Colab: TODO

```python
import torch
vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
vocoder.eval()
mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here

if torch.cuda.is_available():
vocoder = vocoder.cuda()
mel = mel.cuda()

with torch.no_grad():
audio = vocoder.inference(mel)
```

## Inference

- `python inference.py -p [checkpoint path] -i [input mel path]`

## Results

See audio samples at: http://swpark.me/melgan/.
See audio samples at: http://swpark.me/melgan/.

![](./assets/lj-tensorboard.png)

Expand Down
41 changes: 41 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
dependencies = ['torch']
import torch
from model.generator import Generator

model_params = {
'nvidia_tacotron2_LJ11_epoch3200': {
'mel_channel': 80,
'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.1-alpha/nvidia_tacotron2_LJ11_epoch3200.pt',
},
}


def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True):
params = model_params[model_name]
model = Generator(params['mel_channel'])

if pretrained:
state_dict = torch.hub.load_state_dict_from_url(params['model_url'],
progress=progress)
model.load_state_dict(state_dict['model_g'])

model.eval(inference=True)

return model


if __name__ == '__main__':
vocoder = torch.hub.load('seungwonpark/melgan:hub', 'melgan')
mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here

print('Input mel-spectrogram shape: {}'.format(mel.shape))

if torch.cuda.is_available():
print('Moving data & model to GPU')
vocoder = vocoder.cuda()
mel = mel.cuda()

with torch.no_grad():
audio = vocoder.inference(mel)

print('Output audio shape: {}'.format(audio.shape))
14 changes: 2 additions & 12 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main(args):

model = Generator(hp.audio.n_mel_channels).cuda()
model.load_state_dict(checkpoint['model_g'])
model.eval()
model.eval(inference=False)

with torch.no_grad():
for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
Expand All @@ -29,17 +29,7 @@ def main(args):
mel = mel.unsqueeze(0)
mel = mel.cuda()

# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).cuda()
mel = torch.cat((mel, zero), axis=2)

audio = model(mel)
audio = audio.squeeze() # collapse all dimension except time axis
audio = audio[:-(hp.audio.hop_length*10)]
audio = MAX_WAV_VALUE * audio
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE)
audio = audio.short()
audio = model.inference(mel)
audio = audio.cpu().detach().numpy()

out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
Expand Down
34 changes: 34 additions & 0 deletions model/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from .res_stack import ResStack
#from res_stack import ResStack

MAX_WAV_VALUE = 32768.0


class Generator(nn.Module):
def __init__(self, mel_channel):
super(Generator, self).__init__()
self.mel_channel = mel_channel

self.generator = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1, padding=3)),
Expand Down Expand Up @@ -42,6 +45,37 @@ def forward(self, mel):
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
return self.generator(mel)

def eval(self, inference=False):
super(Generator, self).eval()

# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()

def remove_weight_norm(self):
for idx, layer in enumerate(self.generator):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except:
layer.remove_weight_norm()

def inference(self, mel):
hop_length = 256
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
mel = torch.cat((mel, zero), axis=2)

audio = self.forward(mel)
audio = audio.squeeze() # collapse all dimension except time axis
audio = audio[:-(hop_length*10)]
audio = MAX_WAV_VALUE * audio
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
audio = audio.short()

return audio


'''
to run this, fix
Expand Down
5 changes: 5 additions & 0 deletions model/res_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x

def remove_weight_norm(self):
for layer in self.layers:
nn.utils.remove_weight_norm(layer[1])
nn.utils.remove_weight_norm(layer[3])