-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MHA optimizations #93234
MHA optimizations #93234
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93234
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b293762: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/nn/functional.py
Outdated
return linear(q, w, b).chunk(3, dim=-1) | ||
proj = linear(q, w, b) | ||
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing | ||
proj = proj.view(1, *proj.shape[:-1], 3, E).transpose(0, -2).squeeze(-1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this make old checkpoints return a different result due to a different interpretation of output channels? or is the interpretation the same? 3, E is not the same as E, 3 in that respect... also wondering if this should be made into chunk
option somehow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chunk
was doing the same as (3, E), so checkpoints won't be affected
>>> proj = torch.arange(0, 3 * E)
>>> proj
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
>>> proj.chunk(3, dim=-1)
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29]))
>>> proj.view(1, *proj.shape[:-1], 3, E).transpose(0, -2).squeeze(-1).contiguous()
tensor([[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]]])
In this case we do it this way because we specifically know we want the chunks to be contiguous. But in most cases it's fine and faster to use chunk
without calling contiguous()
. If there were a contiguous
flag to torch.chunk
then yes this kind of trick could also be applied there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, maybe copy=True/False
argument could be introduced for chunk
/split
chunk
was doing the same as (3, E), so checkpoints won't be affected
I see, I guess I was misled by the original comment reshape to 3, E and not E, 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the comment more clear.
I think we can leave chunk()
like it is now, and if the same pattern appears in multiple places we can think about adding a flag later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, there might exist an internal method torch.transpose_copy
, not sure if it's of any use here and runs faster than transpose + contiguous
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know where this is implemented? I couldn't find it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure. Looks like it's some sort of clutch, and not implemented anywhere :( https://github.com/pytorch/pytorch/search?q=transpose_copy
torch/nn/functional.py
Outdated
@@ -5178,9 +5185,9 @@ def multi_head_attention_forward( | |||
# | |||
# reshape q, k, v for multihead attention and make em batch first | |||
# | |||
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) | |||
q = q.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the above changes assures us that q,k,v is contiguous? Could we call view() instead of reshape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the path where we call in_projection_packed
yes we could do a view. I think it's also fine in the other paths but to be safe in case we add other paths in the future I changed to reshape.
Is reshape
more expensive than view
on the CPU side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does some checking to see if it is safe to call view otherwise it will clone. I think making it view is more explicit. I think if the projection changes are to assure contiguity of q,k,v at this point of the computation for perf reasons than we should be explicit. Also I pray no new paths are added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, changed to view in ad36237
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
torch/nn/functional.py
Outdated
return linear(q, w, b).chunk(3, dim=-1) | ||
proj = linear(q, w, b) | ||
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() | ||
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
squeeze(-1)
doesn't look right, last dimension most likely isn't 1 (unless E happens to be 1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be squeeze(-2)
probably...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops yes you're right, forgot to change when I updated from transpose(0, -1)
. Still works because of the view() later. I fixed
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Slight perf optimizations for regular MHA by reducing the number of kernels called
Before:
![image](https://user-images.githubusercontent.com/30204471/215349212-172c6364-9e3c-4fd1-92b6-8ddd9931613e.png)
After:
![image](https://user-images.githubusercontent.com/30204471/215349247-021dd9e6-f6ca-40a2-8de8-0805af001f69.png)
cc @ngimel