/
rnn_cell_wrapper_v2.py
131 lines (103 loc) · 5.26 KB
/
rnn_cell_wrapper_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module implementing for RNN wrappers for TF v2."""
# Note that all the APIs under this module are exported as tf.nn.*. This is due
# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
# here to avoid the cyclic dependency issue for serialization. These APIs will
# probably be deprecated and removed in future since similar API is available in
# existing Keras RNN API.
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
class _RNNCellWrapperV2(recurrent.AbstractRNNCell):
"""Base class for cells wrappers V2 compatibility.
This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
wrappers that are compatible with V1 and V2, and defines helper methods for
this purpose.
"""
def __init__(self, cell, *args, **kwargs):
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
self.cell = cell
cell_call_spec = tf_inspect.getfullargspec(cell.call)
self._expects_training_arg = ("training" in cell_call_spec.args) or (
cell_call_spec.varkw is not None
)
def call(self, inputs, state, **kwargs):
"""Runs the RNN cell step computation.
When `call` is being used, we assume that the wrapper object has been built,
and therefore the wrapped cells has been built via its `build` method and
its `call` method can be used directly.
This allows to use the wrapped cell and the non-wrapped cell equivalently
when using `call` and `build`.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
**kwargs: Additional arguments passed to the wrapped cell's `call`.
Returns:
A pair containing:
- Output: A tensor with cell's output.
- New state: A tensor or tuple of tensors with new wrapped cell's state.
"""
return self._call_wrapped_cell(
inputs, state, cell_call_fn=self.cell.call, **kwargs)
def build(self, inputs_shape):
"""Builds the wrapped cell."""
self.cell.build(inputs_shape)
self.built = True
def get_config(self):
config = {
"cell": {
"class_name": self.cell.__class__.__name__,
"config": self.cell.get_config()
},
}
base_config = super(_RNNCellWrapperV2, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
from tensorflow.python.keras.layers.serialization import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects)
return cls(cell, **config)
@deprecated(None, "Please use tf.keras.layers.RNN instead.")
@tf_export("nn.RNNCellDropoutWrapper", v1=[])
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
_RNNCellWrapperV2):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DropoutWrapper, self).__init__(*args, **kwargs)
if isinstance(self.cell, recurrent.LSTMCell):
raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
"Please use LSTMCell(dropout=x, recurrent_dropout=y) "
"instead.")
__init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
@deprecated(None, "Please use tf.keras.layers.RNN instead.")
@tf_export("nn.RNNCellResidualWrapper", v1=[])
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
_RNNCellWrapperV2):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(ResidualWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__
@deprecated(None, "Please use tf.keras.layers.RNN instead.")
@tf_export("nn.RNNCellDeviceWrapper", v1=[])
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
_RNNCellWrapperV2):
"""Operator that ensures an RNNCell runs on a particular device."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DeviceWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__