Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.keras multi input models don't work when using tf.data.Dataset #20698

Closed
lgeiger opened this issue Jul 11, 2018 · 33 comments
Closed

tf.keras multi input models don't work when using tf.data.Dataset #20698

lgeiger opened this issue Jul 11, 2018 · 33 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@lgeiger
Copy link
Contributor

lgeiger commented Jul 11, 2018

Please go to Stack Overflow for help and support:

https://stackoverflow.com/questions/tagged/tensorflow

If you open a GitHub issue, here is our policy:

  1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
  2. The form below must be filled out.
  3. It shouldn't be a TensorBoard issue. Those go here.

Here's why we have that policy: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.


System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.13.5 and Debian GNU/Linux 9 (stretch)
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v1.9.0-rc2-359-g95cfd8b3d9 1.10.0-dev20180711 also reproduces on v1.9.0
  • Python version: 3.6.5 and 3.5.3
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: None
  • GPU model and memory: None
  • Exact command to reproduce: see below

You can collect some of this information using our environment capture script:

https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh

You can obtain the TensorFlow version with

python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"

Describe the problem

tf.keras multi input models don't work when used together with tf.data.Dataset due to input broken validation checks. This problem reproduces both on tf@1.9.0 and the latest nightly.

@fchollet Do you have any ideas what's going on here, or am I missing something obvious?

Source code / logs

Multi input model

Consider the following toy model:

import numpy as np
import tensorflow as tf
from tensorflow import keras

data_a = np.array([300, 455, 350, 560, 700, 800, 200, 250], dtype=np.float32)
labels = np.array([455, 350, 560, 700, 800, 200, 250, 300], dtype=np.float32)
data_b = np.array([200, 255, 350, 470, 600, 300, 344, 322], dtype=np.float32)
data_a = np.reshape(data_a, (8, 1, 1))
data_b = np.reshape(data_b, (8, 1, 1))

x = keras.layers.Input(shape=(1, 1), name='input_x')
y = keras.layers.Input(shape=(1, 1), name='input_y')
admi = keras.layers.LSTM(40, return_sequences=False)(x)
pla = keras.layers.LSTM(40, return_sequences=False)(y)
out = keras.layers.concatenate([admi, pla], axis=-1)
output = keras.layers.Dense(1, activation='sigmoid')(out)
model = keras.models.Model(inputs=[x, y], outputs=output)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

Using numpy data

When fitting using numpy data this works as expected when passing a list or dictionary of inputs:

model.fit([data_a, data_b], labels, batch_size=2, epochs=10)
model.fit({'input_x': data_a, 'input_y': data_b}, labels, batch_size=2, epochs=10)

Using tf.data.Dataset.from_tensor_slices dictionary

When trying the same with a tf.data.Dataset the following fails due to incorrect input validation:

dataset = tf.data.Dataset.from_tensor_slices(({'input_x': data_a, 'input_y': data_b}, labels)).batch(2).repeat()
model.fit(dataset, epochs=10, steps_per_epoch=4)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-d35bacd274cc> in <module>()
      1 dataset = tf.data.Dataset.from_tensor_slices(({'input_x': data_a, 'input_y': data_b}, labels)).batch(2).repeat()
----> 2 model.fit(dataset, epochs=10, steps_per_epoch=4)

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1276         steps_name='steps_per_epoch',
   1277         steps=steps_per_epoch,
-> 1278         validation_split=validation_split)
   1279 
   1280     # Prepare validation data.

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split)
    915           feed_output_shapes,
    916           check_batch_axis=False,  # Don't enforce the batch size.
--> 917           exception_prefix='target')
    918 
    919       # Generate sample-wise weight values given the `sample_weight` and

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    180                            ': expected ' + names[i] + ' to have ' +
    181                            str(len(shape)) + ' dimensions, but got array '
--> 182                            'with shape ' + str(data_shape))
    183         if not check_batch_axis:
    184           data_shape = data_shape[1:]

ValueError: Error when checking target: expected dense to have 2 dimensions, but got array with shape (None,)

Using tf.data.Dataset.from_generator dictionary

However using the same network together with tf.data.Dataset.from_generator works. Probably because less validation is done:

def generator():
    while True:
        for i in np.random.permutation(8):
            yield {'input_x': data_a[i], 'input_y': data_b[i]}, labels[i]

dataset = tf.data.Dataset.from_generator(generator, ({'input_x': tf.float32, 'input_y': tf.float32}, tf.float32)).batch(2)
model.fit(dataset, epochs=10, steps_per_epoch=4)

Using tf.data.Dataset tuple

Passing the multi-input as a tuple to the model both datasets generated with from_tensor_slices and from_generator fail:

dataset = tf.data.Dataset.from_tensor_slices(((data_a, data_b), labels)).batch(2).repeat()
model.fit(dataset, epochs=10, steps_per_epoch=4)
def generator():
    while True:
        for i in np.random.permutation(8):
            yield (data_a[i], data_b[i]), labels[i]

dataset = tf.data.Dataset.from_generator(generator, ((tf.float32, tf.float32), tf.float32)).batch(2)
model.fit(dataset, epochs=10, steps_per_epoch=4)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-7-512a95f0c2a7> in <module>()
      1 dataset = tf.data.Dataset.from_tensor_slices(((data_a, data_b), labels)).batch(2).repeat()
----> 2 model.fit(dataset, epochs=10, steps_per_epoch=4)

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1276         steps_name='steps_per_epoch',
   1277         steps=steps_per_epoch,
-> 1278         validation_split=validation_split)
   1279 
   1280     # Prepare validation data.

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split)
    876         feed_input_shapes,
    877         check_batch_axis=False,  # Don't enforce the batch size.
--> 878         exception_prefix='input')
    879 
    880     if y is not None:

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    141     data = data.values if data.__class__.__name__ == 'DataFrame' else data
    142     data = [data]
--> 143   data = [standardize_single_array(x) for x in data]
    144 
    145   if len(data) != len(names):

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py in <listcomp>(.0)
    141     data = data.values if data.__class__.__name__ == 'DataFrame' else data
    142     data = [data]
--> 143   data = [standardize_single_array(x) for x in data]
    144 
    145   if len(data) != len(names):

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py in standardize_single_array(x)
     79   elif tensor_util.is_tensor(x):
     80     return x
---> 81   elif x.ndim == 1:
     82     x = np.expand_dims(x, 1)
     83   return x

AttributeError: 'tuple' object has no attribute 'ndim'
@was84san
Copy link

was84san commented Jul 11, 2018

I have the same problem and I have also multiple input dataset. But not sure if this problem caused by the multiple input datset. And I am using tensorflow 1.9 In order to be able to use dataset iterator in model.fit

So
1- If I do the following :

dataset = tf.data.TFRecordDataset(train.tf_records).map(_parse_function).batch(20).repeat()
model.fit(dataset)

I got :
AttributeError: "'RepeatDataset' object has no attribute 'ndim'"

2- If I do the following :
dataset=tf.data.TFRecordDataset(train.tf_records).map(_parse_function).batch(20).repeat().make_initializable_iterator()
model.fit(dataset)

I got :
AttributeError: "'Iterator' object has no attribute 'ndim'"

3- If I do the following :
dataset=tf.data.TFRecordDataset(train.tf_records).map(_parse_function).batch(20).repeat().make_initializable_iterator().get_next()
model.fit(dataset)

I got :
AttributeError: "'tuple' object has no attribute 'ndim'"

Note:
if I run get_next() for the iterator, it should give me data and label and other information I put it in tfrecords. So my input pair in iterator.get_next()[0] , and labels in iterator.get_next()[1].

@lgeiger
Copy link
Contributor Author

lgeiger commented Jul 12, 2018

I opened #20753 to fix the issues not related to multi input models.

@rohan100jain rohan100jain added stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug labels Jul 27, 2018
@rohan100jain
Copy link
Member

I could reproduce the error.

@lgeiger
Copy link
Contributor Author

lgeiger commented Jul 27, 2018

Thanks for taking the time and reproducing it. Did you have a chance to checked out my fix in #20753?

@lgeiger
Copy link
Contributor Author

lgeiger commented Jul 27, 2018

Theres also a related PR that adds support for using tuples as multi dim inputs: #20136

@keunwoochoi
Copy link

My situation seems similar. The iterator of dataset fed to model.fit is made from tf.data.Dataset.zip()

xy_ds = (
        tf.data.Dataset.zip((audio_ds, label_ds))
            .batch(
            batch_size=batch_size,
            # drop_remainder=True if is_training else False
            )
        .repeat(repeat)
        .prefetch(tf.contrib.data.AUTOTUNE)
    )

Both audio_ds (input) and label_ds (output) are instances of tf.data.TextLineDataset.

tf.data.TextLineDataset(id_path)
            .map(load_audio, num_parallel_calls=N_READ_THREAD)

Before fed to the model, its iterator is created.

tr_iterator = tr_set.make_one_shot_iterator()
tr_iterator.get_next()
(<tf.Tensor 'IteratorGetNext:0' shape=<unknown> dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=<unknown> dtype=float32>)

And this is the error message when model.fit() is called.

  File "data_io.py", line 127, in <module>
    model.fit(
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 950, in fit
    batch_size=batch_size)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 749, in _standardize_user_data
    exception_prefix='input')
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training_utils.py", line 91, in standardize_input_data
    data = [standardize_single_array(x) for x in data]
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training_utils.py", line 91, in <listcomp>
    data = [standardize_single_array(x) for x in data]
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training_utils.py", line 26, in standardize_single_array
    elif x.ndim == 1:
AttributeError: 'Iterator' object has no attribute 'ndim'

tensorflow: 1.9.0, keras:2.2.2

@was84san
Copy link

was84san commented Aug 16, 2018

I think I discovered the problem fro my situation. The problem was I am using the standalone Keras. Not the one imported from tendorflow. So the new features of feeding the iterator directly to model.fit() is valid only when you are usingtf.Kerasnot the standalone Keras.

@keunwoochoi
Copy link

@was84san Wow, same here, and now it seems solved. Thanks!

@was84san
Copy link

was84san commented Sep 14, 2018

@lgeiger is this issue of passing multiple input to keras model vi tf.dataset api fixed?

@hhwxxx
Copy link

hhwxxx commented Sep 20, 2018

Hi, @was84san.
As you mentioned, I am using tf.kera. But the problem still exists. Do you have any idea?
Thanks!

@was84san
Copy link

Which problem exactly!, feeding multiple inputs, or feeding the iterator directly to model.fit. I figure out only the last one.

@gabrielibagon
Copy link

@hhwxxx I was also unable to use model.fit() with a nested Dataset iterator for multi-input and multi-output models (while using tf.keras) on version 1.10. Installing tf-gpu-nightly (my specific version is now 1.12.0-dev20180918) seemed to resolve this problem for me.

@JanRuettinger
Copy link

@gabrielibagon Could you post a snippet how you got a nested dataset iterator with multiple inputs working?

@ricoms
Copy link

ricoms commented Sep 29, 2018

The final example at here is interesting:

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`
  # protocol buffer, and perform any additional per-record preprocessing.
  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.
    image = tf.image.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label
  # tensor for each example.
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)

  # Each element of `dataset` is tuple containing a dictionary of features
  # (in which each value is a batch of values for that feature), and a batch of
  # labels.
  return dataset

now, how to define a model that accepts and trains correctly with that datase? Is the full example available somewhere?

@gabrielibagon
Copy link

gabrielibagon commented Oct 9, 2018

@JanRuettinger @ricoms Sorry for the delayed response.

I drafted up a toy example using MNIST in order to train a model with two inputs and two outputs. The model is simply two identical models fused together, which takes in two copies of the MNIST data (two inputs) and outputs a prediction for each (two outputs). You can adapt this to more complex models and input pipelines.

Note: This is still using tf-nightly-gpu version 1.12.0-dev20180918. I assume this will work in tensorflow 1.12 and above.

batch_size = 512

# -- Data Setup -- #
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
y_train = tf.keras.utils.to_categorical(y_train)
x_train, x_test = x_train / 255.0, x_test / 255.0
# Create two inputs and two outputs (for demonstration)
x_train1 = x_train2 = x_train
y_train1 = y_train2 = y_train

# -- Dataset API -- #
# Create a Dataset for multiple inputs and Dataset for multiple outputs
input_set = tf.data.Dataset.from_tensor_slices((x_train1, x_train2))
output_set = tf.data.Dataset.from_tensor_slices((y_train1, y_train2))
# Create Dataset pipeline
input_set = input_set.batch(batch_size).repeat()
output_set = output_set.batch(batch_size).repeat()
# Group the input and output dataset
dataset = tf.data.Dataset.zip((input_set, output_set))
# Initialize the iterator to be passed to the model.fit() function
data_iter = dataset.make_one_shot_iterator()

# -- Model Definition -- #
# Multiple Inputs
input1 = tf.keras.layers.Input(shape=(28,28))
input2 = tf.keras.layers.Input(shape=(28,28))
# Input 1 Pathway
x1 = tf.keras.layers.Flatten()(input1)
x1 = tf.keras.layers.Dense(512, activation=tf.nn.relu)(x1)
x1 = tf.keras.layers.Dropout(0.2)(x1)
# Input 2 Pathway
x2 = tf.keras.layers.Flatten()(input2)
x2 = tf.keras.layers.Dense(512, activation=tf.nn.relu)(x2)
x2 = tf.keras.layers.Dropout(0.2)(x2)
# Multiple Outputs
output1 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)(x1)
output2 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)(x2)
# Create Model
model = tf.keras.models.Model(inputs=[input1, input2], outputs=[output1, output2])
# Compile
model.compile(optimizer='adam', loss='categorical_crossentropy')

# -- Train -- #
model.fit(data_iter, steps_per_epoch=len(x_train)//batch_size, epochs=5)

Update: As @jashshopin mentions below, the dataset object can be passed directly to model.fit() if you have no need for an iterator.

@jashshopin
Copy link

Is it necessary to use dataset.make_one_shot_iterator()?

@gabrielibagon
Copy link

@jashshopin Thanks for pointing that out, apparently you can pass the zipped dataset directly into model.fit(). The example should still work for those who might want to use a one-shot iterator or initializable iterator as well.

@ricoms
Copy link

ricoms commented Oct 18, 2018

thanks, @gabrielibagon.

I have something like that here., although I used the keras generator format because I could not deal with a video input pipeline using tensorflow methods.

I might refactor it to tf.Dataset someday but it's working for now. :)

@tensorflowbutler
Copy link
Member

Nagging Assignee @fchollet: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@lgeiger
Copy link
Contributor Author

lgeiger commented Nov 14, 2018

This isn't an issue on tensorflow 1.12 and above anymore. Thanks for the help everybody.

@lgeiger lgeiger closed this as completed Nov 14, 2018
@was84san
Copy link

was84san commented Nov 16, 2018

@ Igeiger, I tried to pass multiple inputs as a list of tf.dataset api to model fit directly, like this
model.fit ( [dataset1_iterator, dataset2_iterator] , .....)

then I got this error


 /home/wassan/tensorflow/venv/lib/python2.7/site- 
 packages/tensorflow/python/keras/engine/training.pyc in _standardize_user_data(self, x, y, sample_weight, cla$s_weight, batch_size, check_steps, steps_name, steps, validation_split)
    990         x, y, sample_weight = next_element
    991     x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
--> 992                                                      class_weight, batch_size)
    993     return x, y, sample_weights
    994 

/home/wassan/tensorflow/venv/lib/python2.7/site-packages/tensorflow/python/keras/engine/training.pyc in _standardize_weights(self, x, y, sample_weight, class$weight, batch_size)
   1115         feed_input_shapes,
   1116         check_batch_axis=False,  # Don't enforce the batch size.
-> 1117         exception_prefix='input')
   1118 
   1119     if y is not None:

/home/wassan/tensorflow/venv/lib/python2.7/site-packages/tensorflow/python/keras/engine/training_utils.pyc in standardize_input_data(data, names, shapes, che$k_batch_axis, exception_prefix)
    282     data = data.values if data.__class__.__name__ == 'DataFrame' else data
    283     data = [data]
--> 284   data = [standardize_single_array(x) for x in data]
    285 
    286   if len(data) != len(names):

/home/wassan/tensorflow/venv/lib/python2.7/site-packages/tensorflow/python/keras/engine/training_utils.pyc in standardize_single_array(x)
    216   if x is None:
    217     return None
--> 218   if x.shape is not None and len(x.shape) == 1:
    219     if tensor_util.is_tensor(x):
    220       return array_ops.expand_dims(x, axis=1)

AttributeError: 'PrefetchDataset' object has no attribute 'shape

And this is with tensorflow 1.12, so how you can pass multiple input using tf.dataset api with model fit not with model.fit_generator?
`

@lgeiger
Copy link
Contributor Author

lgeiger commented Nov 16, 2018

@ Igeiger, I tried to pass multiple inputs as a list of tf.dataset api to model fit directly, like this
model.fit ( [dataset1_iterator, dataset2_iterator] , .....)

Returning a list of tensors in a single dataset and then passing it to model.fit should work. Checkout this example: https://colab.research.google.com/drive/1h3FUGBhVsXnj6oEE3JDnC0WRFF-Zu__c#scrollTo=cjvaKWOqAQ3e

@srcolinas
Copy link

@lgeiger what about using dictionaries as targets? #25299 (comment)

@kristofgiber
Copy link

kristofgiber commented Sep 2, 2019

I can confirm this works in tensorflow 2.0.0-rc0. Multiple input and output, even without all the zipping:

(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()

ds = tf.data.Dataset.from_tensor_slices( ((train_images, dummydata), train_images) )
ds.shuffle(TRAIN_BUF).repeat().batch(BATCH_SIZE)

model.fit(train_dataset, steps_per_epoch=n_trainsamples//BATCH_SIZE)

@drasmuss
Copy link
Contributor

drasmuss commented Sep 5, 2019

This still seems broken to me (in tensorflow 2.0.0-rc0). See this snippet:

import tensorflow as tf
from tensorflow import keras

inputs = [keras.Input((1,), name="a"), keras.Input((1,), name="b")]
outputs = inputs[0] + inputs[1]
model = keras.Model(inputs=inputs, outputs=outputs)

list_input = [tf.zeros((10, 1)), tf.ones((10, 1))]
dict_input = {"a": tf.zeros((10, 1)), "b": tf.ones((10, 1))}

print(model.predict(list_input))
print(model.predict(dict_input))
print(model.predict(tf.data.Dataset.from_tensors(dict_input)))

# error here
print(model.predict(tf.data.Dataset.from_tensors(list_input)))

which gives

ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays: [<tf.Tensor: id=47, shape=(2, 10, 1), dtype=float32, numpy=
array([[[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]],...

@mindis
Copy link

mindis commented Nov 28, 2019

@drasmuss

workaround could be to convert list to dictionary in Dataset

ds=tf.data.Dataset.from_tensors(list_input)

def to_dict(lst):
return {'a':lst[0], 'b':lst[1]}

ds=ds.map(to_dict)

print(model.predict(ds))

@StenkinVlad
Copy link

I think I discovered the problem fro my situation. The problem was I am using the standalone Keras. Not the one imported from tendorflow. So the new features of feeding the iterator directly to model.fit() is valid only when you are usingtf.Kerasnot the standalone Keras.

Thx, my problem solved! Just have changed
import keras
import tensorflow as tf
to
import tensorflow as tf
from tensorflow import keras

@johngrabner
Copy link

I am not finding documentation for feeding models with multiple inputs with different dimensions with tf.data.
The above exchange leaves me still struggling for an understanding on feeding such models. May I asked for clarification?

print(f"tensoflow.__version__ = {tensorflow.__version__}")
# tensoflow.__version__ = 2.1.0-rc2

# A toy keras model with 2 inputs of different size
input_1 = tensorflow.keras.layers.Input(name='input_1', shape=(2,), dtype=numpy.float32)
input_2 = tensorflow.keras.layers.Input(name='input_2', shape=(3,), dtype=numpy.float32)
output = tensorflow.keras.layers.Concatenate(name='output_1')([input_1, input_2])
toy_model = tensorflow.keras.Model(inputs=[input_1, input_2], outputs=[output])
toy_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# in memory data, 2 samples for input_1
input_1_sample_1 = numpy.asarray ( [2,2], dtype=numpy.float32 )
input_1_sample_2 = numpy.asarray ( [22,22], dtype=numpy.float32 )
input_1_data = numpy.asarray( [ input_1_sample_1, input_1_sample_2 ] )
print(f"input_1_data.shape = {input_1_data.shape}")
# input_1_data.shape = (2, 2)

# in memory data, 2 samples for input_2
input_2_sample_1 = numpy.asarray ( [3,3,3], dtype=numpy.float32 )
input_2_sample_2 = numpy.asarray ( [33,33,33], dtype=numpy.float32 )
input_2_data = numpy.asarray( [ input_2_sample_1, input_2_sample_2 ] )
print(f"input_2_data.shape = {input_2_data.shape}")
# input_2_data.shape = (2, 3)

# in memory data, 2 samples for output 1
output_1_sample_1 = numpy.asarray( [2,2,3,3,3], dtype=numpy.float32 )
output_1_sample_2 = numpy.asarray( [22,22,33,33,33], dtype=numpy.float32 )
output_1_data = numpy.asarray( [ output_1_sample_1, output_1_sample_2], dtype=numpy.float32 )
print(f"output_1_data.shape = {output_1_data.shape}")
# output_1_data.shape = (2, 5)

def toy_generator_list():
    while True:
        yield [input_1_data, input_2_data], output_1_data, []

I can use the generator directly, but my goal is to move the generator to a full tf.data pipleline, but I am missing something fundamental to get started.

This works, but does not use tf.data:

toy_model.fit(x=toy_generator_list(), steps_per_epoch=3, epochs=2)

The following as close to a solution I have gotten to, but it fails

toy_dataset_from_generator = tensorflow.data.Dataset.from_generator(toy_generator_list, \
    output_types=(tensorflow.float32, tensorflow.float32, tensorflow.float32), \
        output_shapes=(([2,2],[2,3]), [2,5]) )


toy_model.fit(x=toy_dataset_from_generator, steps_per_epoch=3, epochs=2) 

Generates error

ValueError: The two structures don't have the same sequence length. Input structure has length 2, while the shallow structure has length 3.

I know that my request smells like "a request for help", it is, but please interpret it as a request for improved documentation. Stack overflow does not have anything on multiple inputs of different shapes.

btw:

  • The real model input is an image (255, 255, 3) and a document type (20,) with output 1-hot(40, 60) into a ctc.
  • The ideal tf.data chain would cache the preliminary processing, then augment this cache version for delivery to model.fit, where the model is fit across a network of servers.

@DavidPL1
Copy link

@johngrabner
The problem in your code is that your output_type and output_shape definitions differ.
Changing the output_type to ((tensorflow.float32, tensorflow.float32), tensorflow.float32) should to the trick.

For the sake of completeness, here is a minimal example of a dataset that expects two inputs (shapes (1,32) and (1,128)):

import tensorflow as tf
import numpy as np

def random_generator():
    for i in range(100):
        x1, x2, y = np.random.random((1,32)), np.random.random((1,128)), np.random.random((1,1))
        yield (x1, x2), y
        
toy_dataset = tf.data.Dataset.from_generator(
    random_generator,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(((1,32), (1,128)), (1,1))
)

@tinmodeHuang
Copy link

hey! guys.
I have been in trouble, the error below was thrown when the model with double inputs predicted.

Traceback (most recent call last):
  File "practice.py", line 279, in <module>
    action = np.argmax([0.1, 1, 0.2]*agent.get_qs(current_state))
  File "practice.py", line 186, in get_qs
    return self.model.predict(state)[0]
  File "C:\Users\liuzhen\.conda\envs\python37\lib\site-packages\keras\engine\training.py", line 1380, in predict
    x, _, _ = self._standardize_user_data(x)
  File "C:\Users\liuzhen\.conda\envs\python37\lib\site-packages\keras\engine\training.py", line 757, in _standardize_user_data
    exception_prefix='input')
  File "C:\Users\liuzhen\.conda\envs\python37\lib\site-packages\keras\engine\training_utils.py", line 95, in standardize_input_data
    data = [standardize_single_array(x) for x in data]
  File "C:\Users\liuzhen\.conda\envs\python37\lib\site-packages\keras\engine\training_utils.py", line 95, in <listcomp>
    data = [standardize_single_array(x) for x in data]
  File "C:\Users\liuzhen\.conda\envs\python37\lib\site-packages\keras\engine\training_utils.py", line 30, in standardize_single_array
    elif x.ndim == 1:
AttributeError: 'list' object has no attribute 'ndim'

the 'state' is a list of two nd-arrays there

model = Model(inputs=[input1, input2], outputs=predictions)

I would really appreciate it if anyone is willing to give some tips

@MatanSandori
Copy link

hey!

consider trying this:

relu = tf.keras.activations.relu;
layers = tf.keras.layers;

## Input 1

inputs = layers.Input(shape=(1)); # First input for 'data_a'
outputs = layers.Dense(128, activation=relu)(inputs);

model_data_a = tf.keras.Model(inputs, outputs); # Build 'model_data_a' for 'data_a'

## Input 2

inputs = layers.Input(shape=(1)); # Second input for 'data_b'
outputs = layers.Dense(128, activation=relu)(inputs);

model_data_b = tf.keras.Model(inputs, outputs); # Build 'model_data_b' for 'data_b'

## Model

inputs = layers.Concatenate()([model_data_a.output, model_data_b.output]); ### Get the outputs of 'model_data_a' , 'model_data_b' and combine the outputs  
outputs = layers.Dense(1, activation=relu)(inputs);

model = tf.keras.Model([model_data_a.input, model_data_b.input], outputs); ## Add both inputs

## Compile and fit the model

model.compile(optimizer=tf.keras.optimizers.RMSprop(),
              loss=tf.keras.losses.mae,
              metrics=tf.keras.metrics.mse);

model.fit([data_a, data_b], labels, epochs=5);

hope this helps!

@jiunyen-ching
Copy link

@MatanSandori I think this method will work only if the inputs are of the same shape?

@rmgogogo
Copy link

rmgogogo commented May 1, 2023

# generate two parts, one is input and other is output
def generator():
    for index, row in df.iterrows():
        yield (
            {
                'float_input': row['float_col'],
                'int_input': row['int_col'],
                'str_input': row['str_col'],
                'list_input': row['list_col']
            },
            row['int_col']
        )

# Use output_signature to specify the inputs (maybe also multiple outputs but I didn't try
dataset = tf.data.Dataset.from_generator(generator, 
                                         output_signature=(
                                             {
                                                 'float_input': tf.TensorSpec(shape=(), dtype=tf.float32),
                                                 'int_input': tf.TensorSpec(shape=(), dtype=tf.int32),
                                                 'str_input': tf.TensorSpec(shape=(), dtype=tf.string),
                                                 'list_input': tf.TensorSpec(shape=(2,), dtype=tf.int32)
                                             },
                                             tf.TensorSpec(shape=(), dtype=tf.int32)
                                         ))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.