Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset.map(tf.keras.applications.vgg16.preprocess_input) -> AttributeError: 'Tensor' object has no attribute '_datatype_enum' #29931

Closed
CJMenart opened this issue Jun 18, 2019 · 14 comments
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues

Comments

@CJMenart
Copy link

System information

Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes. Small error-reproducing script provided below
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
TensorFlow installed from (source or binary): Binary (Python 3.6)
TensorFlow version (use command below): 2.0, nightly installed 6/18/19
Bazel version (if compiling from source):
CUDA/cuDNN version: CPU only
GPU model and memory:
Exact command to reproduce: python map_bug.py, script provided below

Describe the problem

Just installed nightly to make sure this hadn't been caught yet--I am trying to do some map() operations on a dataset, nothing fancy. If I build only one dataset in a script, it works fine. If I do it twice, however--for instance, make a train and hold-out set using the same operations--I get this mysterious error message. Pretty sure this cannot be intended behavior.

Source code / logs

Originally this was done in a large project. But I did a bit of work and whittled it down to the following script, which just uses MNIST to reproduce the error:

import tensorflow as tf
from tensorflow.keras import datasets
import numpy as np
BATCH_SIZE = 128


def size_image_for_vgg(image):
    return tf.image.resize(image, [224, 224])


if __name__ == '__main__':
    # punch up mnist
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
    train_images = train_images.reshape((60000, 28, 28, 1)).astype(np.float32) / 255.0
    test_images = test_images.reshape((10000, 28, 28, 1)).astype(np.float32) / 255.0

    # Now create a dataset
    im_ds = tf.data.Dataset.from_tensor_slices(train_images)
    label_ds = tf.data.Dataset.from_tensor_slices(train_labels)
    im_ds_t = tf.data.Dataset.from_tensor_slices(test_images)
    label_ds_t = tf.data.Dataset.from_tensor_slices(test_labels)

    # If this block is commented out, the block below it will NOT throw any error
    # do some normal Dataset operations on test and train data like we're getting ready to fit a model
    im_ds = im_ds.map(size_image_for_vgg, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    im_ds = im_ds.map(tf.keras.applications.vgg16.preprocess_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_ds = tf.data.Dataset.zip((im_ds, label_ds))
    train_ds = train_ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=1000))
    train_ds = train_ds.batch(batch_size=BATCH_SIZE)

    im_ds_t = im_ds_t.map(size_image_for_vgg, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # TODO throws error if you do it a second time?
    im_ds_t = im_ds_t.map(tf.keras.applications.vgg16.preprocess_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    test_ds = tf.data.Dataset.zip((im_ds_t, label_ds_t))
    test_ds = test_ds.batch(batch_size=BATCH_SIZE)

Stack trace

Traceback (most recent call last):
File "map_bug.py", line 33, in
im_ds_t = im_ds_t.map(tf.keras.applications.vgg16.preprocess_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)
File "/home/menarcj/OtherSoftware/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 1141, in map
self, map_func, num_parallel_calls, preserve_cardinality=True)
File "/home/menarcj/OtherSoftware/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3320, in init
**flat_structure(self))
File "/home/menarcj/OtherSoftware/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_dataset_ops.py", line 4141, in parallel_map_dataset
preserve_cardinality=preserve_cardinality, name=name, ctx=_ctx)
File "/home/menarcj/OtherSoftware/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_dataset_ops.py", line 4224, in parallel_map_dataset_eager_fallback
_attr_Targuments, other_arguments = _execute.convert_to_mixed_eager_tensors(other_arguments, _ctx)
File "/home/menarcj/OtherSoftware/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py", line 210, in convert_to_mixed_eager_tensors
types = [t._datatype_enum() for t in v] # pylint: disable=protected-access
File "/home/menarcj/OtherSoftware/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py", line 210, in
types = [t._datatype_enum() for t in v] # pylint: disable=protected-access
AttributeError: 'Tensor' object has no attribute '_datatype_enum'

@dynamicwebpaige dynamicwebpaige added 2.0.0-alpha0 TF 2.0 Issues relating to TensorFlow 2.0 labels Jun 18, 2019
@CJMenart
Copy link
Author

Zoomed in a bit...the error disappears if either

im_ds = im_ds.map(tf.keras.applications.vgg16.preprocess_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)

is commented out. You don't have to comment out the whole block.

@CJMenart
Copy link
Author

So I can reproduce the error with the following even smaller script:

import tensorflow as tf
from tensorflow.keras import datasets
import numpy as np

if __name__ == '__main__':
    # punch up mnist
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
    train_images = train_images.reshape((60000, 28, 28, 1)).astype(np.float32) / 255.0
    test_images = test_images.reshape((10000, 28, 28, 1)).astype(np.float32) / 255.0

    # Now create a dataset
    im_ds = tf.data.Dataset.from_tensor_slices(train_images)
    im_ds_t = tf.data.Dataset.from_tensor_slices(test_images)

    im_ds = im_ds.map(tf.keras.applications.vgg16.preprocess_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # TODO throws error if you do it a second time?
    im_ds_t = im_ds_t.map(tf.keras.applications.vgg16.preprocess_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)

@tchinen
Copy link

tchinen commented Jun 20, 2019

Try this as a workaround:

# Add before any TF calls
# Initialize the keras global outside of any tf.functions
temp = tf.random_uniform([4, 32, 32, 3])  # Or tf.zeros
tf.keras.applications.vgg16.preprocess_input(temp)

@CJMenart
Copy link
Author

tf.random_uniform doesn't exist, but if I use tf.zeros, adding those two lines seems to prevent the issue. I'm not sure why.

@rachellim
Copy link
Contributor

Why the workaround works:

(NOTE: the following involves tf.data internal implementation details)

The error occurs because keras' vgg16.preprocess_input uses a python global (_IMAGENET_MEAN here).

Each time you create a map dataset, it traces the user-defined function in a tf.function. When the function is traced in dataset.map for the first time, it creates the global (a tf.constant) in the first map function's scope. The second time it is traced, it uses the tf.constant created earlier, which is captured as an input to the second map function. However, this is a symbolic tensor (non-eager tensor) that belongs to a different tf.function, resulting in the error.

Adding the two lines works because running the preprocess_input function before creating the dataset initializes the _IMAGENET_MEAN tf.constant outside of any traced functions.

@achandraa achandraa self-assigned this Jun 27, 2019
@achandraa achandraa added comp:keras Keras related issues type:support Support issues labels Jun 27, 2019
@achandraa
Copy link

@CJMenart : Did you get the chance to have a look on @rachellim's response. Please let us know if that resolves the issue. Thanks!

@achandraa achandraa added the stat:awaiting response Status - Awaiting response from author label Jun 27, 2019
@CJMenart
Copy link
Author

@achandraa It worked as soon as tchinen recommended I try it. I was able to complete the tasks I was working on.

But we're not just going to leave this behavior here, right? The fact that such an awkward workaround is required if you want to tf.map the same function twice seems like a bug to me.

@karmel
Copy link

karmel commented Jul 8, 2019

This should be fixed now-- can you try with tf-nightly?

@CJMenart
Copy link
Author

CJMenart commented Jul 9, 2019

The script above is still throwing the same error...I'm now on 2.0.0-dev20190709.

@tanzhenyu
Copy link
Contributor

@CJMenart Sorry should have been more clear on this. This is not fixed through tensorflow, it's fixed through keras applications. Can you git clone keras-applications from github and pip install it through "pip install -e ."?

@CJMenart
Copy link
Author

Ah OK. Just pulled keras-applications, and the snippet now runs without error.

@tanzhenyu
Copy link
Contributor

Great. Closing now.

@Bornlex
Copy link

Bornlex commented Sep 9, 2019

@tanzhenyu @CJMenart Sorry guys, I am running into the same issue when I run the following code on Google Colab:

def build_dataset(boxes_df, data_directory='/content'):
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    outputs = [vgg.get_layer(name).output for name in ['block5_pool']]
    vgg = tf.keras.Model([vgg.input], outputs)

    filenames_ds = tf.data.Dataset.from_tensor_slices(boxes_df['image_name'].apply(lambda path: os.path.join(data_directory, path)))
    x1_ds        = tf.data.Dataset.from_tensor_slices(boxes_df['x_1'])
    x2_ds        = tf.data.Dataset.from_tensor_slices(boxes_df['x_2'])
    y1_ds        = tf.data.Dataset.from_tensor_slices(boxes_df['y_1'])
    y2_ds        = tf.data.Dataset.from_tensor_slices(boxes_df['y_2'])
    tmp_ds       = tf.data.Dataset.zip((filenames_ds, x1_ds, x2_ds, y1_ds, y2_ds))
    #"""
    images_ds    = tmp_ds.map(
        lambda path, x1, x2, y1, y2: tf.image.resize_images(
            tf.image.crop_to_bounding_box(
                tf.image.decode_jpeg(tf.read_file(path)),
                tf.cast(x1, tf.int32),
                tf.cast(y1, tf.int32),
                tf.cast(x2 - x1, tf.int32),
                tf.cast(y2 - y1, tf.int32)
            ),
            (224, 224)
        )
    )
    images_ds = images_ds.map(
        lambda img: tf.keras.applications.vgg19.preprocess_input(img)
    )
    features_ds = images_ds.map(
        lambda img: vgg(tf.expand_dims(img, axis=0)).reshape(7 * 7 * 512)
    )
    return features_ds

@ISipi
Copy link

ISipi commented Jun 26, 2020

I came across this same error message when using tf.keras.applications.densenet.preprocess_input so this issue isn't just related to VGG model preprocessing. It was fixed by placing
temp = tf.zeros([4, 32, 32, 3])
tf.keras.applications.densenet.preprocess_input(temp)
at the beginning of the file, as suggested by @tchinen and @rachellim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues
Projects
TensorFlow 2.0
  
Awaiting triage
Development

No branches or pull requests