Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDPA] update type hint for scaled_dot_product_attention and documentation #94008

Closed
wants to merge 21 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Feb 2, 2023

Summary

  • Adds type hinting support for SDPA
  • Updates the documentation adding warnings and notes on the context manager
  • Adds scaled_dot_product_attention to the non-linear activation function section of nn.functional docs

cc @svekars @carljparker

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94008

Note: Links to docs will display an error until the docs builds have been completed.

❗ 2 Active SEVs

There are 2 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 2a93079:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Skylion007 Skylion007 changed the title [SDPA] update type hint structure for scaled_dot_proudct_attention [SDPA] update type hint structure for scaled_dot_product_attention Feb 2, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 6, 2023

This PR needs to be approved by an authorized maintainer before merge.

@drisspg drisspg added the module: docs Related to our documentation, both in docs/ and docblocks label Feb 6, 2023
@drisspg drisspg changed the title [SDPA] update type hint structure for scaled_dot_product_attention [SDPA] update type hint for scaled_dot_product_attention and documentation Feb 6, 2023
@drisspg drisspg force-pushed the add_typehints_to_sdpa branch 3 times, most recently from 0c69708 to 86fe908 Compare February 6, 2023 23:00
@drisspg
Copy link
Contributor Author

drisspg commented Feb 7, 2023

\text{attn\_mask} = $torch.ones(L,S, torch.bool, diagonal=0).tril()$ if is\_causal

\text{attn\_mask} = $torch.masked\_fill(!attn\_mask, -\infty)$ if attn\_mask.dtype==torch.bool

\text{attn\_weight} = $\text{torch.softmax}(\frac{QK^T}{\sqrt{d^k}}+attn\_mask)$ 

\text{attn\_weight} = $torch.dropout(\text{attn\_weight}, dropout_p)$

\text{return} $torch.matmul(\text{attn\_weight},V)$

Screenshot 2023-02-06 at 4 59 23 PM

Curious if we think this latex math should be added to the longish docstring to explain the math fallback

@drisspg drisspg force-pushed the add_typehints_to_sdpa branch 2 times, most recently from 61983fb to 36f569f Compare February 7, 2023 04:33
@drisspg
Copy link
Contributor Author

drisspg commented Feb 7, 2023

Should I add a warn once to scaled_dot_product_efficient_attention, and scaled_dot_product_flash_attention saying:
" Memory-efficeint/FlashAttention SDPA is a beta feature. See the documentation for torch.nn.scaled_dot_product_attention for further information"

warn onces are kind of annoying but we might want to very directly point uses to the docs

@drisspg drisspg force-pushed the add_typehints_to_sdpa branch 3 times, most recently from f744c3a to 55cbd55 Compare February 8, 2023 03:04
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I added the usual nitpicky grammar / formatting comments for public-facing docs, nothing major.

FYI there's a style guide here for consistency in formatting module docs. I realize this is not a module, but maybe some of the content there will be useful for maintaining consistency.

torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
torch/backends/cuda/__init__.py Outdated Show resolved Hide resolved
@drisspg drisspg force-pushed the add_typehints_to_sdpa branch 2 times, most recently from 6c25dcc to d6c8538 Compare February 8, 2023 22:23
@drisspg drisspg requested a review from cpuhrsch February 9, 2023 21:11
@drisspg
Copy link
Contributor Author

drisspg commented Feb 10, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@drisspg
Copy link
Contributor Author

drisspg commented Feb 10, 2023

@pytorchbot merge -f "all checks are passing"

@drisspg drisspg added the release notes: nn release notes category label Feb 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: docs Related to our documentation, both in docs/ and docblocks module: multi-headed-attention release notes: nn release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants