-
Notifications
You must be signed in to change notification settings - Fork 814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch_first support in MHA and update docs #839
Conversation
Codecov Report
@@ Coverage Diff @@
## master #839 +/- ##
=======================================
Coverage 77.43% 77.44%
=======================================
Files 43 44 +1
Lines 3045 3055 +10
=======================================
+ Hits 2358 2366 +8
- Misses 687 689 +2
Continue to review full report at Codecov.
|
r""" A multi-head attention container | ||
|
||
Args: | ||
nhead: the number of heads in the multiheadattention model | ||
in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear). | ||
attention_layer: The attention layer. | ||
attention_layer: The custom attention layer. The input sent from MHA container to the attention layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also take care of broadcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The custom attention layer needs to take care of broadcasting. Updated the doc to reflect this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd then augment the shape to (..., seq, batch, feature)
and explain what that means and also that it's optional, i.e. enough to only handle 3-dim.
|
||
Examples:: | ||
>>> SDP = torchtext.models.ScaledDotProduct(0.1) | ||
>>> SDP = torchtext.modules.ScaledDotProduct(dropout=0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we mirror the pytorch path conventions here?
torchtext.nn and torchtext.nn.functional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other two domains use torchvision/audio.models
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but this isn't a model, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, Will fix it.
No description provided.