Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix serial/deserialization issue for rnn wrapper.
Split the v2 implementation into separate module so that the v1 and v2 class can have same name in the json config. Fixing for #26581. PiperOrigin-RevId: 251074685
- Loading branch information
1 parent
98a0c57
commit e62dc43
Showing
16 changed files
with
1,102 additions
and
596 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Module implementing for RNN wrappers for TF v2.""" | ||
|
||
# Note that all the APIs under this module are exported as tf.nn.*. This is due | ||
# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported | ||
# here to avoid the cyclic dependency issue for serialization. These APIs will | ||
# probably be deprecated and removed in future since similar API is available in | ||
# existing Keras RNN API. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
|
||
from tensorflow.python.keras.layers import AbstractRNNCell | ||
from tensorflow.python.ops import rnn_cell_wrapper_impl | ||
from tensorflow.python.util.tf_export import tf_export | ||
|
||
|
||
class _RNNCellWrapperV2(AbstractRNNCell): | ||
"""Base class for cells wrappers V2 compatibility. | ||
This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define | ||
wrappers that are compatible with V1 and V2, and defines helper methods for | ||
this purpose. | ||
""" | ||
|
||
def __init__(self, cell, *args, **kwargs): | ||
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs) | ||
self.cell = cell | ||
|
||
def call(self, inputs, state, **kwargs): | ||
"""Runs the RNN cell step computation. | ||
When `call` is being used, we assume that the wrapper object has been built, | ||
and therefore the wrapped cells has been built via its `build` method and | ||
its `call` method can be used directly. | ||
This allows to use the wrapped cell and the non-wrapped cell equivalently | ||
when using `call` and `build`. | ||
Args: | ||
inputs: A tensor with wrapped cell's input. | ||
state: A tensor or tuple of tensors with wrapped cell's state. | ||
**kwargs: Additional arguments passed to the wrapped cell's `call`. | ||
Returns: | ||
A pair containing: | ||
- Output: A tensor with cell's output. | ||
- New state: A tensor or tuple of tensors with new wrapped cell's state. | ||
""" | ||
return self._call_wrapped_cell( | ||
inputs, state, cell_call_fn=self.cell.call, **kwargs) | ||
|
||
def build(self, inputs_shape): | ||
"""Builds the wrapped cell.""" | ||
self.cell.build(inputs_shape) | ||
self.built = True | ||
|
||
def get_config(self): | ||
config = { | ||
"cell": { | ||
"class_name": self.cell.__class__.__name__, | ||
"config": self.cell.get_config() | ||
}, | ||
} | ||
base_config = super(_RNNCellWrapperV2, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
@classmethod | ||
def from_config(cls, config, custom_objects=None): | ||
config = config.copy() | ||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top | ||
cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects) | ||
return cls(cell, **config) | ||
|
||
|
||
@tf_export("nn.RNNCellDropoutWrapper", v1=[]) | ||
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, | ||
_RNNCellWrapperV2): | ||
"""Operator adding dropout to inputs and outputs of the given cell.""" | ||
|
||
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation | ||
super(DropoutWrapper, self).__init__(*args, **kwargs) | ||
|
||
__init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ | ||
|
||
|
||
@tf_export("nn.RNNCellResidualWrapper", v1=[]) | ||
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, | ||
_RNNCellWrapperV2): | ||
"""RNNCell wrapper that ensures cell inputs are added to the outputs.""" | ||
|
||
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation | ||
super(ResidualWrapper, self).__init__(*args, **kwargs) | ||
|
||
__init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ | ||
|
||
|
||
@tf_export("nn.RNNCellDeviceWrapper", v1=[]) | ||
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, | ||
_RNNCellWrapperV2): | ||
"""Operator that ensures an RNNCell runs on a particular device.""" | ||
|
||
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation | ||
super(DeviceWrapper, self).__init__(*args, **kwargs) | ||
|
||
__init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ |
Oops, something went wrong.