RuntimeError: merge_call
called while defining a new graph or a tf.function -- Update non-trainable variable with assign under mirrored strategy scope and tf.function decorator
#34203
Labels
comp:dist-strat
Distribution Strategy related issues
TF 2.0
Issues relating to TensorFlow 2.0
type:support
Support issues
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0:
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
2. TF 2.0:python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the current behavior
My purpose is to record some hidden results that no need to compute gradient but is used for the next batch. The demo code is given below.
Under the mirrored strategy context, it fails to update non-trainable variable with assign method
within fn with tf.function decorator. If remove tf.function, it works well. If
re-assign
self.record = record
within tf.function, then will hit another error:TypeError: An op outside of the function building code is being passed
, same error like this. I'm aware we have to do some all_reduce-like operations to merge the results from all replicas before update any variable.I tried something like
tf.distribute.get_replica_context().merge_call()
, but the doc is really unclear how to implement it, the source code of tensorflow also can not be found any useful example.Describe the expected behavior
under strategy and tf.function context, updating a non-trainable variable with assign method
should work
Code to reproduce the issue
Other info / logs
The text was updated successfully, but these errors were encountered: