/
tcn.py
259 lines (216 loc) · 10.2 KB
/
tcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import keras.backend as K
import keras.layers
from keras import optimizers
from keras.engine.topology import Layer
from keras.layers import Activation, Lambda
from keras.layers import Conv1D, SpatialDropout1D
from keras.layers import Convolution1D, Dense
from keras.models import Input, Model
from typing import List, Tuple
def channel_normalization(x):
# type: (Layer) -> Layer
""" Normalize a layer to the maximum activation
This keeps a layers values between zero and one.
It helps with relu's unbounded activation
Args:
x: The layer to normalize
Returns:
A maximal normalized layer
"""
max_values = K.max(K.abs(x), 2, keepdims=True) + 1e-5
out = x / max_values
return out
def wave_net_activation(x):
# type: (Layer) -> Layer
"""This method defines the activation used for WaveNet
described in https://deepmind.com/blog/wavenet-generative-model-raw-audio/
Args:
x: The layer we want to apply the activation to
Returns:
A new layer with the wavenet activation applied
"""
tanh_out = Activation('tanh')(x)
sigm_out = Activation('sigmoid')(x)
return keras.layers.multiply([tanh_out, sigm_out])
def residual_block(x, s, i, activation, nb_filters, kernel_size, padding, dropout_rate=0, name=''):
# type: (Layer, int, int, str, int, int, float, str) -> Tuple[Layer, Layer]
"""Defines the residual block for the WaveNet TCN
Args:
x: The previous layer in the model
s: The stack index i.e. which stack in the overall TCN
i: The dilation power of 2 we are using for this residual block
activation: The name of the type of activation to use
nb_filters: The number of convolutional filters to use in this block
kernel_size: The size of the convolutional kernel
padding: The padding used in the convolutional layers, 'same' or 'causal'.
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
name: Name of the model. Useful when having multiple TCN.
Returns:
A tuple where the first element is the residual model layer, and the second
is the skip connection.
"""
original_x = x
conv = Conv1D(filters=nb_filters, kernel_size=kernel_size,
dilation_rate=i, padding=padding,
name=name + '_dilated_conv_%d_tanh_s%d' % (i, s))(x)
if activation == 'norm_relu':
x = Activation('relu')(conv)
x = Lambda(channel_normalization)(x)
elif activation == 'wavenet':
x = wave_net_activation(conv)
else:
x = Activation(activation)(conv)
x = SpatialDropout1D(dropout_rate, name=name + '_spatial_dropout1d_%d_s%d_%f' % (i, s, dropout_rate))(x)
# 1x1 conv.
x = Convolution1D(nb_filters, 1, padding='same')(x)
res_x = keras.layers.add([original_x, x])
return res_x, x
def process_dilations(dilations):
def is_power_of_two(num):
return num != 0 and ((num & (num - 1)) == 0)
if all([is_power_of_two(i) for i in dilations]):
return dilations
else:
new_dilations = [2 ** i for i in dilations]
# print(f'Updated dilations from {dilations} to {new_dilations} because of backwards compatibility.')
return new_dilations
class TCN:
"""Creates a TCN layer.
Args:
input_layer: A tensor of shape (batch_size, timesteps, input_dim).
nb_filters: The number of filters to use in the convolutional layers.
kernel_size: The size of the kernel to use in each convolutional layer.
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
nb_stacks : The number of stacks of residual blocks to use.
activation: The activations to use (norm_relu, wavenet, relu...).
padding: The padding to use in the convolutional layers, 'causal' or 'same'.
use_skip_connections: Boolean. If we want to add skip connections from input to each residual block.
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
name: Name of the model. Useful when having multiple TCN.
Returns:
A TCN layer.
"""
def __init__(self,
nb_filters=64,
kernel_size=2,
nb_stacks=1,
dilations=None,
activation='norm_relu',
padding='causal',
use_skip_connections=True,
dropout_rate=0.0,
return_sequences=True,
name='tcn'):
self.name = name
self.return_sequences = return_sequences
self.dropout_rate = dropout_rate
self.use_skip_connections = use_skip_connections
self.activation = activation
self.dilations = dilations
self.nb_stacks = nb_stacks
self.kernel_size = kernel_size
self.nb_filters = nb_filters
self.padding = padding
# backwards incompatibility warning.
# o = tcn.TCN(i, return_sequences=False) =>
# o = tcn.TCN(return_sequences=False)(i)
if padding != 'causal' and padding != 'same':
raise ValueError("Only 'causal' or 'same' paddings are compatible for this layer.")
if not isinstance(nb_filters, int):
print('An interface change occurred after the version 2.1.2.')
print('Before: tcn.TCN(i, return_sequences=False, ...)')
print('Now should be: tcn.TCN(return_sequences=False, ...)(i)')
print('Second solution is to pip install keras-tcn==2.1.2 to downgrade.')
raise Exception()
def __call__(self, inputs):
if self.dilations is None:
self.dilations = [1, 2, 4, 8, 16, 32]
x = inputs
x = Convolution1D(self.nb_filters, 1, padding=self.padding, name=self.name + '_initial_conv')(x)
skip_connections = []
for s in range(self.nb_stacks):
for i in self.dilations:
x, skip_out = residual_block(x, s, i, self.activation, self.nb_filters,
self.kernel_size, self.padding, self.dropout_rate, name=self.name)
skip_connections.append(skip_out)
if self.use_skip_connections:
x = keras.layers.add(skip_connections)
x = Activation('relu')(x)
if not self.return_sequences:
output_slice_index = -1
x = Lambda(lambda tt: tt[:, output_slice_index, :])(x)
return x
def compiled_tcn(num_feat, # type: int
num_classes, # type: int
nb_filters, # type: int
kernel_size, # type: int
dilations, # type: List[int]
nb_stacks, # type: int
max_len, # type: int
activation='norm_relu', # type: str
padding='causal', # type: str
use_skip_connections=True, # type: bool
return_sequences=True,
regression=False, # type: bool
dropout_rate=0.05, # type: float
name='tcn' # type: str
):
# type: (...) -> keras.Model
"""Creates a compiled TCN model for a given task (i.e. regression or classification).
Args:
num_feat: The number of features of your input, i.e. the last dimension of: (batch_size, timesteps, input_dim).
num_classes: The size of the final dense layer, how many classes we are predicting.
nb_filters: The number of filters to use in the convolutional layers.
kernel_size: The size of the kernel to use in each convolutional layer.
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
nb_stacks : The number of stacks of residual blocks to use.
max_len: The maximum sequence length, use None if the sequence length is dynamic.
activation: The activations to use.
padding: The padding to use in the convolutional layers.
use_skip_connections: Boolean. If we want to add skip connections from input to each residual block.
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
regression: Whether the output should be continuous or discrete.
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
name: Name of the model. Useful when having multiple TCN.
Returns:
A compiled keras TCN.
"""
dilations = process_dilations(dilations)
input_layer = Input(shape=(max_len, num_feat))
x = TCN(nb_filters, kernel_size, nb_stacks, dilations, activation,
padding, use_skip_connections, dropout_rate, return_sequences, name)(input_layer)
print('x.shape=', x.shape)
if not regression:
# classification
x = Dense(num_classes)(x)
x = Activation('softmax')(x)
output_layer = x
print(f'model.x = {input_layer.shape}')
print(f'model.y = {output_layer.shape}')
model = Model(input_layer, output_layer)
# https://github.com/keras-team/keras/pull/11373
# It's now in Keras@master but still not available with pip.
# TODO To remove later.
def accuracy(y_true, y_pred):
# reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
if K.ndim(y_true) == K.ndim(y_pred):
y_true = K.squeeze(y_true, -1)
# convert dense predictions to labels
y_pred_labels = K.argmax(y_pred, axis=-1)
y_pred_labels = K.cast(y_pred_labels, K.floatx())
return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
adam = optimizers.Adam(lr=0.002, clipnorm=1.)
model.compile(adam, loss='sparse_categorical_crossentropy', metrics=[accuracy])
print('Adam with norm clipping.')
else:
# regression
x = Dense(1)(x)
x = Activation('linear')(x)
output_layer = x
print(f'model.x = {input_layer.shape}')
print(f'model.y = {output_layer.shape}')
model = Model(input_layer, output_layer)
adam = optimizers.Adam(lr=0.002, clipnorm=1.)
model.compile(adam, loss='mean_squared_error')
return model