diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 52d73ada157693..c1a9d9da8f8bdb 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -2582,15 +2582,33 @@ def _should_eval(self, epoch, validation_freq): # Functions below exist only as v1 / v2 compatibility shims. ###################################################################### - def _get_compile_args(self): - """Used for saving or cloning a Model.""" + def _get_compile_args(self, user_metrics=True): + """Used for saving or cloning a Model. + + Args: + user_metrics: Whether to return user-supplied metrics or `Metric` objects. + Defaults to returning the user-supplied metrics. + + Returns: + Dictionary of arguments that were used when compiling the model. + """ self._assert_compile_was_called() # pylint: disable=protected-access + + saved_metrics = self.compiled_metrics._user_metrics + saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics + + if not user_metrics: + if saved_metrics is not None: + saved_metrics = self.compiled_metrics._metrics + if saved_weighted_metrics is not None: + saved_weighted_metrics = self.compiled_metrics._weighted_metrics + compile_args = { 'optimizer': self.optimizer, 'loss': self.compiled_loss._user_losses, - 'metrics': self.compiled_metrics._user_metrics, - 'weighted_metrics': self.compiled_metrics._user_weighted_metrics, + 'metrics': saved_metrics, + 'weighted_metrics': saved_weighted_metrics, 'loss_weights': self.compiled_loss._user_loss_weights, } # pylint: enable=protected-access diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 2c38de926669a7..41c86f3e6d72dc 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -2810,7 +2810,8 @@ def _get_distribution_strategy(self): def _trackable_saved_model_saver(self): return model_serialization.ModelSavedModelSaver(self) - def _get_compile_args(self): + def _get_compile_args(self, user_metrics=True): + del user_metrics self._assert_compile_was_called() kwargs = { 'loss': self.loss, diff --git a/tensorflow/python/keras/saving/saved_model/revive_test.py b/tensorflow/python/keras/saving/saved_model/revive_test.py index 5c4f8a23cfe493..693568d484834b 100644 --- a/tensorflow/python/keras/saving/saved_model/revive_test.py +++ b/tensorflow/python/keras/saving/saved_model/revive_test.py @@ -295,6 +295,29 @@ def test_revive_network(self, model_cls): revived = keras_load.load(self.path, compile=False) self._assert_revived_correctness(model, revived) + def test_load_compiled_metrics(self): + model = testing_utils.get_small_sequential_mlp(1, 3) + + # Compile with dense categorical accuracy + model.compile('rmsprop', 'mse', 'acc') + x = np.random.random((5, 10)).astype(np.float32) + y_true = np.random.random((5, 3)).astype(np.float32) + model.train_on_batch(x, y_true) + + model.save(self.path, include_optimizer=True, save_format='tf') + revived = keras_load.load(self.path, compile=True) + self.assertAllClose(model.test_on_batch(x, y_true), + revived.test_on_batch(x, y_true)) + + # Compile with sparse categorical accuracy + model.compile('rmsprop', 'mse', 'acc') + y_true = np.random.randint(0, 3, (5, 1)).astype(np.float32) + model.train_on_batch(x, y_true) + model.save(self.path, include_optimizer=True, save_format='tf') + revived = keras_load.load(self.path, compile=True) + self.assertAllClose(model.test_on_batch(x, y_true), + revived.test_on_batch(x, y_true)) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index 9fdf81cae2a092..0c3d044e80d643 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -173,7 +173,7 @@ def model_metadata(model, include_optimizer=True, require_config=True): 'Prefer using a Keras optimizer instead ' '(see keras.io/optimizers).') elif model._compile_was_called: # pylint: disable=protected-access - training_config = model._get_compile_args() # pylint: disable=protected-access + training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access training_config.pop('optimizer', None) # Handled separately. metadata['training_config'] = _serialize_nested_config(training_config) if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):