/
CRDNN.py
315 lines (287 loc) · 10.3 KB
/
CRDNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""A combination of Convolutional, Recurrent, and Fully-connected networks.
Authors
* Mirco Ravanelli 2020
* Peter Plantinga 2020
* Ju-Chieh Chou 2020
* Titouan Parcollet 2020
* Abdel 2020
"""
import torch
import speechbrain as sb
class CRDNN(sb.nnet.containers.Sequential):
"""This model is a combination of CNNs, RNNs, and DNNs.
This model expects 3-dimensional input [batch, time, feats] and
by default produces output of the size [batch, time, dnn_neurons].
One exception is if ``using_2d_pooling`` or ``time_pooling`` is True.
In this case, the time dimension will be downsampled.
Arguments
---------
input_size : int
The length of the expected input at the third dimension.
input_shape : tuple
While input_size will suffice, this option can allow putting
CRDNN into a sequential with other classes.
activation : torch class
A class used for constructing the activation layers for CNN and DNN.
dropout : float
Neuron dropout rate as applied to CNN, RNN, and DNN.
cnn_blocks : int
The number of convolutional neural blocks to include.
cnn_channels : list of ints
A list of the number of output channels for each CNN block.
cnn_kernelsize : tuple of ints
The size of the convolutional kernels.
time_pooling : bool
Whether to pool the utterance on the time axis before the RNN.
time_pooling_size : int
The number of elements to pool on the time axis.
freq_pooling_size : int
The number of elements to pool on the frequency axis.
rnn_class : torch class
The type of RNN to use in CRDNN network (LiGRU, LSTM, GRU, RNN)
inter_layer_pooling_size : list of ints
A list of the pooling sizes for each CNN block.
using_2d_pooling: bool
Whether using a 2D or 1D pooling after each CNN block.
rnn_layers : int
The number of recurrent RNN layers to include.
rnn_neurons : int
Number of neurons in each layer of the RNN.
rnn_bidirectional : bool
Whether this model will process just forward or in both directions.
rnn_re_init : bool,
If True, an orthogonal initialization will be applied to the recurrent
weights.
dnn_blocks : int
The number of linear neural blocks to include.
dnn_neurons : int
The number of neurons in the linear layers.
projection_dim : int
The number of neurons in the projection layer.
This layer is used to reduce the size of the flattened
representation obtained after the CNN blocks.
use_rnnp: bool
If True, a linear projection layer is added between RNN layers.
Example
-------
>>> inputs = torch.rand([10, 15, 60])
>>> model = CRDNN(input_shape=inputs.shape)
>>> outputs = model(inputs)
>>> outputs.shape
torch.Size([10, 15, 512])
"""
def __init__(
self,
input_size=None,
input_shape=None,
activation=torch.nn.LeakyReLU,
dropout=0.15,
cnn_blocks=2,
cnn_channels=[128, 256],
cnn_kernelsize=(3, 3),
time_pooling=False,
time_pooling_size=2,
freq_pooling_size=2,
rnn_class=sb.nnet.RNN.LiGRU,
inter_layer_pooling_size=[2, 2],
using_2d_pooling=False,
rnn_layers=4,
rnn_neurons=512,
rnn_bidirectional=True,
rnn_re_init=False,
dnn_blocks=2,
dnn_neurons=512,
projection_dim=-1,
use_rnnp=False,
):
if input_size is None and input_shape is None:
raise ValueError("Must specify one of input_size or input_shape")
if input_shape is None:
input_shape = [None, None, input_size]
super().__init__(input_shape=input_shape)
if cnn_blocks > 0:
self.append(sb.nnet.containers.Sequential, layer_name="CNN")
for block_index in range(cnn_blocks):
self.CNN.append(
CNN_Block,
channels=cnn_channels[block_index],
kernel_size=cnn_kernelsize,
using_2d_pool=using_2d_pooling,
pooling_size=inter_layer_pooling_size[block_index],
activation=activation,
dropout=dropout,
layer_name=f"block_{block_index}",
)
if time_pooling:
self.append(
sb.nnet.pooling.Pooling1d(
pool_type="max",
input_dims=4,
kernel_size=time_pooling_size,
pool_axis=1,
),
layer_name="time_pooling",
)
# This projection helps reducing the number of parameters
# when using large number of CNN filters.
# Large numbers of CNN filters + large features
# often lead to very large flattened layers.
# This layer projects it back to something reasonable.
if projection_dim != -1:
self.append(sb.nnet.containers.Sequential, layer_name="projection")
self.projection.append(
sb.nnet.linear.Linear,
n_neurons=projection_dim,
bias=True,
combine_dims=True,
layer_name="linear",
)
self.projection.append(
sb.nnet.normalization.LayerNorm, layer_name="norm"
)
self.projection.append(activation(), layer_name="act")
if rnn_layers > 0:
if use_rnnp:
self.append(sb.nnet.containers.Sequential, layer_name="RNN")
for _ in range(rnn_layers):
self.append(
rnn_class,
hidden_size=rnn_neurons,
num_layers=1,
bidirectional=rnn_bidirectional,
re_init=rnn_re_init,
)
self.append(
sb.nnet.linear.Linear,
n_neurons=dnn_neurons,
bias=True,
combine_dims=True,
)
self.append(torch.nn.Dropout(p=dropout))
else:
self.append(
rnn_class,
layer_name="RNN",
hidden_size=rnn_neurons,
num_layers=rnn_layers,
dropout=dropout,
bidirectional=rnn_bidirectional,
re_init=rnn_re_init,
)
if dnn_blocks > 0:
self.append(sb.nnet.containers.Sequential, layer_name="DNN")
for block_index in range(dnn_blocks):
self.DNN.append(
DNN_Block,
neurons=dnn_neurons,
activation=activation,
dropout=dropout,
layer_name=f"block_{block_index}",
)
class CNN_Block(sb.nnet.containers.Sequential):
"""CNN Block, based on VGG blocks.
Arguments
---------
input_shape : tuple
Expected shape of the input.
channels : int
Number of convolutional channels for the block.
kernel_size : tuple
Size of the 2d convolutional kernel
activation : torch.nn.Module class
A class to be used for instantiating an activation layer.
using_2d_pool : bool
Whether to use 2d pooling or only 1d pooling.
pooling_size : int
Size of pooling kernel, duplicated for 2d pooling.
dropout : float
Rate to use for dropping channels.
Example
-------
>>> inputs = torch.rand(10, 15, 60)
>>> block = CNN_Block(input_shape=inputs.shape, channels=32)
>>> outputs = block(inputs)
>>> outputs.shape
torch.Size([10, 15, 30, 32])
"""
def __init__(
self,
input_shape,
channels,
kernel_size=[3, 3],
activation=torch.nn.LeakyReLU,
using_2d_pool=False,
pooling_size=2,
dropout=0.15,
):
super().__init__(input_shape=input_shape)
self.append(
sb.nnet.CNN.Conv2d,
out_channels=channels,
kernel_size=kernel_size,
layer_name="conv_1",
)
self.append(sb.nnet.normalization.LayerNorm, layer_name="norm_1")
self.append(activation(), layer_name="act_1")
self.append(
sb.nnet.CNN.Conv2d,
out_channels=channels,
kernel_size=kernel_size,
layer_name="conv_2",
)
self.append(sb.nnet.normalization.LayerNorm, layer_name="norm_2")
self.append(activation(), layer_name="act_2")
if using_2d_pool:
self.append(
sb.nnet.pooling.Pooling2d(
pool_type="max",
kernel_size=(pooling_size, pooling_size),
pool_axis=(1, 2),
),
layer_name="pooling",
)
else:
self.append(
sb.nnet.pooling.Pooling1d(
pool_type="max",
input_dims=4,
kernel_size=pooling_size,
pool_axis=2,
),
layer_name="pooling",
)
self.append(
sb.nnet.dropout.Dropout2d(drop_rate=dropout), layer_name="drop"
)
class DNN_Block(sb.nnet.containers.Sequential):
"""Block for linear layers.
Arguments
---------
input_shape : tuple
Expected shape of the input.
neurons : int
Size of the linear layers.
activation : torch.nn.Module class
Class definition to use for constructing activation layers.
dropout : float
Rate to use for dropping neurons.
Example
-------
>>> inputs = torch.rand(10, 15, 128)
>>> block = DNN_Block(input_shape=inputs.shape, neurons=64)
>>> outputs = block(inputs)
>>> outputs.shape
torch.Size([10, 15, 64])
"""
def __init__(
self, input_shape, neurons, activation=torch.nn.LeakyReLU, dropout=0.15
):
super().__init__(input_shape=input_shape)
self.append(
sb.nnet.linear.Linear,
n_neurons=neurons,
layer_name="linear",
)
self.append(sb.nnet.normalization.BatchNorm1d, layer_name="norm")
self.append(activation(), layer_name="act")
self.append(torch.nn.Dropout(p=dropout), layer_name="dropout")