Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch causal mask for is_causal flag #91171

Closed
wants to merge 1 commit into from

Conversation

mikekgfb
Copy link
Contributor

Summary: switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 20, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91171

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 200dc51:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 24, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 02c012ee38591216da46b1d5bfe2ae91a7343c74
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 25, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: a6f79e35c105b509c372d89298e00b261d813cf4
@albanD albanD removed their request for review December 25, 2022 17:48
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 28, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 729b6e35dcd419249b6dcf13f68030322ba94ab5
mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 28, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 3172b6cd8fe548fa3b409f92f34356abe1c3b878
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 28, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 81127b0aed62df069b2d0d06ae7cf1ee2152caf4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 28, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: f2859beec34efea0ca60a70d0dd41c739a27d2f8
mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 28, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: c38778f29c50b1a25b12669acb2d23aa00c28b32
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 29, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 2587282fa344cceff7c8a72e521510ea215d3c5c
mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 29, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 1497599d3d686ea857c2d4ef4facd079e15a6223
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

mikekgfb added a commit to mikekgfb/pytorch that referenced this pull request Dec 29, 2022
Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 2cf78877c34c9b1988e8c00607e19d551f0ff88e
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

Summary:
Pull Request resolved: pytorch#91171

switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

fbshipit-source-id: 7c665db6878edf1e9cfc831096b8189612c2b8bb
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D42089340

def test_train_with_is_causal(self, device):
iters = 100
layer = nn.TransformerEncoderLayer(
d_model=2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to exercise the fused kernels I would change the d_model to be 16 hence 16/2 heads = 8. Also, the failure initially occurred for fp16 right? Maybe we should do this test with that.

Would this test fail prior to this PR? I am not 100% what its testing, besides more of an end-to-end exercise. Not sure what failures this shows though

@@ -4847,7 +4847,7 @@ def _in_projection(
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
need_attn_weights (bool): If true, the second return value will contain the attention weights used;
otherwise, the second return value is unspecified
is_causal (bool): If true, assumes causal attention masking; for this case, attn_mask should not be set.
is_causal (bool): If true, assumes causal attention masking and ignores attn_mask.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads up I know there is a PR for xformers; facebookresearch/xformers#587 that I think will allow for both casual plus additive attn bias. But I think for now this is the right way to go

make_causal = (is_causal is True)

if is_causal is None:
if mask is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if causal and mask is not None

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small comment about the test but this looks good

@mikekgfb
Copy link
Contributor Author

@pytorchbot merge -f "merge"

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 30, 2022

You need to provide a reason for using force merge, in the format @pytorchbot merge -f 'Explanation'.
The explanation needs to be clear on why this is needed. Here are some good examples:

  • Bypass checks due to unrelated upstream failures from ...
  • This is a minor fix to ..., which shouldn't break anything
  • This is pre-tested in a previous CI run
  • Bypass flaky ... check

@mikekgfb
Copy link
Contributor Author

@pytorchbot merge -f "This is pre-tested in a previous CI run"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants