Skip to content

Commit

Permalink
Fix saving/loading issue for sparse accuracy metrics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 329566746
Change-Id: Ic933352c1f0c5b131fade91e14db45828767e9e9
  • Loading branch information
k-w-w authored and tensorflower-gardener committed Sep 1, 2020
1 parent 8744e4b commit 5adacc8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 6 deletions.
26 changes: 22 additions & 4 deletions tensorflow/python/keras/engine/training.py
Expand Up @@ -2582,15 +2582,33 @@ def _should_eval(self, epoch, validation_freq):
# Functions below exist only as v1 / v2 compatibility shims.
######################################################################

def _get_compile_args(self):
"""Used for saving or cloning a Model."""
def _get_compile_args(self, user_metrics=True):
"""Used for saving or cloning a Model.
Args:
user_metrics: Whether to return user-supplied metrics or `Metric` objects.
Defaults to returning the user-supplied metrics.
Returns:
Dictionary of arguments that were used when compiling the model.
"""
self._assert_compile_was_called()
# pylint: disable=protected-access

saved_metrics = self.compiled_metrics._user_metrics
saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics

if not user_metrics:
if saved_metrics is not None:
saved_metrics = self.compiled_metrics._metrics
if saved_weighted_metrics is not None:
saved_weighted_metrics = self.compiled_metrics._weighted_metrics

compile_args = {
'optimizer': self.optimizer,
'loss': self.compiled_loss._user_losses,
'metrics': self.compiled_metrics._user_metrics,
'weighted_metrics': self.compiled_metrics._user_weighted_metrics,
'metrics': saved_metrics,
'weighted_metrics': saved_weighted_metrics,
'loss_weights': self.compiled_loss._user_loss_weights,
}
# pylint: enable=protected-access
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/keras/engine/training_v1.py
Expand Up @@ -2810,7 +2810,8 @@ def _get_distribution_strategy(self):
def _trackable_saved_model_saver(self):
return model_serialization.ModelSavedModelSaver(self)

def _get_compile_args(self):
def _get_compile_args(self, user_metrics=True):
del user_metrics
self._assert_compile_was_called()
kwargs = {
'loss': self.loss,
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/python/keras/saving/saved_model/revive_test.py
Expand Up @@ -295,6 +295,29 @@ def test_revive_network(self, model_cls):
revived = keras_load.load(self.path, compile=False)
self._assert_revived_correctness(model, revived)

def test_load_compiled_metrics(self):
model = testing_utils.get_small_sequential_mlp(1, 3)

# Compile with dense categorical accuracy
model.compile('rmsprop', 'mse', 'acc')
x = np.random.random((5, 10)).astype(np.float32)
y_true = np.random.random((5, 3)).astype(np.float32)
model.train_on_batch(x, y_true)

model.save(self.path, include_optimizer=True, save_format='tf')
revived = keras_load.load(self.path, compile=True)
self.assertAllClose(model.test_on_batch(x, y_true),
revived.test_on_batch(x, y_true))

# Compile with sparse categorical accuracy
model.compile('rmsprop', 'mse', 'acc')
y_true = np.random.randint(0, 3, (5, 1)).astype(np.float32)
model.train_on_batch(x, y_true)
model.save(self.path, include_optimizer=True, save_format='tf')
revived = keras_load.load(self.path, compile=True)
self.assertAllClose(model.test_on_batch(x, y_true),
revived.test_on_batch(x, y_true))


if __name__ == '__main__':
ops.enable_eager_execution()
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/keras/saving/saving_utils.py
Expand Up @@ -173,7 +173,7 @@ def model_metadata(model, include_optimizer=True, require_config=True):
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
elif model._compile_was_called: # pylint: disable=protected-access
training_config = model._get_compile_args() # pylint: disable=protected-access
training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access
training_config.pop('optimizer', None) # Handled separately.
metadata['training_config'] = _serialize_nested_config(training_config)
if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
Expand Down

0 comments on commit 5adacc8

Please sign in to comment.