diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py index 3bbc4b39c..abffb188b 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py @@ -121,6 +121,7 @@ def DynamicEmbeddingOptimizer(self, bp_v2=False, synchronous=False, **kwargs): if hasattr(self, 'add_variable_from_reference'): original_add_variable_from_reference = self.add_variable_from_reference + # pylint: disable=protected-access def _distributed_apply(distribution, grads_and_vars, name, apply_state): """`apply_gradients` using a `DistributionStrategy`.""" @@ -208,9 +209,8 @@ def apply_grad_to_update_var(var, grad): args=(grad,), group=False) replica_context = distribute_ctx.get_replica_context() - # pylint: disable=protected-access - if (replica_context is None or replica_context is - distribute_ctx._get_default_replica_context()): + if (replica_context is None or replica_context + is distribute_ctx._get_default_replica_context()): # In cross-replica context, extended.update returns a list of # update ops from all replicas (group=False). update_ops.extend(update_op) diff --git a/tensorflow_recommenders_addons/utils/resource_loader.py b/tensorflow_recommenders_addons/utils/resource_loader.py index 01daf643f..a0e5a1c9e 100644 --- a/tensorflow_recommenders_addons/utils/resource_loader.py +++ b/tensorflow_recommenders_addons/utils/resource_loader.py @@ -18,6 +18,7 @@ import pkg_resources import tensorflow as tf import warnings +from packaging.version import parse as parse_version abi_warning_already_raised = False SKIP_CUSTOM_OPS = False @@ -37,13 +38,32 @@ def get_required_tf_version(): "TFRA installation.", UserWarning, ) - return tf.__version__ + return tf.__version__, tf.__version__ pkg_info = pkg.requires() + low_version, high_version = None, None + for x in pkg_info: if x.name in ["tensorflow", "tensorflow-gpu"]: - return x.specs[0][1] - assert False, "Fail to get required TensorFlow version of TFRA!" + for spec in x.specs: + if spec[0] == ">=": + low_version = spec[1] + elif spec[0] == "<=": + high_version = spec[1] + if low_version and high_version: + return low_version, high_version + + assert False, f"Fail to get required TensorFlow version of TFRA: {pkg_info[0]} {low_version} {high_version}" + + +def abi_is_compatible(): + if "dev" in tf.__version__: + return False + low_version, high_version = get_required_tf_version() + + current_version = parse_version(tf.__version__) + return parse_version(low_version) <= current_version <= parse_version( + high_version) def get_devices(device_type="GPU"): @@ -138,14 +158,6 @@ def display_warning_if_incompatible(self): abi_warning_already_raised = True -def abi_is_compatible(): - if "dev" in tf.__version__: - return False - - required_tf_version = get_required_tf_version() - return tf.__version__ == required_tf_version - - def prefix_op_name(op_name): """ In order to keep compatibility of existing models, diff --git a/tensorflow_recommenders_addons/utils/tests/test_resource_loader.py b/tensorflow_recommenders_addons/utils/tests/test_resource_loader.py new file mode 100644 index 000000000..33530d2c2 --- /dev/null +++ b/tensorflow_recommenders_addons/utils/tests/test_resource_loader.py @@ -0,0 +1,54 @@ +import unittest +from unittest.mock import patch, Mock + +import pkg_resources +import tensorflow as tf + +from tensorflow_recommenders_addons.utils.resource_loader import abi_is_compatible + + +class TestTensorFlowCompatibility(unittest.TestCase): + + @patch('pkg_resources.get_distribution') + @patch('tensorflow.__version__', '2.12.0') + def test_compatible_version(self, mock_get_distribution): + mock_pkg = Mock() + mock_requirement = Mock() + mock_requirement.name = 'tensorflow' + mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')] + mock_pkg.requires.return_value = [mock_requirement] + mock_get_distribution.return_value = mock_pkg + self.assertTrue(abi_is_compatible()) + + @patch('pkg_resources.get_distribution') + @patch('tensorflow.__version__', '2.10.0') + def test_incompatible_version_below_range(self, mock_get_distribution): + mock_pkg = Mock() + mock_requirement = Mock() + mock_requirement.name = 'tensorflow' + mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')] + mock_pkg.requires.return_value = [mock_requirement] + mock_get_distribution.return_value = mock_pkg + self.assertFalse(abi_is_compatible()) + + @patch('pkg_resources.get_distribution') + @patch('tensorflow.__version__', '2.16.0') + def test_incompatible_version_above_range(self, mock_get_distribution): + mock_pkg = Mock() + mock_requirement = Mock() + mock_requirement.name = 'tensorflow' + mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')] + mock_pkg.requires.return_value = [mock_requirement] + mock_get_distribution.return_value = mock_pkg + self.assertFalse(abi_is_compatible()) + + @patch('pkg_resources.get_distribution') + @patch('tensorflow.__version__', '2.13.0-dev20240528') + def test_dev_version(self, mock_get_distribution): + mock_pkg = Mock() + mock_requirement = Mock() + mock_requirement.name = 'tensorflow' + mock_requirement.specs = [('>=', '2.11.0'), ('<=', '2.15.1')] + mock_pkg.requires.return_value = [mock_requirement] + mock_get_distribution.return_value = mock_pkg + self.assertFalse(abi_is_compatible()) diff --git a/tensorflow_recommenders_addons/utils/types.py b/tensorflow_recommenders_addons/utils/types.py index 945a82b39..d30a2d72b 100644 --- a/tensorflow_recommenders_addons/utils/types.py +++ b/tensorflow_recommenders_addons/utils/types.py @@ -13,15 +13,28 @@ # limitations under the License. # ============================================================================== """Types for typing functions signatures.""" +# pylint: disable=protected-access from typing import Union, Callable, List import numpy as np import tensorflow as tf -Number = Union[float, int, np.float16, np.float32, np.float64, np.int8, - np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, - np.uint64,] +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] Initializer = Union[None, dict, str, Callable] Regularizer = Union[None, dict, str, Callable] @@ -29,7 +42,15 @@ Activation = Union[None, str, Callable] Optimizer = Union[tf.keras.optimizers.Optimizer, str] -TensorLike = Union[List[Union[Number, list]], tuple, Number, np.ndarray, - tf.Tensor, tf.SparseTensor, tf.Variable,] +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, +] FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64] AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None] +# pylint: enable=protected-access