Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise ValueError when saving a model created in mirroredstrategy #40366

Closed
djdongjin opened this issue Jun 10, 2020 · 16 comments
Closed

Raise ValueError when saving a model created in mirroredstrategy #40366

djdongjin opened this issue Jun 10, 2020 · 16 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 type:bug Bug

Comments

@djdongjin
Copy link

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Yes
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): tf-nightly
  • Python version: 3.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
The model is created inside a mirroredstrategy. When I save the model using model.save(save_path) after training, it raises ValueError: SyncOnReadVariable does not support assign_addin cross-replica context when aggregation is set totf.VariableAggregation.SUM. The error is triggered here. The related error tracing is:

  File "ncf_keras_main.py", line 85, in call
    self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1678, in add_metric
    metric_obj(value)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 231, in __call__
    replica_local_fn, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py", line 1133, in call_replica_local_fn
    return fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 211, in replica_local_fn
    update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py", line 90, in decorated
    update_op = update_state_fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 176, in update_state_fn
    return ag_update_state(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 373, in update_state
    update_total_op = self.total.assign_add(value_sum)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/distribute/values.py", line 918, in assign_add
    "SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.

I also attached a complete tracing for your reference.

Describe the expected behavior

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

Traceback (most recent call last):
  File "ncf_keras_main.py", line 568, in <module>
    app.run(main)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "ncf_keras_main.py", line 563, in main
    logging.info("Result is %s", run_ncf(FLAGS))
  File "ncf_keras_main.py", line 351, in run_ncf
    keras_model.save("save_model")
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1950, in save
    signatures, options)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py", line 134, in save_model
    signatures, options)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 78, in save
    save_lib.save(model, filepath, signatures, options)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 953, in save
    obj, export_dir, signatures, options, meta_graph_def)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1015, in _build_meta_graph
    checkpoint_graph_view)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 144, in list_functions
    self._serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 2543, in _list_functions_for_serialization
    Model, self)._list_functions_for_serialization(serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3014, in _list_functions_for_serialization
    .list_functions_for_serialization(serialization_cache))
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 87, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 77, in functions_to_serialize
    serialization_cache).functions_to_serialize)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 92, in _get_serialized_attributes
    serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 51, in _get_serialized_attributes_internal
    default_signature = save_impl.default_save_signature(self.obj)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 205, in default_save_signature
    fn.get_concrete_function()
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1168, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1074, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2842, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3200, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3062, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 979, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py", line 132, in _wrapped_model
    outputs = model(inputs, training=False)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 961, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py", line 385, in call
    inputs, training=training, mask=mask)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py", line 507, in _run_internal_graph
    outputs = node.layer(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 961, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)
  File "ncf_keras_main.py", line 85, in call
    self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1678, in add_metric
    metric_obj(value)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 231, in __call__
    replica_local_fn, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py", line 1133, in call_replica_local_fn
    return fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 211, in replica_local_fn
    update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py", line 90, in decorated
    update_op = update_state_fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 176, in update_state_fn
    return ag_update_state(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 373, in update_state
    update_total_op = self.total.assign_add(value_sum)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/distribute/values.py", line 918, in assign_add
    "SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.
@nikitamaia
Copy link
Member

nikitamaia commented Jun 10, 2020

Hi @djdongjin can you please provide a reproducible example? What arguments did you pass when running the script? Thanks

Note this section from the docs: "We don't allow operations like v.assign_add in a cross-replica context for sync on read variables" which is essentially the error message you are seeing.

@ravikyram ravikyram added comp:dist-strat Distribution Strategy related issues TF 2.3 Issues related to TF 2.3 stat:awaiting response Status - Awaiting response from author labels Jun 11, 2020
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 23, 2020
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@acarl005
Copy link

acarl005 commented Dec 22, 2020

I hit this error as well, so I created a reproducible example. I also have a StackOverflow post about it.

import tensorflow as tf


class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class Encoder(tf.keras.layers.Layer):
    """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""

    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.dense_proj = tf.keras.layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = tf.keras.layers.Dense(latent_dim)
        self.dense_log_var = tf.keras.layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z


class Decoder(tf.keras.layers.Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dense_proj = tf.keras.layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = tf.keras.layers.Dense(original_dim, activation="sigmoid")

    def call(self, inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)


class VariationalAutoEncoder(tf.keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(self, original_dim, intermediate_dim=64, latent_dim=32, name="autoencoder", **kwargs):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        self.add_metric([0.], name="foo")
        return reconstructed


(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

original_dim = 784

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    vae = VariationalAutoEncoder(original_dim, 64, 32)
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())

vae.fit(x_train, x_train, epochs=3, batch_size=64)
vae.save("vae")

For me, if I just remove the self.add_metric([0.], name="foo"), then it works. We need our custom metrics though.

As for environment, I'm using the Google AI Platform runtime version 2.3.

@nikitamaia
Copy link
Member

Hi @acarl005, thanks for providing a simple, reproducible example! I took a deeper look at this issue, seems to be a known bug that has been fixed in 2.4. Please see this gist that runs without error. You should now be able to export a Keras model trained with a custom metric.

If you want to test on AI Platform I think you'll need to use a custom container since the latest runtime version is only 2.3

@acarl005
Copy link

Thanks @nikitamaia for the quick response, and for providing the gist. We'll work on upgrading to 2.4.

@gtuzi
Copy link

gtuzi commented Aug 14, 2021

This seems to be a problem also with:
tf.keras.metrics.Mean()

Using TF 2.5.0 and get the same error with MirroredStrategy

@davzaman
Copy link

davzaman commented Sep 1, 2021

I'm getting this error with tf.keras.metrics.CategoricalAccuracy() (which inherits from Mean) when using MirroredStrategy

OS Ubuntu 18.04.5
TF 2.4.1
Python 3.8.3
CUDA 11.0

@ravinderkhatri
Copy link

I am facing the same issue with Tensorflow 2.7.0 Mirrored Strategy using tf.keras.metrics.mean()

@josephdviviano
Copy link

I am also seeing this issue with Tensorflow 2.7.0 MirroredStrategy and tf.keras.metrics.mean()

@gtuzi
Copy link

gtuzi commented Jun 1, 2022

Facing the same issue with Tensorflow 2.8.0 using metric = tfk.metrics.Mean with MirroredStrategy when calling metric.update_state(value). This is a basic functionality for the Keras echo system. Strange that this is a persisting bug for so many versions.

@tylerpayne
Copy link

I am also getting this error in Tensorflow 2.8.0 using keras.metrics.Mean within a tf.distribute.MirroredStrategy.

@neel04
Copy link

neel04 commented Aug 13, 2022

Same here. any workarounds?

@OrigamiDream
Copy link

I'm using TensorFlow 2.9.1, and I found a workaround.
Do not define the metrics within the scope.

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = ...
    loss_fn = ...
    optimizer = ...

# Out of scope, this works.
metrics = tf.keras.metrics.Mean(name='total_loss')

This works when you are using custom training loop.
I didn't test on Model#fit() because I have no enough time.

The model should not be compiled within the MirroredStrategy scope.

@Joschua-Conrad
Copy link

I'm facing the same issue under tf 2.8.0. add_metric does not work.

But in my opinion, that quite makes sense. A metric object is supposed to collect the metric tensors from all the replicas created. So it needs to be created in the cross-replica context enabled with strategy.scope. But add_metric creates the metric object if necessary and updates it at the same time. So if used, the metric object will be created inside some call and that is automatically wrapped inside the single-replica context enabled with strategy.run. The metric object will be created locally to one replica and that cannot work.

So the solution is to split metric creation and update. One can pack a self.nicemetric = tf.keras.metrics.Mean(name="Nicemetric") inside __init__ of a layer and then use self.nicemetric.update_state(some_tensor) in call. While creating the model inside the cross-replica context, the metric is created correctly. And later during model.fit, the metric is updated from within call on a per-replica basis.

The Distributed Training with Keras Tutorial uses model.fit, but does not show how to add custom metrics. The Custom Training Loops Tutorial uses custom metrics with the approach described above, but no model.fit. However, the approach described above and in Custom Training Loops Tutorial can be used also with model.fit. For me that worked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 type:bug Bug
Projects
None yet
Development

No branches or pull requests