-
Notifications
You must be signed in to change notification settings - Fork 815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch should be sorted by decreasing size. #95
Conversation
`rnn.pack_padded_sequence` requires that a minibatch be sorted by decreasing order. It's important for `self.sort_key(x)` to sort the data in increasing order to for curriculum learning but for the rows in the batch to be sorted in decreasing order.
Not sure how i feel about this (and thus curriculum learning) being on by default. |
I support the idea of intra-batch sorting being the opposite of inter-batch sorting, since the only reason for the former is to support packed sequences. It won’t turn curriculum learning on by default if you use a shuffled iterator like BucketIterator. |
I believe that what should happen here is just a reverse() call though, since the sorts should always use the same key and just have opposite order. |
good point, this was my main concern. |
@Deepblue129 if you agree that just adding |
Thank you for your comments. Made the change you suggested. Added a comment because that |
You need to do Also, maybe the comment should just say something like " |
torchtext/data/iterator.py
Outdated
@@ -157,8 +157,8 @@ def __iter__(self): | |||
continue | |||
self.iterations += 1 | |||
self._iterations_this_epoch += 1 | |||
minibatch_decreasing_size = sorted(minibatch, key=lambda x: -self.sort_key(x)) | |||
yield Batch(minibatch_decreasing_size, self.dataset, self.device, | |||
# NOTE: Find out more here for why we reverse: https://github.com/pytorch/text/pull/95 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
LGTM when CI passes. |
@nelson-liu Did you know we both go to UW Comp Sci? And both worked @ Google? |
Flake8 failed with |
Hi, @jekbradbury , I still stumbled on this issue even with this commit(a504b9). See OpenNMT/OpenNMT-py#189. Below is my analysis. I think the I think the fix should be explicitly sorting the I just tracked down the issue to the text code and found this issue. Hope I did't misunderstand the code. |
I don't think this change is intuitive as it goes against what is expected when using a sort key in Python. This also breaks a popular downstream library: https://github.com/IBM/pytorch-seq2seq |
Intended effect:
rnn.pack_padded_sequence
requires that a minibatch be sorted by decreasing order.Proposed Solution:
Flip the sign of the
self.sort_key(x)
when creating the Batch.