-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
switch causal mask for is_causal flag #91171
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91171
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 200dc51: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D42089340 |
This pull request was exported from Phabricator. Differential Revision: D42089340 |
7c975d8
to
58177dc
Compare
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 02c012ee38591216da46b1d5bfe2ae91a7343c74
This pull request was exported from Phabricator. Differential Revision: D42089340 |
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: a6f79e35c105b509c372d89298e00b261d813cf4
58177dc
to
4c47d16
Compare
This pull request was exported from Phabricator. Differential Revision: D42089340 |
4c47d16
to
030eae5
Compare
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 729b6e35dcd419249b6dcf13f68030322ba94ab5
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 3172b6cd8fe548fa3b409f92f34356abe1c3b878
030eae5
to
97b6440
Compare
This pull request was exported from Phabricator. Differential Revision: D42089340 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D42089340 |
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 81127b0aed62df069b2d0d06ae7cf1ee2152caf4
97b6440
to
f955e6b
Compare
This pull request was exported from Phabricator. Differential Revision: D42089340 |
f955e6b
to
fe5fc74
Compare
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: f2859beec34efea0ca60a70d0dd41c739a27d2f8
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: c38778f29c50b1a25b12669acb2d23aa00c28b32
db5835a
to
79d8d2e
Compare
This pull request was exported from Phabricator. Differential Revision: D42089340 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D42089340 |
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 2587282fa344cceff7c8a72e521510ea215d3c5c
79d8d2e
to
7282d68
Compare
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 1497599d3d686ea857c2d4ef4facd079e15a6223
7282d68
to
7d84fb6
Compare
This pull request was exported from Phabricator. Differential Revision: D42089340 |
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 2cf78877c34c9b1988e8c00607e19d551f0ff88e
7d84fb6
to
9fd432c
Compare
This pull request was exported from Phabricator. Differential Revision: D42089340 |
Summary: Pull Request resolved: pytorch#91171 switch causal mask for is_causal flag Test Plan: sandcastle & github Differential Revision: D42089340 fbshipit-source-id: 7c665db6878edf1e9cfc831096b8189612c2b8bb
9fd432c
to
200dc51
Compare
This pull request was exported from Phabricator. Differential Revision: D42089340 |
def test_train_with_is_causal(self, device): | ||
iters = 100 | ||
layer = nn.TransformerEncoderLayer( | ||
d_model=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to exercise the fused kernels I would change the d_model to be 16 hence 16/2 heads = 8. Also, the failure initially occurred for fp16 right? Maybe we should do this test with that.
Would this test fail prior to this PR? I am not 100% what its testing, besides more of an end-to-end exercise. Not sure what failures this shows though
@@ -4847,7 +4847,7 @@ def _in_projection( | |||
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied | |||
need_attn_weights (bool): If true, the second return value will contain the attention weights used; | |||
otherwise, the second return value is unspecified | |||
is_causal (bool): If true, assumes causal attention masking; for this case, attn_mask should not be set. | |||
is_causal (bool): If true, assumes causal attention masking and ignores attn_mask. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a heads up I know there is a PR for xformers; facebookresearch/xformers#587 that I think will allow for both casual plus additive attn bias. But I think for now this is the right way to go
make_causal = (is_causal is True) | ||
|
||
if is_causal is None: | ||
if mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if causal and mask is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small comment about the test but this looks good
@pytorchbot merge -f "merge" |
You need to provide a reason for using force merge, in the format @pytorchbot merge -f 'Explanation'.
|
@pytorchbot merge -f "This is pre-tested in a previous CI run" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: switch causal mask for is_causal flag
Test Plan: sandcastle & github
Differential Revision: D42089340