-
Notifications
You must be signed in to change notification settings - Fork 10
/
generator.py
165 lines (138 loc) · 5.63 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import logging
from modules import LVCBlock
import torch.nn.functional as F
LRELU_SLOPE = 0.1
class UnivNet(torch.nn.Module):
"""Parallel WaveGAN Generator module."""
def __init__(self, h, use_weight_norm=True):
super().__init__()
in_channels = h.cond_in_channels
out_channels = h.out_channels
inner_channels = h.cg_channels
cond_channels = h.num_mels
upsample_ratios = h.upsample_rates
lvc_layers_each_block = h.num_lvc_blocks
lvc_kernel_size = h.lvc_kernels
kpnet_hidden_channels = h.lvc_hidden_channels
kpnet_conv_size = h.lvc_conv_size
dropout = h.dropout
self.in_channels = in_channels
self.out_channels = out_channels
self.cond_channels = cond_channels
self.lvc_block_nums = len(upsample_ratios)
# define first convolution
self.first_conv = torch.nn.Conv1d(in_channels, inner_channels,
kernel_size=7, padding=(7 - 1) // 2,
dilation=1, bias=True)
# define residual blocks
self.lvc_blocks = torch.nn.ModuleList()
cond_hop_length = 1
for n in range(self.lvc_block_nums):
cond_hop_length = cond_hop_length * upsample_ratios[n]
lvcb = LVCBlock(
in_channels=inner_channels,
cond_channels=cond_channels,
upsample_ratio=upsample_ratios[n],
conv_layers=lvc_layers_each_block,
conv_kernel_size=lvc_kernel_size,
cond_hop_length=cond_hop_length,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=dropout,
)
self.lvc_blocks += [lvcb]
# define output layers
self.last_conv_layers = torch.nn.ModuleList([
torch.nn.Conv1d(inner_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2,
dilation=1, bias=True),
])
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, c):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
c (Tensor): Local conditioning auxiliary features (B, C ,T').
Returns:
Tensor: Output tensor (B, out_channels, T)
"""
x = self.first_conv(x)
for n in range(self.lvc_block_nums):
x = self.lvc_blocks[n](x, c)
# apply final layers
for f in self.last_conv_layers:
x = F.leaky_relu(x, LRELU_SLOPE)
x = f(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size,
dilation=lambda x: 2 ** x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
"""Return receptive field size."""
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
def inference(self, c=None, x=None):
"""Perform inference.
Args:
c (Union[Tensor, ndarray]): Local conditioning auxiliary features (T' ,C).
x (Union[Tensor, ndarray]): Input noise signal (T, 1).
Returns:
Tensor: Output tensor (T, out_channels)
"""
if x is not None:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.float).to(next(self.parameters()).device)
x = x.transpose(1, 0).unsqueeze(0)
else:
assert c is not None
x = torch.randn(1, 1, len(c) * self.upsample_factor).to(next(self.parameters()).device)
if c is not None:
if not isinstance(c, torch.Tensor):
c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device)
c = c.transpose(1, 0).unsqueeze(0)
c = torch.nn.ReplicationPad1d(self.aux_context_window)(c)
return self.forward(x, c).squeeze(0).transpose(1, 0)
'''
to run this, fix
from . import ResStack
into
from res_stack import ResStack
'''
if __name__ == '__main__':
'''
torch.Size([3, 80, 10])
torch.Size([3, 1, 2000])
4527362
'''
model = UnivNet()
x = torch.randn(3, 64, 10)
c = torch.randn(3, 80, 10) # (B, channels, T).
print(c.shape)
y = model(x, c) # (B, 1, T ** prod(upsample_scales)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)