Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: merge_call called while defining a new graph or a tf.function -- Update non-trainable variable with assign under mirrored strategy scope and tf.function decorator #34203

Closed
naturomics opened this issue Nov 12, 2019 · 2 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues

Comments

@naturomics
Copy link

naturomics commented Nov 12, 2019

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Gentoo
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): binary, pip
  • TensorFlow version (use command below): 2.0.0
  • Python version: 3.6
  • Bazel version (if compiling from source): -
  • GCC/Compiler version (if compiling from source): -
  • CUDA/cuDNN version:
  • GPU model and memory:

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
My purpose is to record some hidden results that no need to compute gradient but is used for the next batch. The demo code is given below.
Under the mirrored strategy context, it fails to update non-trainable variable with assign method
within fn with tf.function decorator. If remove tf.function, it works well. If
re-assign self.record = record within tf.function, then will hit another error:
TypeError: An op outside of the function building code is being passed, same error like this. I'm aware we have to do some all_reduce-like operations to merge the results from all replicas before update any variable.
I tried something like tf.distribute.get_replica_context().merge_call(), but the doc is really unclear how to implement it, the source code of tensorflow also can not be found any useful example.

Describe the expected behavior
under strategy and tf.function context, updating a non-trainable variable with assign method
should work

Code to reproduce the issue

import numpy as np
import tensorflow as tf

class MyLayer(tf.keras.layers.Layer):
  def __init__(self):
    super(MyLayer, self).__init__()

  def build(self, input_shape):
    self.w = self.add_weight("w", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(np.random.uniform()))

    # record some hidden results used by next batch
    self.record = self.add_weight("record", shape=[],
                                  dtype=tf.float32,
                                  trainable=False,
                                  aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                                  initializer=tf.constant_initializer(np.random.uniform()))

  def call(self, x):
    record = self.record + self.w
    y = x*self.w + record

    # Hit TypeError: An op outside of the function building code is being passed a "Graph" tensor
    #self.record = record

    # Hit RuntimeError: `merge_call` called while defining a new graph or a tf.function
    self.record.assign(record)
    return y


class Net(tf.keras.Model):
  def __init__(self):
    super(Net, self).__init__()
    self.my_layer = MyLayer()

  def call(self, x):
    y = self.my_layer(x)
    y = y + tf.random.normal(shape=[])
    return y

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    net = Net()
    n_samples = 1000
    xs = np.random.uniform(size=[n_samples])

    #It works well without tf.function
    @tf.function
    def train_step(x):
        y = net(x)
        return y
    for i in range(n_samples):
        x = xs[i]
       

    # if no tf.distribute.strategy was used, it also works well no matter tf.function is used or not

        y = strategy.experimental_run_v2(train_step, args=(x,))

Other info / logs

 test4.py:50 train_step  *
        y = net(x)
    /usr/lib64/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py:891 __call__
        outputs = self.call(cast_inputs, *args, **kwargs)
    test4.py:28 call  *
        self.record.assign(record+0.1)
    /usr/lib64/python3.6/site-packages/tensorflow_core/python/distribute/values.py:1036 assign
        return self._assign_func(f=assign_fn, *args, **kwargs)
    /usr/lib64/python3.6/site-packages/tensorflow_core/python/distribute/values.py:1024 _assign_func
        merge_fn, args=args, kwargs=kwargs)
    /usr/lib64/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py:1917 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    /usr/lib64/python3.6/site-packages/tensorflow_core/python/distribute/mirrored_strategy.py:940 _merge_call
        "`merge_call` called while defining a new graph or a tf.function. "

RuntimeError: `merge_call` called while defining a new graph or a tf.function. 
This can often happen if the function `fn` passed to `strategy.experimental_run()` 
is decorated with `@tf.function` (or contains a nested `@tf.function`), 
and `fn` contains a synchronization point, such as aggregating gradients. 
This behavior is not yet supported. Instead, please wrap the entire call `strategy.experimental_run(fn)` in a `@tf.function`, 
and avoid nested `tf.function`s that may potentially cross a synchronization boundary.
@naturomics
Copy link
Author

naturomics commented Nov 13, 2019

Closing. Solved after reading the doc and source code several times, hope the doc and example for tf.distribute will be improved.

When defining self.record, specify synchronization argument like this:

self.record = self.add_weight("record", shape=[],
                              dtype=tf.float32,
                              trainable=False,
                              synchronization=tf.VariableSynchronization.ON_READ,
                              aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                              initializer=tf.constant_initializer(np.random.uniform()))

ON_READ synchronization allows to use assign/assign_sub/assign_add methods to update its value in a replica context. The complete code:

import numpy as np
import tensorflow as tf

class MyLayer(tf.keras.layers.Layer):
  """A simple linear model."""

  def __init__(self):
    super(MyLayer, self).__init__()

  def build(self, input_shape):
    self.w = self.add_weight("w", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(np.random.uniform()))

    # record some hidden results used by next batch
    self.record = self.add_weight("record", shape=[],
                                  dtype=tf.float32,
                                  trainable=False,
                                  synchronization=tf.VariableSynchronization.ON_READ,
                                  aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                                  initializer=tf.constant_initializer(np.random.uniform()))

  def call(self, x):
    record = self.record + self.w + x
    y = x*self.w + record

    replica_ctx = tf.distribute.get_replica_context()
    tf.print(replica_ctx.replica_id_in_sync_group, end="\t")
    self.record.assign(record)
    tf.print(self.record)
    tf.print()
    return y


class Net(tf.keras.Model):
  def __init__(self):
    super(Net, self).__init__()
    self.my_layer = MyLayer()

  def call(self, x):
    y = self.my_layer(x)
    y = y + tf.random.normal(shape=[])
    return y

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    net = Net()
    n_samples = 1000
    xs = np.random.uniform(size=[n_samples])
    dataset = tf.data.Dataset.from_tensor_slices((xs,))
    dataset = dataset.batch(2)
    dataset = strategy.experimental_distribute_dataset(dataset)

    @tf.function
    def train_step(x):
        x = tf.reshape(x, [])
        y = net(x)
        return y
    for x in dataset:
        y = strategy.experimental_run_v2(train_step, args=(x,))

@oanush oanush self-assigned this Nov 13, 2019
@oanush oanush added comp:dist-strat Distribution Strategy related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues labels Nov 13, 2019
@moono
Copy link

moono commented Feb 7, 2020

I have exact same issue: Non-trainable variable assign ops and mirrored strategy for training.
I'm wondering if this is the proper (official) solution?
Can member of Tensorflow verify?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues
Projects
None yet
Development

No branches or pull requests

3 participants