Skip to content

Commit

Permalink
Keras SavedModel: Ignore custom metrics failure when compile=False
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed Jan 11, 2021
1 parent df05368 commit c4e6c63
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
22 changes: 17 additions & 5 deletions tensorflow/python/keras/saving/saved_model/load.py
Expand Up @@ -135,7 +135,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin

# Recreate layers and metrics using the info stored in the metadata.
keras_loader = KerasObjectLoader(metadata, object_graph_def)
keras_loader.load_layers()
keras_loader.load_layers(compile=compile)

# Generate a dictionary of all loaded nodes.
nodes_to_load = {'root': None}
Expand Down Expand Up @@ -360,7 +360,7 @@ def _add_children_recreated_from_config(self, obj, proto, node_id):
obj_child, child_proto, child_id)
self.loaded_nodes[child_id] = obj_child, setter

def load_layers(self):
def load_layers(self, compile=True): # pylint: disable=redefined-builtin
"""Load all layer nodes from the metadata."""
# Load metrics after models and layers, since it's likely that models
# and layers will create the metric when initialized (this avoids wasting
Expand All @@ -376,9 +376,21 @@ def load_layers(self):
node_metadata.metadata)

for node_metadata in metric_list:
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
node_metadata.node_id, node_metadata.identifier,
node_metadata.metadata)
try:
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
node_metadata.node_id, node_metadata.identifier,
node_metadata.metadata)
except ValueError:
# Metrics are only needed when the model is compiled later. We ignore
# errors when trying to load custom metrics when `compile=False` until
# custom metrics are serialized properly (b/135550038).
if compile:
raise
logging.warning('Unable to restore custom metric. Please ensure that '
'the layer implements `get_config` and `from_config` '
'when saving. In addition, please use the '
'`custom_objects` arg when calling `load_model()`.')


def _load_layer(self, node_id, identifier, metadata):
"""Load a single layer from a SavedUserObject proto."""
Expand Down
20 changes: 20 additions & 0 deletions tensorflow/python/keras/saving/saved_model/saved_model_test.py
Expand Up @@ -1147,6 +1147,26 @@ def update_state(self, value):
self._test_metric_save_and_load(
metric, self._save_model_dir(), 1, test_sample_weight=False)

@keras_parameterized.run_with_all_model_types
def test_custom_metric_model(self):

class CustomMetric(keras.metrics.MeanSquaredError):
pass

model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.compile(
loss='mse',
optimizer='rmsprop',
metrics=[CustomMetric()])

saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
with self.assertRaisesRegex(ValueError, 'custom_objects'):
keras_load.load(saved_model_dir)

keras_load.load(saved_model_dir, compile=False)



if __name__ == '__main__':
test.main()

0 comments on commit c4e6c63

Please sign in to comment.