Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: An op outside of the function building code is being passed a Graph tensor #620

Open
nbro opened this issue Oct 25, 2019 · 8 comments
Assignees
Labels
layers tensorflow 2.0 Issues related to TF 2.0.

Comments

@nbro
Copy link
Contributor

nbro commented Oct 25, 2019

System information

  • Have I written custom code: Yes.
  • OS Platform and Distribution: Mac OS Catalina: 10.15 (19A602)
  • TensorFlow installed from: binary
  • TensorFlow version: 2.0.0
  • Python version: 3.7.4
  • GPU model and memory: Intel Iris Pro 1536 MB

Describe the current behavior

I am getting the error

tensorflow.python.eager.core._SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'conv2d_flipout/divergence_kernel:0' shape=() dtype=float32>]

After having gotten the exception

TypeError: An op outside of the function building code is being passed a Graph tensor

See the detailed traceback below.

Describe the expected behavior

No error.

Code to reproduce the issue

from __future__ import print_function

import tensorflow as tf
import tensorflow_probability as tfp

# tf.compat.v1.disable_eager_execution()

def get_bayesian_model(input_shape=None, num_classes=10):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Input(shape=input_shape))
    model.add(tfp.layers.Convolution2DFlipout(6, kernel_size=5, padding="SAME", activation=tf.nn.relu))
    model.add(tf.keras.layers.Flatten())
    model.add(tfp.layers.DenseFlipout(84, activation=tf.nn.relu))
    model.add(tfp.layers.DenseFlipout(num_classes))
    return model

def get_mnist_data(normalize=True):
    img_rows, img_cols = 28, 28
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    if tf.keras.backend.image_data_format() == 'channels_first':
        x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
        x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
        input_shape = (1, img_rows, img_cols)
    else:
        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
        x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
        input_shape = (img_rows, img_cols, 1)

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')

    if normalize:
        x_train /= 255
        x_test /= 255

    return x_train, y_train, x_test, y_test, input_shape


def train():
    # Hyper-parameters.
    batch_size = 128
    num_classes = 10
    epochs = 1

    # Get the training data.
    x_train, y_train, x_test, y_test, input_shape = get_mnist_data()

    # Get the model.
    model = get_bayesian_model(input_shape=input_shape, num_classes=num_classes)

    # Prepare the model for training.
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss="sparse_categorical_crossentropy",
                  metrics=['accuracy'])

    # Train the model.
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1)
    model.evaluate(x_test, y_test, verbose=0)


if __name__ == "__main__":
    train()

Other info / logs

WARNING:tensorflow:From /Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_probability/python/layers/util.py:104: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
2019-10-25 20:38:32.504579: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-10-25 20:38:32.517426: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fe25e59f290 executing computations on platform Host. Devices:
2019-10-25 20:38:32.517438: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): Host, Default Version
Train on 60000 samples
Traceback (most recent call last):
  128/60000 [..............................] - ETA: 7:32  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/execute.py", line 61, in quick_execute
    num_outputs)
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: conv2d_flipout/divergence_kernel:0

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/nbro/Desktop/my_project/my_module.py", line 63, in <module>
    train()
  File "/Users/nbro/Desktop/my_project/my_module.py", line 58, in train
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
    use_multiprocessing=use_multiprocessing)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 324, in fit
    total_epochs=epochs)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 123, in run_one_epoch
    batch_outs = execution_function(iterator)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 86, in execution_function
    distributed_function(input_fn))
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
    result = self._call(*args, **kwds)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 520, in _call
    return self._stateless_fn(*args, **kwds)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1823, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1141, in _filtered_call
    self.captured_inputs)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
    ctx=ctx)
  File "/Users/nbro/Desktop/my_project/venv/lib/python3.7/site-packages/tensorflow_core/python/eager/execute.py", line 75, in quick_execute
    "tensors, but found {}".format(keras_symbolic_tensors))
tensorflow.python.eager.core._SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'conv2d_flipout/divergence_kernel:0' shape=() dtype=float32>]

The problem is apparently related to the layer tfp.layers.Convolution2DFlipout. I know that if I use tf.compat.v1.disable_eager_execution() after having imported TensorFlow, I do not get the mentioned error anymore, but I would like to use TensorFlow's eager execution, avoid sessions or placeholders.

I opened the same issue here: tensorflow/tensorflow#33729.

@srvasude
Copy link
Member

srvasude commented Feb 8, 2020

I can still replicate this. @jburnim to triage.

@nbro
Copy link
Contributor Author

nbro commented Feb 8, 2020

@srvasude Yes, this bug still exists in TensorFlow 2.1 and TFP 0.9. This is a big issue that should be solved as soon as possible because apparently the only workaround to it (described in this comment tensorflow/tensorflow#33729 (comment)) will not be available in TensorFlow 2.2 (see this issue tensorflow/tensorflow#35138).

@davmre
Copy link
Contributor

davmre commented Feb 11, 2020

We're looking into this. As a potential workaround, at least some cases it appears that calling model.build(input_shape) before the model.fit() call avoids the problem.

@nbro
Copy link
Contributor Author

nbro commented Apr 23, 2020

@davmre That doesn't seem to work with my example above using TF 2.1 and TFP 0.9.0.

Can you provide an example where that trick avoids the problem?

nbro referenced this issue Apr 26, 2020
…_tf_function=True` - Make `_DenseVariational` layers eager compatible.

PiperOrigin-RevId: 264872770
@nbro
Copy link
Contributor Author

nbro commented Apr 28, 2020

This error doesn't occur anymore in the nightly versions of TF and TFP.

daniel-muthukrishna added a commit to daniel-muthukrishna/transomaly that referenced this issue May 13, 2020
…work.

Need to fix Convolutional1D not using 'causal' padding. Also need to fix disabling eager_execution required because of known issue with Conv1DFlipout in current tf 2.2 and tfp 0.9 versions. Problem identified here: tensorflow/probability#620
@mcmar
Copy link

mcmar commented Jul 21, 2020

@nbro I'm still getting this with tensorflow 2.2.0, which should have included the nightly fixes as of Apr 28 2020.

@SiegeLordEx
Copy link
Member

@mcmar Do you have a small, self-contained example of this? Which layers were you using?

@prerakmody
Copy link

prerakmody commented Sep 30, 2021

This issue of using the eager execution still persists with TFlow 2.4.0 and TFlow-Prob: 0.12.1.
Check this gist where I had to disable eager execution.

Update

  • In the above gist, using model.build(input_shape) does not solve the problem, but initializing weights by passing a random tensor of appropriate size solves the requirement of disabling eager execution - model(tf.ones(( (1,) + input_shape))).shape

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
layers tensorflow 2.0 Issues related to TF 2.0.
Projects
None yet
Development

No branches or pull requests

8 participants