Skip to content

Conversation

jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Mar 27, 2024

As per title. Idea is to show how to implement MHA using NJT + torch.compile and get nice speedups.

TODO (future work):

  • Link to more thorough comparison of strided / jagged NTs (write this)
  • Update part before MHA to focus on jagged?
  • Update SDPA tutorial

Copy link

pytorch-bot bot commented Mar 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2813

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a4feb18 with merge base d3cf027 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jbschlosser
Copy link
Contributor Author

@svekars is it possible to see a rendered version of the docs? I tried clicking the link Preview Python docs built from this PR` but I get:

<Error>
  <Code>AccessDenied</Code>
  <Message>Access Denied</Message>
  <RequestId>NJK553V5M50YSYCH</RequestId>
  <HostId>x0Kg4Bq9i51Qa+uIaXFdTbhzI6iAGdRlFYfDhjYjo1vUyaLDXESxv/8jEJMVjFHDU6LUUPFkNsI=</HostId>
</Error>

@svekars
Copy link
Contributor

svekars commented Mar 28, 2024

@jbschlosser - there was an error on one of the workers. The preview only becomes available after the manager finishes building. Now it passes and the preview is available.

Copy link
Contributor

@svekars svekars left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An editorial pass. Looks good overall.

@drisspg
Copy link
Contributor

drisspg commented Mar 29, 2024

@jbschlosser
Copy link
Contributor Author

Should we also update the usage here: https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html?

good call yes

@svekars
Copy link
Contributor

svekars commented Apr 30, 2024

@jbschlosser the PR looks good - do you want to finish up the TODOs?

@jbschlosser
Copy link
Contributor Author

do you want to finish up the TODOs?

this will take a fair amount of work; might be worth landing this for now and addressing all that later on when I can carve out some time

@jbschlosser jbschlosser marked this pull request as ready for review April 30, 2024 17:30
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@jbschlosser
Copy link
Contributor Author

@svekars what's the merging procedure for this repo?

@kkt-cohere
Copy link

Hi, sorry to resurrect this. I had a qq: is it still the case that fused implementations of sdpa (like the FA kernels) don't support NestedTensor for training? It was mentioned in one of the older tutorials. But it is not clear to me if that still holds. Thanks in advance! cc @jbschlosser @drisspg

@jbschlosser
Copy link
Contributor Author

@kkt-cohere sorry for the delay, I was out the last couple weeks.

is it still the case that fused implementations of sdpa (like the FA kernels) don't support NestedTensor for training?

This isn't really true anymore. Nested tensors with layout=torch.jagged (AKA NJTs) do support flash attention, etc. for training. I'll take a note to update the out-of-date SDPA docs to reflect this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants