New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.contrib.data.Dataset does not handle well with last elements with is fewer than batch size #13745

Closed
zh794390558 opened this Issue Oct 16, 2017 · 9 comments

Comments

Projects
None yet
8 participants
@zh794390558

zh794390558 commented Oct 16, 2017

tf.contrib.data.Dataset does not handle well with last elements with is fewer than batch size.

Maybe batch_size = 10, but last batch has 9 elements.

@mrry

This comment has been minimized.

Contributor

mrry commented Oct 16, 2017

When this is that case dataset.batch(10) will give you zero or more batches of 10 elements, followed by a single batch of 9 elements. If you don't want the smaller batch, you can use dataset.apply(tf.contrib.data.batch_and_drop_remainder(10)).

What would you recommend instead?

@zh794390558

This comment has been minimized.

zh794390558 commented Oct 17, 2017

I think which may have a param for dataset.batch() which like allow_smaller_final_batch of tf.train.batch().

And tf.contrib.data.batch_and_drop_remainder is interface use dataset.batch() which i want dataset.padded_batch()

@mrry

This comment has been minimized.

Contributor

mrry commented Oct 17, 2017

If you're using Dataset.padded_batch(), you can compose it with a Dataset.filter() to exclude smaller batches. For example, if your dataset has three components in each element (x, y, z):

dataset = dataset.padded_batch(batch_size, ...).filter(lambda x, y, z: tf.equal(tf.shape(x), batch_size))
@zh794390558

This comment has been minimized.

zh794390558 commented Oct 20, 2017

@mrry Thanks, it's work.

@chenghuige

This comment has been minimized.

chenghuige commented Jul 27, 2018

What if I want to extend final batch to 10 elements from 9 by extending one element from existing 9 ?

@mrry

This comment has been minimized.

Contributor

mrry commented Jul 27, 2018

@chenghuige You can handle that by adding a Dataset.map() that uses tf.pad() to pad each batch up to 10 elements.

@lixiang-ucas

This comment has been minimized.

lixiang-ucas commented Oct 5, 2018

@mrry can you give an example of extending one element from existing 9? What if my input tensor has a shape of [batch_size,10,3]?
Thank you.

@NikkiZy

This comment has been minimized.

NikkiZy commented Oct 24, 2018

I have met the same problem, that is what i tried:
dataset = dataset.padded_batch(....., drop_reminder = True)
dataset = dataset.repeat()
and it works for me. Good luck

@JorgeCeja

This comment has been minimized.

JorgeCeja commented Oct 30, 2018

tf.contrib.data.batch_and_drop_remainder() is DEPRECATED

Use tf.data.Dataset.batch(..., drop_remainder=True) INSTEAD

Source: Tensorflow Docs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment