Skip to content

Commit

Permalink
fix noise of unvoiced segment, match details of official repo. (#35)
Browse files Browse the repository at this point in the history
* fix leakyrelu slope, reflection pad

* Modify generator and discriminator to match official repo (#31)

* add weight norm at shortcut conv

* fix wrong import statement at generator.py

* fix #30, deploy changes to pytorch hub
  • Loading branch information
seungwonpark committed Dec 2, 2019
1 parent 1245ca2 commit b6db549
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# configuration
config/*
!config/default.yaml
temp-restore.yaml

# logs, checkpoints
chkpt/
Expand Down
6 changes: 3 additions & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from model.generator import Generator

model_params = {
'nvidia_tacotron2_LJ11_epoch3200': {
'nvidia_tacotron2_LJ11_epoch6400': {
'mel_channel': 80,
'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.2-alpha/nvidia_tacotron2_LJ11_epoch3200_v02.pt',
'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.3-alpha/nvidia_tacotron2_LJ11_epoch6400.pt',
},
}


def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True):
def melgan(model_name='nvidia_tacotron2_LJ11_epoch6400', pretrained=True, progress=True):
params = model_params[model_name]
model = Generator(params['mel_channel'])

Expand Down
18 changes: 11 additions & 7 deletions model/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,29 @@ def __init__(self):

self.discriminator = nn.ModuleList([
nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7)),
nn.LeakyReLU(),
nn.ReflectionPad1d(7),
nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1)),
nn.LeakyReLU(0.2, inplace=True),
),
nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, padding=20, groups=4)),
nn.LeakyReLU(),
nn.LeakyReLU(0.2, inplace=True),
),
nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, padding=20, groups=16)),
nn.LeakyReLU(),
nn.LeakyReLU(0.2, inplace=True),
),
nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, padding=20, groups=64)),
nn.LeakyReLU(),
nn.LeakyReLU(0.2, inplace=True),
),
nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256)),
nn.LeakyReLU(),
nn.LeakyReLU(0.2, inplace=True),
),
nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=2)),
nn.LeakyReLU(),
nn.LeakyReLU(0.2, inplace=True),
),
nn.utils.weight_norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1)),
])
Expand Down Expand Up @@ -58,3 +59,6 @@ def forward(self, x):
for feat in features:
print(feat.shape)
print(score.shape)

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)
25 changes: 15 additions & 10 deletions model/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F

from .res_stack import ResStack
#from res_stack import ResStack
# from res_stack import ResStack

MAX_WAV_VALUE = 32768.0

Expand All @@ -14,30 +14,32 @@ def __init__(self, mel_channel):
self.mel_channel = mel_channel

self.generator = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1, padding=3)),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)),

nn.LeakyReLU(),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)),

ResStack(256),

nn.LeakyReLU(),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)),

ResStack(128),

nn.LeakyReLU(),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)),

ResStack(64),

nn.LeakyReLU(),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)),

ResStack(32),

nn.LeakyReLU(),
nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1, padding=3)),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)),
nn.Tanh(),
)

Expand Down Expand Up @@ -84,11 +86,14 @@ def inference(self, mel):
from res_stack import ResStack
'''
if __name__ == '__main__':
model = Generator(7)
model = Generator(80)

x = torch.randn(3, 7, 10)
x = torch.randn(3, 80, 10)
print(x.shape)

y = model(x)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)
2 changes: 1 addition & 1 deletion model/multiscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):

self.pooling = nn.ModuleList(
[Identity()] +
[nn.AvgPool1d(kernel_size=4, stride=2, padding=2) for _ in range(1, 3)]
[nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)]
)

def forward(self, x):
Expand Down
27 changes: 17 additions & 10 deletions model/res_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@ class ResStack(nn.Module):
def __init__(self, channel):
super(ResStack, self).__init__()

self.layers = nn.ModuleList([
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LeakyReLU(),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i, padding=3**i)),
nn.LeakyReLU(),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=1, padding=1)),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3**i),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
)
for i in range(3)
])

self.shortcuts = nn.ModuleList([
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
for i in range(3)
])

def forward(self, x):
for layer in self.layers:
x = x + layer(x)
for block, shortcut in zip(self.blocks, self.shortcuts):
x = shortcut(x) + block(x)
return x

def remove_weight_norm(self):
for layer in self.layers:
nn.utils.remove_weight_norm(layer[1])
nn.utils.remove_weight_norm(layer[3])
for block, shortcut in zip(self.blocks, self.shortcuts):
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
nn.utils.remove_weight_norm(shortcut)

0 comments on commit b6db549

Please sign in to comment.