/
rnn_cell_wrapper_v2.py
127 lines (99 loc) · 5.02 KB
/
rnn_cell_wrapper_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module implementing for RNN wrappers for TF v2."""
# Note that all the APIs under this module are exported as tf.nn.*. This is due
# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
# here to avoid the cyclic dependency issue for serialization. These APIs will
# probably be deprecated and removed in future since similar API is available in
# existing Keras RNN API.
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.util.tf_export import tf_export
class _RNNCellWrapperV2(recurrent.AbstractRNNCell):
"""Base class for cells wrappers V2 compatibility.
This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
wrappers that are compatible with V1 and V2, and defines helper methods for
this purpose.
"""
def __init__(self, cell, *args, **kwargs):
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
self.cell = cell
cell_call_spec = tf_inspect.getfullargspec(cell.call)
self._expects_training_arg = ("training" in cell_call_spec.args) or (
cell_call_spec.varkw is not None
)
def call(self, inputs, state, **kwargs):
"""Runs the RNN cell step computation.
When `call` is being used, we assume that the wrapper object has been built,
and therefore the wrapped cells has been built via its `build` method and
its `call` method can be used directly.
This allows to use the wrapped cell and the non-wrapped cell equivalently
when using `call` and `build`.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
**kwargs: Additional arguments passed to the wrapped cell's `call`.
Returns:
A pair containing:
- Output: A tensor with cell's output.
- New state: A tensor or tuple of tensors with new wrapped cell's state.
"""
return self._call_wrapped_cell(
inputs, state, cell_call_fn=self.cell.call, **kwargs)
def build(self, inputs_shape):
"""Builds the wrapped cell."""
self.cell.build(inputs_shape)
self.built = True
def get_config(self):
config = {
"cell": {
"class_name": self.cell.__class__.__name__,
"config": self.cell.get_config()
},
}
base_config = super(_RNNCellWrapperV2, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
from tensorflow.python.keras.layers.serialization import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects)
return cls(cell, **config)
@tf_export("nn.RNNCellDropoutWrapper", v1=[])
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
_RNNCellWrapperV2):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DropoutWrapper, self).__init__(*args, **kwargs)
if isinstance(self.cell, recurrent.LSTMCell):
raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
"Please use LSTMCell(dropout=x, recurrent_dropout=y) "
"instead.")
__init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
@tf_export("nn.RNNCellResidualWrapper", v1=[])
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
_RNNCellWrapperV2):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(ResidualWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__
@tf_export("nn.RNNCellDeviceWrapper", v1=[])
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
_RNNCellWrapperV2):
"""Operator that ensures an RNNCell runs on a particular device."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DeviceWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__