-
Notifications
You must be signed in to change notification settings - Fork 29
/
modules.py
287 lines (239 loc) · 10.1 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import tensorflow as tf
from tacotron.models.zoneout_LSTM import ZoneoutLSTMCell
from tensorflow.contrib.rnn import LSTMBlockCell
from hparams import hparams
from tensorflow.contrib.rnn import GRUCell
from tacotron.utils.util import shape_list
def VAE(inputs, input_lengths, filters, kernel_size, strides, num_units, is_training, scope):
with tf.variable_scope(scope):
outputs = ReferenceEncoder(
inputs=inputs,
input_lengths=input_lengths,
filters=filters,
kernel_size=kernel_size,
strides=strides,
is_training=is_training)
mu = tf.layers.dense(outputs, num_units, name='mean')
log_var = tf.layers.dense(outputs, num_units, name='vari')
std = tf.exp(log_var)
z = tf.random_normal(shape=[tf.shape(mu)[0], num_units], mean=0.0, stddev=1.0)
output = mu + z * std
return output, mu, log_var
def ReferenceEncoder(inputs, input_lengths, filters, kernel_size, strides, is_training, scope='reference_encoder'):
with tf.variable_scope(scope):
reference_output = tf.expand_dims(inputs, axis=-1)
for i, channel in enumerate(filters):
reference_output = conv2d(reference_output, channel, kernel_size,
strides, tf.nn.relu, is_training, 'conv2d_{}'.format(i))
shape = shape_list(reference_output)
reference_output = tf.reshape(reference_output, shape[:-2] + [shape[2] * shape[3]])
#GRU
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
cell=GRUCell(128),
inputs=reference_output,
sequence_length=input_lengths,
dtype=tf.float32
)
return encoder_state
def conv1d(inputs, kernel_size, channels, activation, is_training, scope):
drop_rate = hparams.tacotron_dropout_rate
with tf.variable_scope(scope):
conv1d_output = tf.layers.conv1d(
inputs,
filters=channels,
kernel_size=kernel_size,
activation=None,
padding='same')
batched = tf.layers.batch_normalization(conv1d_output, training=is_training)
activated = activation(batched)
return tf.layers.dropout(activated, rate=drop_rate, training=is_training,
name='dropout_{}'.format(scope))
def conv2d(inputs, filters, kernel_size, strides, activation, is_training, scope):
with tf.variable_scope(scope):
conv2d_output = tf.layers.conv2d(
inputs, filters=filters, kernel_size=kernel_size, strides=strides, padding='same')
batch_norm_output = tf.layers.batch_normalization(
conv2d_output, training=is_training, name='batch_norm')
if activation is not None:
conv2d_output = activation(batch_norm_output)
return conv2d_output
class EncoderConvolutions:
"""Encoder convolutional layers used to find local dependencies in inputs characters.
"""
def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.nn.relu, scope=None):
"""
Args:
is_training: Boolean, determines if the model is training or in inference to control dropout
kernel_size: tuple or integer, The size of convolution kernels
channels: integer, number of convolutional kernels
activation: callable, postnet activation function for each convolutional layer
scope: Postnet scope.
"""
super(EncoderConvolutions, self).__init__()
self.is_training = is_training
self.kernel_size = kernel_size
self.channels = channels
self.activation = activation
self.scope = 'enc_conv_layers' if scope is None else scope
def __call__(self, inputs):
with tf.variable_scope(self.scope):
x = inputs
for i in range(hparams.enc_conv_num_layers):
x = conv1d(x, self.kernel_size, self.channels, self.activation,
self.is_training, 'conv_layer_{}_'.format(i + 1)+self.scope)
return x
class EncoderRNN:
"""Encoder bidirectional one layer LSTM
"""
def __init__(self, is_training, size=256, zoneout=0.1, scope=None):
"""
Args:
is_training: Boolean, determines if the model is training or in inference to control zoneout
size: integer, the number of LSTM units for each direction
zoneout: the zoneout factor
scope: EncoderRNN scope.
"""
super(EncoderRNN, self).__init__()
self.is_training = is_training
self.size = size
self.zoneout = zoneout
self.scope = 'encoder_LSTM' if scope is None else scope
#Create LSTM Cell
self._cell = ZoneoutLSTMCell(size, is_training,
zoneout_factor_cell=zoneout,
zoneout_factor_output=zoneout)
def __call__(self, inputs, input_lengths):
with tf.variable_scope(self.scope):
outputs, (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn(
self._cell,
self._cell,
inputs,
sequence_length=input_lengths,
dtype=tf.float32)
return tf.concat(outputs, axis=2) # Concat and return forward + backward outputs
class Prenet:
"""Two fully connected layers used as an information bottleneck for the attention.
"""
def __init__(self, is_training, layer_sizes=[256, 256], activation=tf.nn.relu, scope=None):
"""
Args:
is_training: Boolean, determines if the model is in training or inference to control dropout
layer_sizes: list of integers, the length of the list represents the number of pre-net
layers and the list values represent the layers number of units
activation: callable, activation functions of the prenet layers.
scope: Prenet scope.
"""
super(Prenet, self).__init__()
self.drop_rate = hparams.tacotron_dropout_rate
self.layer_sizes = layer_sizes
self.is_training = is_training
self.activation = activation
self.scope = 'prenet' if scope is None else scope
def __call__(self, inputs):
x = inputs
with tf.variable_scope(self.scope):
for i, size in enumerate(self.layer_sizes):
dense = tf.layers.dense(x, units=size, activation=self.activation,
name='dense_{}'.format(i + 1))
#The paper discussed introducing diversity in generation at inference time
#by using a dropout of 0.5 only in prenet layers (in both training and inference).
x = tf.layers.dropout(dense, rate=self.drop_rate, training=True,
name='dropout_{}'.format(i + 1) + self.scope)
return x
class DecoderRNN:
"""Decoder two uni directional LSTM Cells
"""
def __init__(self, is_training, layers=2, size=1024, zoneout=0.1, scope=None):
"""
Args:
is_training: Boolean, determines if the model is in training or inference to control zoneout
layers: integer, the number of LSTM layers in the decoder
size: integer, the number of LSTM units in each layer
zoneout: the zoneout factor
"""
super(DecoderRNN, self).__init__()
self.is_training = is_training
self.layers = layers
self.size = size
self.zoneout = zoneout
self.scope = 'decoder_rnn' if scope is None else scope
#Create a set of LSTM layers
self.rnn_layers = [ZoneoutLSTMCell(size, is_training,
zoneout_factor_cell=zoneout,
zoneout_factor_output=zoneout) for i in range(layers)]
self._cell = tf.contrib.rnn.MultiRNNCell(self.rnn_layers, state_is_tuple=True)
def __call__(self, inputs, states):
with tf.variable_scope(self.scope):
return self._cell(inputs, states)
class FrameProjection:
"""Projection layer to r * num_mels dimensions or num_mels dimensions
"""
def __init__(self, shape=80, activation=None, scope=None):
"""
Args:
shape: integer, dimensionality of output space (r*n_mels for decoder or n_mels for postnet)
activation: callable, activation function
scope: FrameProjection scope.
"""
super(FrameProjection, self).__init__()
self.shape = shape
self.activation = activation
self.scope = 'Linear_projection' if scope is None else scope
def __call__(self, inputs):
with tf.variable_scope(self.scope):
#If activation==None, this returns a simple Linear projection
#else the projection will be passed through an activation function
output = tf.layers.dense(inputs, units=self.shape, activation=self.activation,
name='projection_{}'.format(self.scope))
return output
class StopProjection:
"""Projection to a scalar and through a sigmoid activation
"""
def __init__(self, is_training, shape=hparams.outputs_per_step, activation=tf.nn.sigmoid, scope=None):
"""
Args:
is_training: Boolean, to control the use of sigmoid function as it is useless to use it
during training since it is integrate inside the sigmoid_crossentropy loss
shape: integer, dimensionality of output space. Defaults to 1 (scalar)
activation: callable, activation function. only used during inference
scope: StopProjection scope.
"""
super(StopProjection, self).__init__()
self.is_training = is_training
self.shape = shape
self.activation = activation
self.scope = 'stop_token_projection' if scope is None else scope
def __call__(self, inputs):
with tf.variable_scope(self.scope):
output = tf.layers.dense(inputs, units=self.shape,
activation=None, name='projection_{}'.format(self.scope))
#During training, don't use activation as it is integrated inside the sigmoid_cross_entropy loss function
if self.is_training:
return output
return self.activation(output)
class Postnet:
"""Postnet that takes final decoder output and fine tunes it (using vision on past and future frames)
"""
def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.nn.tanh, scope=None):
"""
Args:
is_training: Boolean, determines if the model is training or in inference to control dropout
kernel_size: tuple or integer, The size of convolution kernels
channels: integer, number of convolutional kernels
activation: callable, postnet activation function for each convolutional layer
scope: Postnet scope.
"""
super(Postnet, self).__init__()
self.is_training = is_training
self.kernel_size = kernel_size
self.channels = channels
self.activation = activation
self.scope = 'postnet_convolutions' if scope is None else scope
def __call__(self, inputs):
with tf.variable_scope(self.scope):
x = inputs
for i in range(hparams.postnet_num_layers - 1):
x = conv1d(x, self.kernel_size, self.channels, self.activation,
self.is_training, 'conv_layer_{}_'.format(i + 1)+self.scope)
x = conv1d(x, self.kernel_size, self.channels, lambda _: _, self.is_training, 'conv_layer_{}_'.format(5)+self.scope)
return x