Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add pretrained model via pytorch hub, remove weightnorm at inference (r…
…esolves #16, #18) (#19) * remove weight_norm, add torch hub * add pretrain url * specify branch * fix * don't use hp in .inference * update README.md
- Loading branch information
1 parent
1725a75
commit f825a48
Showing
5 changed files
with
102 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
dependencies = ['torch'] | ||
import torch | ||
from model.generator import Generator | ||
|
||
model_params = { | ||
'nvidia_tacotron2_LJ11_epoch3200': { | ||
'mel_channel': 80, | ||
'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.1-alpha/nvidia_tacotron2_LJ11_epoch3200.pt', | ||
}, | ||
} | ||
|
||
|
||
def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True): | ||
params = model_params[model_name] | ||
model = Generator(params['mel_channel']) | ||
|
||
if pretrained: | ||
state_dict = torch.hub.load_state_dict_from_url(params['model_url'], | ||
progress=progress) | ||
model.load_state_dict(state_dict['model_g']) | ||
|
||
model.eval(inference=True) | ||
|
||
return model | ||
|
||
|
||
if __name__ == '__main__': | ||
vocoder = torch.hub.load('seungwonpark/melgan:hub', 'melgan') | ||
mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here | ||
|
||
print('Input mel-spectrogram shape: {}'.format(mel.shape)) | ||
|
||
if torch.cuda.is_available(): | ||
print('Moving data & model to GPU') | ||
vocoder = vocoder.cuda() | ||
mel = mel.cuda() | ||
|
||
with torch.no_grad(): | ||
audio = vocoder.inference(mel) | ||
|
||
print('Output audio shape: {}'.format(audio.shape)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters