Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tf-addons for upcoming keras 3 default. #2858

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorflow_addons/image/tests/distort_image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_adjust_random_hue_in_yiq(shape, style, dtype):
y_np = _adjust_hue_in_yiq_np(x_np, delta_h)
y_tf = _adjust_hue_in_yiq_tf(x_np, delta_h)
test_utils.assert_allclose_according_to_type(
y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=0.8
y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=1.1
)


Expand All @@ -121,11 +121,11 @@ def test_invalid_channels_hsv():

def test_adjust_hsv_in_yiq_unknown_shape():
fn = tf.function(distort_image_ops.adjust_hsv_in_yiq).get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float64)
tf.TensorSpec(shape=None, dtype=tf.float32)
)
for shape in (2, 3, 3), (4, 2, 3, 3):
image_np = np.random.rand(*shape) * 255.0
image_tf = tf.constant(image_np)
image_tf = tf.constant(image_np, dtype=tf.float32)
np.testing.assert_allclose(
_adjust_hue_in_yiq_np(image_np, 0), fn(image_tf), rtol=2e-4, atol=1e-4
)
Expand Down
17 changes: 14 additions & 3 deletions tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,20 @@
from tensorflow_addons.optimizers import KerasLegacyOptimizer
from typeguard import typechecked

if Version(tf.__version__).release >= Version("2.13").release:
# New versions of Keras require importing from `keras.src` when
# importing internal symbols.
if Version(tf.__version__).release >= Version("2.16").release:
# Determine if loading keras 2 or 3.
if (
hasattr(tf.keras, "version")
and Version(tf.keras.version()).release >= Version("3.0").release
):
# New versions of Keras require importing from `keras.src` when
# importing internal symbols.
from keras.src import backend
from keras.src.utils import tf_utils
else:
from tf_keras.src import backend
from tf_keras.src.utils import tf_utils
elif Version(tf.__version__).release >= Version("2.13").release:
from keras.src import backend
from keras.src.utils import tf_utils
else:
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_addons/optimizers/lazy_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,6 @@ def _resource_scatter_operate(self, resource, indices, update, resource_scatter_
}

return resource_scatter_op(**resource_update_kwargs)

def get_config(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with keras 3 it throws an error that LazyAdam doesn't have a get_config method.

return super().get_config()
133 changes: 133 additions & 0 deletions tensorflow_addons/rnn/abstract_rnn_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for RNN cells.

Adapted from legacy github.com/keras-team/tf-keras.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason this is needed? Is it because of the usage of tf.keras.layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AbstractRNNCell only exists in tf_keras (it previously existed in keras 2, but doesn't exist in keras 3).

"""

import tensorflow as tf


def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
if inputs is not None:
batch_size = tf.shape(inputs)[0]
dtype = inputs.dtype
return _generate_zero_filled_state(batch_size, cell.state_size, dtype)


def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
"""Generate a zero filled tensor with shape [batch_size, state_size]."""
if batch_size_tensor is None or dtype is None:
raise ValueError(
"batch_size and dtype cannot be None while constructing initial state: "
"batch_size={}, dtype={}".format(batch_size_tensor, dtype)
)

def create_zeros(unnested_state_size):
flat_dims = tf.TensorShape(unnested_state_size).as_list()
init_state_size = [batch_size_tensor] + flat_dims
return tf.zeros(init_state_size, dtype=dtype)

if tf.nest.is_nested(state_size):
return tf.nest.map_structure(create_zeros, state_size)
else:
return create_zeros(state_size)


class AbstractRNNCell(tf.keras.layers.Layer):
"""Abstract object representing an RNN cell.

This is a base class for implementing RNN cells with custom behavior.

Every `RNNCell` must have the properties below and implement `call` with
the signature `(output, next_state) = call(input, state)`.

Examples:

```python
class MinimalRNNCell(AbstractRNNCell):

def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)

@property
def state_size(self):
return self.units

def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True

def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, output
```

This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.

An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns. If `self.state_size` is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape `[batch_size].concatenate(s)`
for each `s` in `self.batch_size`.
"""

def call(self, inputs, states):
"""The function that contains the logic for one RNN step calculation.

Args:
inputs: the input tensor, which is a slide from the overall RNN input by
the time dimension (usually the second dimension).
states: the state tensor from previous step, which has the same shape
as `(batch, state_size)`. In the case of timestep 0, it will be the
initial state user specified, or zero filled tensor otherwise.

Returns:
A tuple of two tensors:
1. output tensor for the current timestep, with size `output_size`.
2. state tensor for next step, which has the shape of `state_size`.
"""
raise NotImplementedError("Abstract method")

@property
def state_size(self):
"""size(s) of state(s) used by this cell.

It can be represented by an Integer, a TensorShape or a tuple of Integers
or TensorShapes.
"""
raise NotImplementedError("Abstract method")

@property
def output_size(self):
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError("Abstract method")

def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
4 changes: 2 additions & 2 deletions tensorflow_addons/rnn/esn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
"""Implements ESN Cell."""

import tensorflow as tf
import tensorflow.keras as keras
from typeguard import typechecked

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils.types import (
Activation,
Initializer,
)


@tf.keras.utils.register_keras_serializable(package="Addons")
class ESNCell(keras.layers.AbstractRNNCell):
class ESNCell(AbstractRNNCell):
"""Echo State recurrent Network (ESN) cell.
This implements the recurrent cell from the paper:
H. Jaeger
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_addons/rnn/nas_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"""Implements NAS Cell."""

import tensorflow as tf
import tensorflow.keras as keras
from typeguard import typechecked

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils.types import (
FloatTensorLike,
TensorLike,
Expand All @@ -27,7 +27,7 @@


@tf.keras.utils.register_keras_serializable(package="Addons")
class NASCell(keras.layers.AbstractRNNCell):
class NASCell(AbstractRNNCell):
"""Neural Architecture Search (NAS) recurrent network cell.

This implements the recurrent cell from the paper:
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/seq2seq/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ py_library(
"//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so",
],
deps = [
"//tensorflow_addons/rnn",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
],
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/seq2seq/attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tensorflow as tf

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.types import (
AcceptableDTypes,
Expand Down Expand Up @@ -1577,7 +1578,7 @@ def _compute_attention(
return attention, alignments, next_attention_state


class AttentionWrapper(tf.keras.layers.AbstractRNNCell):
class AttentionWrapper(AbstractRNNCell):
"""Wraps another RNN cell with attention.

Example:
Expand Down
16 changes: 7 additions & 9 deletions tensorflow_addons/text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ package(default_visibility = ["//visibility:public"])
py_library(
name = "text",
srcs = glob(["*.py"]),
data = select({
"//tensorflow_addons:windows": [
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
],
data = [
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
"//tensorflow_addons/rnn",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
] + select({
"//tensorflow_addons:windows": [],
"//conditions:default": [
"//tensorflow_addons/custom_ops/text:_parse_time_op.so",
"//tensorflow_addons/custom_ops/text:_skip_gram_ops.so",
"//tensorflow_addons/testing",
"//tensorflow_addons/utils",
],
}),
)
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/text/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import tensorflow as tf

from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils.types import TensorLike
from typeguard import typechecked
from typing import Optional, Tuple
Expand Down Expand Up @@ -403,7 +404,7 @@ def viterbi_decode(score: TensorLike, transition_params: TensorLike) -> tf.Tenso
return viterbi, viterbi_score


class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
class CrfDecodeForwardRnnCell(AbstractRNNCell):
"""Computes the forward decoding in a linear-chain CRF."""

@typechecked
Expand Down
10 changes: 1 addition & 9 deletions tensorflow_addons/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,10 @@
import pytest
import tensorflow as tf

from packaging.version import Version
from tensorflow_addons import options
from tensorflow_addons.utils import resource_loader

if Version(tf.__version__).release >= Version("2.13").release:
# New versions of Keras require importing from `keras.src` when
# importing internal symbols.
from keras.src.testing_infra.test_utils import layer_test # noqa: F401
elif Version(tf.__version__) >= Version("2.9"):
from keras.testing_infra.test_utils import layer_test # noqa: F401
else:
from keras.testing_utils import layer_test # noqa: F401
from tensorflow_addons.utils.tf_test_utils import layer_test # noqa

NUMBER_OF_WORKERS = int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1"))
WORKER_ID = int(os.environ.get("PYTEST_XDIST_WORKER", "gw0")[2])
Expand Down
Loading
Loading