Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow datasets to provide the number of examples they contain #36531

Closed
Flamefire opened this issue Feb 7, 2020 · 2 comments
Closed

Allow datasets to provide the number of examples they contain #36531

Flamefire opened this issue Feb 7, 2020 · 2 comments
Assignees
Labels
comp:data tf.data related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.1 for tracking issues in 2.1 release type:feature Feature requests

Comments

@Flamefire
Copy link
Contributor

System information

  • TensorFlow version (you are using): 2.1.0
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state.

Currently there is no good way to get to the number of samples or batches contained by a dataset although the information is usually available.

What you can do: sum(1 for _ in dataset) but this might not do what one wants:
When the dataset is batched it will return the number of batches including the trailing one. MultiWorkerMirroredStrategy can't handle that.

Usually this information is already available, see e.g. tensorflow/datasets#1403

Will this change the current api? How?

Add a member num_examples and/or an overload for __len__

Who will benefit with this feature?

  • Everyone using MultiWorkerMirroredStrategy
  • Everyone using steps_per_epoch
  • TF itself as the number of samples/batches is known before executing the training loop avoid status reports like 10/Unknown
  • This would help to provide correct behavior in 6be131d#diff-f8dd40712ac721c1b363e1a1ec44c1a3R741-R747

Any Other info.

There is an experimental op cardinality which might be very related. However it often (always?) returns "Unknown". Tested with MNIST from TFDS.

@Conchylicultor
Copy link
Member

Conchylicultor commented Feb 7, 2020

For more context, TFDS cannot provides the tf.data.Dataset cardinality because it is not supported by TFRecordDataset and (maybe) interleave op. If there was a way to manually overwrite the cardinality of a tf.data.Dataset, we could forward the number of examples to the tf.data.Dataset.

Related issue: tensorflow/datasets#1456

@ravikyram ravikyram added comp:data tf.data related issues TF 2.1 for tracking issues in 2.1 release type:feature Feature requests labels Feb 10, 2020
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 10, 2020
@Conchylicultor
Copy link
Member

Conchylicultor commented Feb 18, 2020

Thanks to jsmira, this should be fixed in d25235b with tf.data.experimental.assert_cardinality(123)

ds = tf.data.TFRecordDataset("examples.tfrecord")
tf.data.experimental.cardinality(ds)  # tf.data.experimental.UNKNOWN_CARDINALITY

ds = ds.apply(tf.data.experimental.assert_cardinality(42))
tf.data.experimental.cardinality(ds).numpy()  # 42

I'll update the TFDS side. But this issue can be closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.1 for tracking issues in 2.1 release type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

5 participants