Skip to content

Commit

Permalink
Rolling back change that is no longer needed.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571267300
  • Loading branch information
pabloduque0 authored and Copybara-Service committed Oct 6, 2023
1 parent c83a236 commit 272400f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 42 deletions.
20 changes: 3 additions & 17 deletions tensorflow_hub/keras_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,6 @@ class KerasLayer(tf.keras.layers.Layer):
load_options: Optional, `tf.saved_model.LoadOptions` object that specifies
options for loading when a Python string is provided as `handle`. This
argument can only be used from TensorFlow 2.3 onwards.
force_keras_loading: Whether model should be wrapped around
tf.keras.models.Model which is equivalent as being loaded as
tf.keras.models.load_model.
**kwargs: Forwarded to Keras' base Layer constructor.
"""

Expand All @@ -138,7 +135,6 @@ def __init__(
output_key=None,
output_shape=None,
load_options=None,
force_keras_loading=False,
**kwargs):
# Note: for compatibility with keras-model serialization this layer is
# json-serializable. If you add or change arguments here, please also update
Expand All @@ -158,12 +154,7 @@ def __init__(
_convert_nest_to_shapes(output_shape))

self._load_options = load_options
self._func = load_module(
handle,
tags,
self._load_options,
force_keras_loading=force_keras_loading,
)
self._func = load_module(handle, tags, self._load_options)
self._is_hub_module_v1 = getattr(self._func, "_is_hub_module_v1", False)

# Update with the defaults when using legacy TF1 Hub format.
Expand Down Expand Up @@ -437,8 +428,7 @@ def _shape_as_tuple(x):
return tf.nest.map_structure(_shape_as_tuple, x)


def load_module(
handle, tags=None, load_options=None, force_keras_loading=False):
def load_module(handle, tags=None, load_options=None):
if callable(handle):
if tags is not None:
raise ValueError("Passing a callable handle is mutually exclusive "
Expand Down Expand Up @@ -466,11 +456,7 @@ def load_module(
set_load_options = load_options or load_context.get_load_options()
except ImportError: # Expected before TF2.4.
set_load_options = load_options
return module_v2.load(
handle, tags=tags,
options=set_load_options,
force_keras_loading=force_keras_loading,
)
return module_v2.load(handle, tags=tags, options=set_load_options)


def func_has_training_argument(func):
Expand Down
11 changes: 1 addition & 10 deletions tensorflow_hub/module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def resolve(handle):
return registry.resolver(handle)


def load(handle, tags=None, options=None, force_keras_loading=False):
def load(handle, tags=None, options=None):
"""Resolves a handle and loads the resulting module.
This is the preferred API to load a Hub module in low-level TensorFlow 2.
Expand Down Expand Up @@ -80,9 +80,6 @@ def load(handle, tags=None, options=None, force_keras_loading=False):
options: Optional, `tf.saved_model.LoadOptions` object that specifies
options for loading. This argument can only be used from TensorFlow 2.3
onwards.
force_keras_loading: Whether model should be wrapped around
tf.keras.models.Model which is equivalent as being loaded as
tf.keras.models.load_model.
Returns:
A trackable object (see tf.saved_model.load() documentation for details).
Expand All @@ -99,9 +96,6 @@ def load(handle, tags=None, options=None, force_keras_loading=False):
if tags is None and is_hub_module_v1:
tags = []

if force_keras_loading and is_hub_module_v1:
raise ValueError("`force_keras_loading` is not supported for v1 modules.")

saved_model_path = os.path.join(
tf.compat.as_bytes(module_path),
tf.compat.as_bytes(tf.saved_model.SAVED_MODEL_FILENAME_PB))
Expand All @@ -124,8 +118,5 @@ def load(handle, tags=None, options=None, force_keras_loading=False):
module_path, tags=tags, options=options)
else:
obj = tf.compat.v1.saved_model.load_v2(module_path, tags=tags)

if force_keras_loading:
obj = tf.keras.models.Model(obj)
obj._is_hub_module_v1 = is_hub_module_v1 # pylint: disable=protected-access
return obj
15 changes: 0 additions & 15 deletions tensorflow_hub/module_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,6 @@ def test_load_without_string(self):
with self.assertRaisesRegex(ValueError, 'Expected a string, got.*'):
module_v2.load(0)

def test_load_force_keras_loading_loads_keras_model_instance(self):
export_dir = os.path.join(self.get_temp_dir(), 'keras_model')
_save_plus_one_saved_model_v2(export_dir)

loaded_model = module_v2.load(export_dir, force_keras_loading=True)

self.assertIsInstance(loaded_model, tf.keras.Model)

def test_load_force_keras_loading_raises_exception_on_tf1(self):
export_dir = os.path.join(self.get_temp_dir(), 'keras_model')
exception_mesage = '`force_keras_loading` is not supported for v1 modules.'
_save_plus_one_hub_module_v1(export_dir)

with self.assertRaisesRegex(ValueError, exception_mesage):
module_v2.load(export_dir, force_keras_loading=True)

if __name__ == '__main__':
# In TF 1.15.x, we need to enable V2-like behavior, notably eager execution.
Expand Down

0 comments on commit 272400f

Please sign in to comment.