Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generators #3

Closed
esparza83 opened this issue Sep 4, 2019 · 10 comments
Closed

Generators #3

esparza83 opened this issue Sep 4, 2019 · 10 comments
Assignees
Labels
question Further information is requested

Comments

@esparza83
Copy link

Is there a particular format when combining with ImageDataGenerator?
Error:
OperatorNotAllowedInGraphError: iterating over tf.Tensor is not allowed in Graph execution

@arjung
Copy link
Collaborator

arjung commented Sep 5, 2019

@esparza83, Could you give us more information on what you are trying to do, how you are using NSL, and which version of TF you are using?

It will also be great if you can paste a minimal relevant code snippet that will reproduce this problem.

@esparza83
Copy link
Author

Im using TF 2.0 RC for image classification, Im trying to use a generator to fit the model, but it won't work if I use fit_generator.

train_gen = ImageDataGenerator(rotation_range=5,
width_shift_range=0.10,
height_shift_range=0.10,
horizontal_flip = True,
zoom_range=0.2,
fill_mode='nearest').flow_from_dataframe(train_data,directory=HPARAMS.directory,x_col='filename',y_col='label',batch_size=HPARAMS.batch_size,target_size=HPARAMS.image_size,shuffle=True,class_mode='categorical')

@arjung
Copy link
Collaborator

arjung commented Sep 5, 2019

Are you using Neural Structured Learning?

@esparza83
Copy link
Author

I tried to follow the example on the mnist data from here https://www.tensorflow.org/neural_structured_learning, but using my own data

@csferng
Copy link
Collaborator

csferng commented Sep 5, 2019

@esparza83, AdversarialRegularization assumes that its input batch to be a dictionary containing both x and y, while ImageDataGenerator generates input in (x, y) tuples. The mismatched format might cause the error.

One way to fill the gap is to create an intermediate generator function which converts tuples into dictionaries, as demonstrated below:

import itertools
import tensorflow as tf
import neural_structured_learning as nsl
from keras.utils import np_utils

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = np.reshape(x_test, (-1, 28, 28, 1))
y_train = np_utils.to_categorical(y_train, HPARAMS.num_classes)
y_test = np_utils.to_categorical(y_test, HPARAMS.num_classes)

def generator(x, y):
  datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      featurewise_center=True,
      featurewise_std_normalization=True,
      rotation_range=20,
      width_shift_range=0.2,
      height_shift_range=0.2,
      horizontal_flip=True)
  datagen.fit(x)
  for x_batch, y_batch in datagen.flow(x, y, batch_size=32):
    yield {'feature': x_batch, 'label': y_batch}

train_gen = generator(x_train, y_train)
test_gen = generator(x_test, y_test)

model = tf.keras.Sequential(...)
adv_config = nsl.configs.make_adv_reg_config(multiplier=0.2, adv_step_size=0.05)
adv_model = nsl.keras.AdversarialRegularization(model, adv_config=adv_config)

adv_model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
adv_model.fit_generator(train_gen, steps_per_epoch=100, epochs=1)
adv_model.evaluate_generator(test_gen, steps=100)

@csferng csferng closed this as completed Sep 9, 2019
@aheydon-google aheydon-google added the question Further information is requested label Oct 16, 2019
@othmaneDaanouni
Copy link

I passed two days trying to use Neural Structured language to adapt into CNN Model I use ImageDataGenerator and flow_from_directory when I use model.fit_generator I got an error message:

ValueError: Layer model_1 was called with an input that isn't a symbolic tensor. Received type: <class 'dict'>. Full input: [{'feature': <tf.Tensor 'feature:0' shape=(None, 224, 224, 3) dtype=float32>}]. All inputs to the layer should be tensors.

i use Keras 2.3.1 and TensorFlow 2.0 as backend

this is a snipped of my code :

num_classes = 4
img_rows, img_cols = 224, 224
batch_size = 16

train_datagen = ImageDataGenerator(
  rescale=1./255,
  rotation_range=30,
  width_shift_range=0.3,
  height_shift_range=0.3,
  horizontal_flip=True,
  fill_mode='nearest')

  validation_datagen = ImageDataGenerator(rescale=1./255)
  train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_rows, img_cols),
    batch_size=batch_size, shuffle=True,
    class_mode='categorical')

    validation_generator = validation_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_rows, img_cols),
    batch_size=batch_size, shuffle=True,
    class_mode='categorical')

    def vgg():
      model1 = Sequential([ ])
      return model1

base_model = vgg()

I adapte Datagenerated from (x,y) format to a dictionary format

def convert_training_data_generator():
   for x ,y in train_generator:
    return {'feature': x, 'label':y}

def convert_testing_data_generator():
   for x ,y in validation_generator:
    return {'feature': x, 'label': y}

adv_config = nsl.configs.make_adv_reg_config(multiplier=0.2, adv_step_size=0.05)
model = nsl.keras.AdversarialRegularization(base_model, adv_config=adv_config)
train= convert_training_data_generator()
test= convert_testing_data_generator()

history = model.fit_generator(train,
   steps_per_epoch= nb_train_samples // batch_size,
    epochs = epochs,
    callbacks = callbacks,
    validation_data = test,
    validation_steps = nb_validation_samples // batch_size)

your help gays are very appreciated

@csferng
Copy link
Collaborator

csferng commented Dec 13, 2019

@othmaneDaanouni, thanks for your interest. Please provide a bit more details to help us reproduce the error.

The base model in your example is empty. Could you share the model architecture that you actually use?

    def vgg():
      model1 = Sequential([ ])
      return model1

Also the model has to be compiled before running .fit_generator. What are the arguments you use to compile the model, namely the optimizer and the loss?

model.compile(optimizer=..., loss=...)

@csferng csferng reopened this Dec 13, 2019
@othmaneDaanouni
Copy link

othmaneDaanouni commented Dec 13, 2019

@csferng @arjung think you very much for your reply this is the base model :

def vgg():
    model1 = Sequential([

    # 1st CONV-ReLU Layer
    Conv2D(64, (3, 3), activation="relu",padding = 'same', input_shape = (img_rows, img_cols, 3)),
    BatchNormalization(),

    # 2nd CONV-ReLU Layer
    Conv2D(64, (3, 3), activation="relu", padding = "same"),
    BatchNormalization(),

    # Max Pooling with Dropout 
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.2),

    # 3rd set of CONV-ReLU Layers
    Conv2D(128, (3, 3), activation="relu",padding="same"),
    BatchNormalization(),

    # 4th Set of CONV-ReLU Layers
    Conv2D(128, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),

    # Max Pooling with Dropout 
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.2),

    # 5th Set of CONV-ReLU Layers
    Conv2D(256, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),

    # 6th Set of CONV-ReLU Layers
    Conv2D(256, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),

    # Max Pooling with Dropout 
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.2),
        
     # 7th Set of CONV-ReLU Layers
    Conv2D(256, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),

    # 8th Set of CONV-ReLU Layers
    Conv2D(256, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
    
    # 9th Set of CONV-ReLU Layers
    Conv2D(256, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),

    # Max Pooling with Dropout 
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.2),
    
    # 10th Set of CONV-ReLU Layers
    Conv2D(512, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
        
    # 11th Set of CONV-ReLU Layers
    Conv2D(512, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
        
    # 12th Set of CONV-ReLU Layers
    Conv2D(512, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
        
    # Max Pooling with Dropout 
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.2),
        
    # 13th Set of CONV-ReLU Layers
    Conv2D(512, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
        
    # 14th Set of CONV-ReLU Layers
    Conv2D(512, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
        
    # 15th Set of CONV-ReLU Layers
    Conv2D(512, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
        
    # Max Pooling with Dropout 
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.2),
        
    # 16th Set of CONV-ReLU Layers
    Conv2D(512, (3, 3), activation="relu", padding="same"),
    BatchNormalization(),
    
    # Global Average Pooling
    GlobalAveragePooling2D(),

    # Final Dense Layer
    Dense(num_classes,activation="softmax"),
    ])
    return model1

and model.compile arguments :

model.compile(loss = 'categorical_crossentropy',
              optimizer = "adam",
              metrics = ['accuracy'])

@csferng
Copy link
Collaborator

csferng commented Dec 19, 2019

@othmaneDaanouni, I haven't been able to fully reproduce the error you encountered yet, but I just noticed that the convert_*_generator functions should yield the results but not return them (so that the return values are generators):

def convert_training_data_generator():
   for x ,y in train_generator:
    yield {'feature': x, 'label':y}  # not return

train= convert_training_data_generator()  # train is a generator, not a dict

More information about the yield keyword and Python generators could be found here.

@csferng
Copy link
Collaborator

csferng commented Feb 17, 2020

The problem turned out to be the difference between keras.models.Sequential and tf.keras.models.Sequential. The keras version requires the input to be a tensor or a list of tensors, while the tf.keras version flattens the input internally so the latter could handle dictionary-style input.

Your code should work after changing the import lines from from keras.X import Y to from tensorflow.keras.X import Y.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

6 participants
@csferng @esparza83 @othmaneDaanouni @aheydon-google @arjung and others