Skip to content

Commit

Permalink
fix version check
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed May 29, 2024
1 parent fb1ca07 commit 627d884
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def DynamicEmbeddingOptimizer(self, bp_v2=False, synchronous=False, **kwargs):
if hasattr(self, 'add_variable_from_reference'):
original_add_variable_from_reference = self.add_variable_from_reference

# pylint: disable=protected-access
def _distributed_apply(distribution, grads_and_vars, name, apply_state):
"""`apply_gradients` using a `DistributionStrategy`."""

Expand Down Expand Up @@ -208,9 +209,8 @@ def apply_grad_to_update_var(var, grad):
args=(grad,),
group=False)
replica_context = distribute_ctx.get_replica_context()
# pylint: disable=protected-access
if (replica_context is None or replica_context is
distribute_ctx._get_default_replica_context()):
if (replica_context is None or replica_context
is distribute_ctx._get_default_replica_context()):
# In cross-replica context, extended.update returns a list of
# update ops from all replicas (group=False).
update_ops.extend(update_op)
Expand Down
34 changes: 23 additions & 11 deletions tensorflow_recommenders_addons/utils/resource_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pkg_resources
import tensorflow as tf
import warnings
from packaging.version import parse as parse_version

abi_warning_already_raised = False
SKIP_CUSTOM_OPS = False
Expand All @@ -37,13 +38,32 @@ def get_required_tf_version():
"TFRA installation.",
UserWarning,
)
return tf.__version__
return tf.__version__, tf.__version__

pkg_info = pkg.requires()
low_version, high_version = None, None

for x in pkg_info:
if x.name in ["tensorflow", "tensorflow-gpu"]:
return x.specs[0][1]
assert False, "Fail to get required TensorFlow version of TFRA!"
for spec in x.specs:
if spec[0] == ">=":
low_version = spec[1]
elif spec[0] == "<=":
high_version = spec[1]
if low_version and high_version:
return low_version, high_version

assert False, f"Fail to get required TensorFlow version of TFRA: {pkg_info[0]} {low_version} {high_version}"


def abi_is_compatible():
if "dev" in tf.__version__:
return False
low_version, high_version = get_required_tf_version()

current_version = parse_version(tf.__version__)
return parse_version(low_version) <= current_version <= parse_version(
high_version)


def get_devices(device_type="GPU"):
Expand Down Expand Up @@ -138,14 +158,6 @@ def display_warning_if_incompatible(self):
abi_warning_already_raised = True


def abi_is_compatible():
if "dev" in tf.__version__:
return False

required_tf_version = get_required_tf_version()
return tf.__version__ == required_tf_version


def prefix_op_name(op_name):
"""
In order to keep compatibility of existing models,
Expand Down
54 changes: 54 additions & 0 deletions tensorflow_recommenders_addons/utils/tests/test_resource_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
from unittest.mock import patch, Mock

import pkg_resources
import tensorflow as tf

from tensorflow_recommenders_addons.utils.resource_loader import abi_is_compatible


class TestTensorFlowCompatibility(unittest.TestCase):

@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.12.0')
def test_compatible_version(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertTrue(abi_is_compatible())

@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.10.0')
def test_incompatible_version_below_range(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertFalse(abi_is_compatible())

@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.16.0')
def test_incompatible_version_above_range(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertFalse(abi_is_compatible())

@patch('pkg_resources.get_distribution')
@patch('tensorflow.__version__', '2.13.0-dev20240528')
def test_dev_version(self, mock_get_distribution):
mock_pkg = Mock()
mock_requirement = Mock()
mock_requirement.name = 'tensorflow'
mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')]
mock_pkg.requires.return_value = [mock_requirement]
mock_get_distribution.return_value = mock_pkg
self.assertFalse(abi_is_compatible())
31 changes: 26 additions & 5 deletions tensorflow_recommenders_addons/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,44 @@
# limitations under the License.
# ==============================================================================
"""Types for typing functions signatures."""
# pylint: disable=protected-access

from typing import Union, Callable, List

import numpy as np
import tensorflow as tf

Number = Union[float, int, np.float16, np.float32, np.float64, np.int8,
np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
np.uint64,]
Number = Union[
float,
int,
np.float16,
np.float32,
np.float64,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]

Initializer = Union[None, dict, str, Callable]
Regularizer = Union[None, dict, str, Callable]
Constraint = Union[None, dict, str, Callable]
Activation = Union[None, str, Callable]
Optimizer = Union[tf.keras.optimizers.Optimizer, str]

TensorLike = Union[List[Union[Number, list]], tuple, Number, np.ndarray,
tf.Tensor, tf.SparseTensor, tf.Variable,]
TensorLike = Union[
List[Union[Number, list]],
tuple,
Number,
np.ndarray,
tf.Tensor,
tf.SparseTensor,
tf.Variable,
]
FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None]
# pylint: enable=protected-access

0 comments on commit 627d884

Please sign in to comment.