Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch should be sorted by decreasing size. #95

Merged
merged 4 commits into from
Aug 17, 2017
Merged

Batch should be sorted by decreasing size. #95

merged 4 commits into from
Aug 17, 2017

Conversation

PetrochukM
Copy link
Contributor

@PetrochukM PetrochukM commented Aug 16, 2017

Intended effect:

  • rnn.pack_padded_sequence requires that a minibatch be sorted by decreasing order.
  • Curriculum learning requires that the batches are sorted in increasing order.

Proposed Solution:
Flip the sign of the self.sort_key(x) when creating the Batch.

`rnn.pack_padded_sequence` requires that a minibatch be sorted by decreasing order.

It's important for `self.sort_key(x)` to sort the data in increasing order to for curriculum learning but for the rows in the batch to be sorted in decreasing order.
@nelson-liu
Copy link
Contributor

Not sure how i feel about this (and thus curriculum learning) being on by default.

@jekbradbury
Copy link
Contributor

I support the idea of intra-batch sorting being the opposite of inter-batch sorting, since the only reason for the former is to support packed sequences. It won’t turn curriculum learning on by default if you use a shuffled iterator like BucketIterator.

@jekbradbury
Copy link
Contributor

I believe that what should happen here is just a reverse() call though, since the sorts should always use the same key and just have opposite order.

@nelson-liu
Copy link
Contributor

It won’t turn curriculum learning on by default if you use a shuffled iterator like BucketIterator.

good point, this was my main concern.

@jekbradbury
Copy link
Contributor

@Deepblue129 if you agree that just adding minibatch.reverse() before the Batch constructor solves your issue, I think it'd be the most generic solution and I'll merge that for 0.2

@PetrochukM
Copy link
Contributor Author

Thank you for your comments. Made the change you suggested.

Added a comment because that reverse seems out of nowhere unless you have extra context.

@jekbradbury
Copy link
Contributor

You need to do minibatch.reverse() in a separate line, because .reverse() returns None (you could also use reversed but there's no reason not to do it in-place here).

Also, maybe the comment should just say something like "pack_padded_sequence requires that a minibatch be sorted by decreasing order, which requires reversing relative to typical sort keys" rather than asking the reader to copy and paste a GitHub URL?

@@ -157,8 +157,8 @@ def __iter__(self):
continue
self.iterations += 1
self._iterations_this_epoch += 1
minibatch_decreasing_size = sorted(minibatch, key=lambda x: -self.sort_key(x))
yield Batch(minibatch_decreasing_size, self.dataset, self.device,
# NOTE: Find out more here for why we reverse: https://github.com/pytorch/text/pull/95

This comment was marked as off-topic.

This comment was marked as off-topic.

@nelson-liu
Copy link
Contributor

LGTM when CI passes.

@PetrochukM
Copy link
Contributor Author

@nelson-liu Did you know we both go to UW Comp Sci? And both worked @ Google?

@jekbradbury
Copy link
Contributor

Flake8 failed with ./torchtext/data/iterator.py:160:91: E501 line too long (107 > 90 characters)

@jekbradbury jekbradbury merged commit a5049b9 into pytorch:master Aug 17, 2017
@JianyuZhan
Copy link

JianyuZhan commented Aug 26, 2017

Hi, @jekbradbury , I still stumbled on this issue even with this commit(a504b9). See OpenNMT/OpenNMT-py#189. Below is my analysis.

I think the minibatch.reverse() fix is not correct. In my case, it is using the pool() method to return batches for iterating over. The pool() method would shuffle the minibatch examples, then later when we minibatch.reverse() in the Iterator.__iter__ , it is still no satisfying the requirement of decreasing size, and thus it crashed when calling the rnn.pack_padded_sequence(embedding, lengths).

I think the fix should be explicitly sorting the minibatch by decreasing size before the Batch constructor. I previously (wrongly) fixed it in the calling site of my call to rnn.pack_padded_sequence(embedding, lengths). But I think this is not the right way to do it. It is the semantics that should be hided inside the Iterator.

I just tracked down the issue to the text code and found this issue. Hope I did't misunderstand the code.

@kyteague
Copy link

kyteague commented Sep 19, 2017

I don't think this change is intuitive as it goes against what is expected when using a sort key in Python.

This also breaks a popular downstream library: https://github.com/IBM/pytorch-seq2seq
I made an issue about it there as well: IBM/pytorch-seq2seq#77

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants