Skip to content

Commit

Permalink
Fix serial/deserialization issue for rnn wrapper.
Browse files Browse the repository at this point in the history
Split the v2 implementation into separate module so that the v1 and v2
class can have same name in the json config.

Fixing for #26581.

PiperOrigin-RevId: 251074685
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jun 1, 2019
1 parent 98a0c57 commit e62dc43
Show file tree
Hide file tree
Showing 16 changed files with 1,102 additions and 596 deletions.
1 change: 1 addition & 0 deletions tensorflow/python/BUILD
Expand Up @@ -3328,6 +3328,7 @@ py_library(
srcs = [
"ops/rnn_cell.py",
"ops/rnn_cell_impl.py",
"ops/rnn_cell_wrapper_impl.py",
],
srcs_version = "PY2AND3",
deps = [
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/python/keras/BUILD
Expand Up @@ -387,6 +387,7 @@ py_library(
"layers/pooling.py",
"layers/recurrent.py",
"layers/recurrent_v2.py",
"layers/rnn_cell_wrapper_v2.py",
"layers/serialization.py",
"layers/wrappers.py",
"utils/kernelized_utils.py",
Expand Down Expand Up @@ -961,6 +962,22 @@ tf_py_test(
],
)

tf_py_test(
name = "rnn_cell_wrapper_v2_test",
size = "medium",
srcs = ["layers/rnn_cell_wrapper_v2_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
shard_count = 4,
tags = [
"notsan",
],
)

tf_py_test(
name = "time_distributed_learning_phase_test",
size = "small",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/python/keras/layers/__init__.py
Expand Up @@ -177,6 +177,11 @@
from tensorflow.python.keras.layers.wrappers import Bidirectional
from tensorflow.python.keras.layers.wrappers import TimeDistributed

# # RNN Cell wrappers.
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DeviceWrapper
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper

# Serialization functions
from tensorflow.python.keras.layers.serialization import deserialize
from tensorflow.python.keras.layers.serialization import serialize
Expand Down
122 changes: 122 additions & 0 deletions tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py
@@ -0,0 +1,122 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module implementing for RNN wrappers for TF v2."""

# Note that all the APIs under this module are exported as tf.nn.*. This is due
# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
# here to avoid the cyclic dependency issue for serialization. These APIs will
# probably be deprecated and removed in future since similar API is available in
# existing Keras RNN API.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from tensorflow.python.keras.layers import AbstractRNNCell
from tensorflow.python.ops import rnn_cell_wrapper_impl
from tensorflow.python.util.tf_export import tf_export


class _RNNCellWrapperV2(AbstractRNNCell):
"""Base class for cells wrappers V2 compatibility.
This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
wrappers that are compatible with V1 and V2, and defines helper methods for
this purpose.
"""

def __init__(self, cell, *args, **kwargs):
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
self.cell = cell

def call(self, inputs, state, **kwargs):
"""Runs the RNN cell step computation.
When `call` is being used, we assume that the wrapper object has been built,
and therefore the wrapped cells has been built via its `build` method and
its `call` method can be used directly.
This allows to use the wrapped cell and the non-wrapped cell equivalently
when using `call` and `build`.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
**kwargs: Additional arguments passed to the wrapped cell's `call`.
Returns:
A pair containing:
- Output: A tensor with cell's output.
- New state: A tensor or tuple of tensors with new wrapped cell's state.
"""
return self._call_wrapped_cell(
inputs, state, cell_call_fn=self.cell.call, **kwargs)

def build(self, inputs_shape):
"""Builds the wrapped cell."""
self.cell.build(inputs_shape)
self.built = True

def get_config(self):
config = {
"cell": {
"class_name": self.cell.__class__.__name__,
"config": self.cell.get_config()
},
}
base_config = super(_RNNCellWrapperV2, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects)
return cls(cell, **config)


@tf_export("nn.RNNCellDropoutWrapper", v1=[])
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
_RNNCellWrapperV2):
"""Operator adding dropout to inputs and outputs of the given cell."""

def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DropoutWrapper, self).__init__(*args, **kwargs)

__init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__


@tf_export("nn.RNNCellResidualWrapper", v1=[])
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
_RNNCellWrapperV2):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""

def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(ResidualWrapper, self).__init__(*args, **kwargs)

__init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__


@tf_export("nn.RNNCellDeviceWrapper", v1=[])
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
_RNNCellWrapperV2):
"""Operator that ensures an RNNCell runs on a particular device."""

def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DeviceWrapper, self).__init__(*args, **kwargs)

__init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__

0 comments on commit e62dc43

Please sign in to comment.