-
Notifications
You must be signed in to change notification settings - Fork 613
/
Copy pathesn_cell.py
215 lines (191 loc) · 8.17 KB
/
esn_cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements ESN Cell."""
import tensorflow as tf
import tensorflow.keras as keras
from typeguard import typechecked
from tensorflow_addons.utils.types import (
Activation,
Initializer,
)
@tf.keras.utils.register_keras_serializable(package="Addons")
class ESNCell(keras.layers.AbstractRNNCell):
"""Echo State recurrent Network (ESN) cell.
This implements the recurrent cell from the paper:
H. Jaeger
"The "echo state" approach to analysing and training recurrent neural networks".
GMD Report148, German National Research Center for Information Technology, 2001.
https://www.researchgate.net/publication/215385037
Example:
>>> inputs = np.random.random([30,23,9]).astype(np.float32)
>>> ESNCell = tfa.rnn.ESNCell(4)
>>> rnn = tf.keras.layers.RNN(ESNCell, return_sequences=True, return_state=True)
>>> outputs, memory_state = rnn(inputs)
>>> outputs.shape
TensorShape([30, 23, 4])
>>> memory_state.shape
TensorShape([30, 4])
Args:
units: Positive integer, dimensionality in the reservoir.
connectivity: Float between 0 and 1.
Connection probability between two reservoir units.
Default: 0.1.
leaky: Float between 0 and 1.
Leaking rate of the reservoir.
If you pass 1, it is the special case the model does not have leaky
integration.
Default: 1.
spectral_radius: Float between 0 and 1.
Desired spectral radius of recurrent weight matrix.
Default: 0.9.
use_norm2: Boolean, whether to use the p-norm function (with p=2) as an upper
bound of the spectral radius so that the echo state property is satisfied.
It avoids to compute the eigenvalues which has an exponential complexity.
Default: False.
use_bias: Boolean, whether the layer uses a bias vector.
Default: True.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
Default: `glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix,
used for the linear transformation of the recurrent state.
Default: `glorot_uniform`.
bias_initializer: Initializer for the bias vector.
Default: `zeros`.
Call arguments:
inputs: A 2D tensor (batch x num_units).
states: List of state tensors corresponding to the previous timestep.
"""
@typechecked
def __init__(
self,
units: int,
connectivity: float = 0.1,
leaky: float = 1,
spectral_radius: float = 0.9,
use_norm2: bool = False,
use_bias: bool = True,
activation: Activation = "tanh",
kernel_initializer: Initializer = "glorot_uniform",
recurrent_initializer: Initializer = "glorot_uniform",
bias_initializer: Initializer = "zeros",
**kwargs,
):
super().__init__(**kwargs)
self.units = units
self.connectivity = connectivity
self.leaky = leaky
self.spectral_radius = spectral_radius
self.use_norm2 = use_norm2
self.use_bias = use_bias
self.activation = tf.keras.activations.get(activation)
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.recurrent_initializer = tf.keras.initializers.get(recurrent_initializer)
self.bias_initializer = tf.keras.initializers.get(bias_initializer)
self._state_size = units
self._output_size = units
@property
def state_size(self):
return self._state_size
@property
def output_size(self):
return self._output_size
def build(self, inputs_shape):
input_size = tf.compat.dimension_value(tf.TensorShape(inputs_shape)[-1])
if input_size is None:
raise ValueError(
"Could not infer input size from inputs.get_shape()[-1]. Shape received is %s"
% inputs_shape
)
def _esn_recurrent_initializer(shape, dtype, partition_info=None):
recurrent_weights = tf.keras.initializers.get(self.recurrent_initializer)(
shape, dtype
)
connectivity_mask = tf.cast(
tf.math.less_equal(tf.random.uniform(shape), self.connectivity),
dtype,
)
recurrent_weights = tf.math.multiply(recurrent_weights, connectivity_mask)
# Satisfy the necessary condition for the echo state property `max(eig(W)) < 1`
if self.use_norm2:
# This condition is approximated scaling the norm 2 of the reservoir matrix
# which is an upper bound of the spectral radius.
recurrent_norm2 = tf.math.sqrt(
tf.math.reduce_sum(tf.math.square(recurrent_weights))
)
is_norm2_0 = tf.cast(tf.math.equal(recurrent_norm2, 0), dtype)
scaling_factor = self.spectral_radius / (
recurrent_norm2 + 1 * is_norm2_0
)
else:
abs_eig_values = tf.abs(tf.linalg.eig(recurrent_weights)[0])
scaling_factor = tf.math.divide_no_nan(
self.spectral_radius, tf.reduce_max(abs_eig_values)
)
recurrent_weights = tf.multiply(recurrent_weights, scaling_factor)
return recurrent_weights
self.recurrent_kernel = self.add_weight(
name="recurrent_kernel",
shape=[self.units, self.units],
initializer=_esn_recurrent_initializer,
trainable=False,
dtype=self.dtype,
)
self.kernel = self.add_weight(
name="kernel",
shape=[input_size, self.units],
initializer=self.kernel_initializer,
trainable=False,
dtype=self.dtype,
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=[self.units],
initializer=self.bias_initializer,
trainable=False,
dtype=self.dtype,
)
self.built = True
def call(self, inputs, state):
in_matrix = tf.concat([inputs, state[0]], axis=1)
weights_matrix = tf.concat([self.kernel, self.recurrent_kernel], axis=0)
output = tf.linalg.matmul(in_matrix, weights_matrix)
if self.use_bias:
output = output + self.bias
output = self.activation(output)
output = (1 - self.leaky) * state[0] + self.leaky * output
return output, output
def get_config(self):
config = {
"units": self.units,
"connectivity": self.connectivity,
"leaky": self.leaky,
"spectral_radius": self.spectral_radius,
"use_norm2": self.use_norm2,
"use_bias": self.use_bias,
"activation": tf.keras.activations.serialize(self.activation),
"kernel_initializer": tf.keras.initializers.serialize(
self.kernel_initializer
),
"recurrent_initializer": tf.keras.initializers.serialize(
self.recurrent_initializer
),
"bias_initializer": tf.keras.initializers.serialize(self.bias_initializer),
}
base_config = super().get_config()
return {**base_config, **config}