Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Model Data Cardinality Check Ambiguous #47116

Open
whitead opened this issue Feb 12, 2021 · 6 comments
Open

Custom Model Data Cardinality Check Ambiguous #47116

whitead opened this issue Feb 12, 2021 · 6 comments
Assignees
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.4 for issues related to TF 2.4 type:bug Bug

Comments

@whitead
Copy link

whitead commented Feb 12, 2021

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution: Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.3.2/2.4.1
  • Python version: 3.7.8

Describe the current behavior

Training Keras custom model when inputs are lists fail data cardinality check if leading axis is not batch. The code below works in 2.3.2 but not 2.4.1. An example of why you would not want a leading axis as batch is a per-batch weight or dynamic parameter (see example code).

Describe the expected behavior

This worked in 2.3.2 but was changed. I read the release notes and cannot seem to understand why this behavior changed. I am unsure how to fix. Either a flag to remove this check, a way to change the input spec of model call, or clarification on how to allow inputs that do not have a leading batch axis would help me fix this problem.

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np

class MyModel(tf.keras.Model):    
    def call(self, inputs, training):
        # make a model with a batched input (x) and per-batch tensor (s)
        x = inputs[0]
        s = inputs[1]
        return tf.reduce_mean(x * s, axis=1)
    
m = MyModel()
x = np.ones((10, 2))
s = np.ones(1)
y = np.ones(10)

m.compile('sgd', loss='mean_squared_error')

# works as call
m([x, s])

# fails on TF 2.4.1
# succeeds on TF 2.3.2
m.train_on_batch(x=[x, s], y=y)

** Error Message**

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-1e85b8f34d3e> in <module>
     21 # fails on 2.4.1
     22 # succeeds on 2.3.2
---> 23 m.train_on_batch(x=[x, s], y=y)

~/miniconda3/envs/mmm/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics, return_dict)
   1723       iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x,
   1724                                                     y, sample_weight,
-> 1725                                                     class_weight)
   1726       self.train_function = self.make_train_function()
   1727       logs = self.train_function(iterator)

~/miniconda3/envs/mmm/lib/python3.7/site-packages/tensorflow/python/keras/engine/data_adapter.py in single_batch_iterator(strategy, x, y, sample_weight, class_weight)
   1511     data = (x, y, sample_weight)
   1512 
-> 1513   _check_data_cardinality(data)
   1514   dataset = dataset_ops.DatasetV2.from_tensors(data)
   1515   if class_weight:

~/miniconda3/envs/mmm/lib/python3.7/site-packages/tensorflow/python/keras/engine/data_adapter.py in _check_data_cardinality(data)
   1527           label, ", ".join(str(i.shape[0]) for i in nest.flatten(single_data)))
   1528     msg += "Make sure all arrays contain the same number of samples."
-> 1529     raise ValueError(msg)
   1530 
   1531 

ValueError: Data cardinality is ambiguous:
  x sizes: 10, 1
  y sizes: 10
Make sure all arrays contain the same number of samples.
@whitead whitead added the type:bug Bug label Feb 12, 2021
@Saduf2019 Saduf2019 added the TF 2.4 for issues related to TF 2.4 label Feb 15, 2021
@Saduf2019 Saduf2019 assigned Saduf2019 and unassigned ravikyram Feb 15, 2021
@Saduf2019 Saduf2019 added the comp:keras Keras related issues label Feb 15, 2021
@Saduf2019
Copy link
Contributor

Saduf2019 commented Feb 15, 2021

I have replicated the issue on tf 2.4 and nightly, please find the gist here.

@Saduf2019 Saduf2019 assigned rmothukuru and unassigned Saduf2019 Feb 15, 2021
@rmothukuru rmothukuru assigned Saduf2019 and unassigned rmothukuru Feb 15, 2021
@Saduf2019
Copy link
Contributor

@whitead
I ran the code shared on tf 2.3 and face a different error whereas the issue description says the code works fine on tf 2.3, please find the gist here.

@Saduf2019 Saduf2019 added the stat:awaiting response Status - Awaiting response from author label Feb 15, 2021
@whitead
Copy link
Author

whitead commented Feb 15, 2021

Sorry about that @Saduf2019. I've edited the post for the shape error. I've attached output for 2.3.2 below and here is a gist

image

@Saduf2019
Copy link
Contributor

@whitead
You can safely ignore the warning, to suppress the warnings please use.

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf

@Saduf2019 Saduf2019 added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author labels Feb 16, 2021
@whitead
Copy link
Author

whitead commented Feb 16, 2021

Hi @Saduf2019 that was the message from 2.3.2. I will repost the error message for 2.4.1 below. My issue is that the code works in 2.3.2 but not in 2.4.1 and I do not know how to fix.

image

@Saduf2019 Saduf2019 added regression issue To spot regression issues in latest version and removed stat:awaiting response Status - Awaiting response from author labels Feb 17, 2021
@Saduf2019 Saduf2019 assigned ymodak and unassigned Saduf2019 Feb 17, 2021
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 18, 2021
@ymodak ymodak removed the regression issue To spot regression issues in latest version label Mar 12, 2021
@saikumarchalla
Copy link

Issue still exists in TF 2.5. Please find the gist here.Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.4 for issues related to TF 2.4 type:bug Bug
Projects
None yet
Development

No branches or pull requests

6 participants