New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] tf.data.Dataset sort and skip buckets #14250

Closed
georgesterpu opened this Issue Nov 4, 2017 · 8 comments

Comments

Projects
None yet
5 participants
@georgesterpu
Contributor

georgesterpu commented Nov 4, 2017

Hi
It would be useful to sort the variable length inputs by their lengths in order to accelerate the training process. However I cannot find this functionality yet.

In [1], @guillaumekln already suggested something similar through his code snippet, yet the requested feature was batching inputs of similar length together, regardless of the processing order of the batches, and the solution of @mrry in [2] using group_by_window() addressed this request just fine. First question: Would it be possible to make the iterator return the batches in the ascending order of their ids (given by key_func), while maintaining the shuffling operation applied before batching?

Additionally, I would like to skip the longer sentences early in training, with a length threshold that would gradually increase depending on the global_step. Second question: Could you reserve one batch id (e.g. -1) in group_by_window to tag the batches that will be skipped ? At the moment, it seems that all the ids are considered, even the negative values, and it would not be restrictive at all to allow only positive values (as there would still be 63 bits left to group the inputs). Thus, in key_func we could simply compare the input length with the threshold and return a negative value when it is above it.

Apologies if both functionalities are already available, feel free to stackoverflow me.

@jart

This comment has been minimized.

Member

jart commented Nov 6, 2017

feel free to stackoverflow me

I haven't heard that one before. StackOverflow is probably your best bet when there's any uncertainty about whether or not a feature is necessary. We triagers can also take care of the work of tagging folks on the team for you in the future. I will note that you made me smile, so I'm going to try my best to fill in for the community.

Speaking as someone who only recently began learning contrib.data, from what I understand these are stateful iterative ops, so I'm not sure if the sort of sorting you described would be possible. Have you considered collecting all the data and using tf.nn.top_k? I'm also assuming you're also familiar with tf.cond which has been traditionally used to implement skipping type logic.

If those things don't help, please give StackOverflow a try. I'll be happy to reopen this issue if provided new information.

@jart jart closed this Nov 6, 2017

@georgesterpu

This comment has been minimized.

Contributor

georgesterpu commented Nov 6, 2017

You are completely right, @jart. It would be more appropriate if I rephrase the last sentence as : "After going through the entire support thread in [1] and searching related content on StackOverflow too, to the best of my knowledge these functionalities have not been implemented yet."

If you are unsure about the possibility of sorting a Dataset after applying the group_by_window() transformation, perhaps it would be fair to re-open the issue. Could you also help me find out what is the implicit order of the batches when this transformation is applied, please ?

On batch skipping, could you please provide me an example of ignoring the batches grouped by a range of ids within group_by_window ? I am probably overthinking this, yet it seems that the id returned by key_func would not be available anymore at runtime. Maybe reduce_func could discard some batches instead ?

Thank you

@jart

This comment has been minimized.

Member

jart commented Nov 6, 2017

Just so there's no misunderstanding, our team is always grateful when members of the community take the time to leave us feedback. My goal is to be friendly and helpful in this process. Per your request, I'm happy to reopen this issue, so the next triager can take a look in a day or so.

@jart jart reopened this Nov 6, 2017

@georgesterpu

This comment has been minimized.

Contributor

georgesterpu commented Nov 7, 2017

Thanks, Justine. One workaround that crossed my mind is storing the inputs in separate TFRecords based on their lengths and concatenating several Datasets when creating the iterator. Having this feature on the iterator would only make things look more compact.

@fanlu

This comment has been minimized.

Contributor

fanlu commented Dec 14, 2017

I also meet this problem. And more I want to use distribute tf to train my deepspeech2 model,So It requests dataset.get_next() to yield the same bucket_id data in distributed machines.I also try dataset.shard(), but this function returned dataset's bucket_id are not same in one batch.How do I implement this function? @mrry

@tensorflowbutler

This comment has been minimized.

Member

tensorflowbutler commented Jan 9, 2018

Nagging Awaiting TensorFlower: It has been 14 days with no activityand the awaiting tensorflower label was assigned. Please update the label and/or status accordingly.

@tensorflowbutler

This comment has been minimized.

Member

tensorflowbutler commented Jan 24, 2018

Nagging Awaiting TensorFlower: It has been 14 days with no activity and the awaiting tensorflower label was assigned. Please update the label and/or status accordingly.

@georgesterpu

This comment has been minimized.

Contributor

georgesterpu commented May 22, 2018

Hello, as a future reference, skipping examples longer than some max_length could be done like this:

dataset = dataset.filter(lambda elem: tf.shape(elem)[0] < max_length)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment