Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MHA optimizations #93234

Closed
wants to merge 5 commits into from
Closed

MHA optimizations #93234

wants to merge 5 commits into from

Conversation

milesial
Copy link
Contributor

@milesial milesial commented Jan 29, 2023

Slight perf optimizations for regular MHA by reducing the number of kernels called

Before:
image

After:
image

cc @ngimel

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93234

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b293762:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

return linear(q, w, b).chunk(3, dim=-1)
proj = linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing
proj = proj.view(1, *proj.shape[:-1], 3, E).transpose(0, -2).squeeze(-1).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this make old checkpoints return a different result due to a different interpretation of output channels? or is the interpretation the same? 3, E is not the same as E, 3 in that respect... also wondering if this should be made into chunk option somehow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunk was doing the same as (3, E), so checkpoints won't be affected

>>> proj = torch.arange(0, 3 * E)
>>> proj
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
>>> proj.chunk(3, dim=-1)
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29]))
>>> proj.view(1, *proj.shape[:-1], 3, E).transpose(0, -2).squeeze(-1).contiguous()
tensor([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9]],

        [[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]]])

In this case we do it this way because we specifically know we want the chunks to be contiguous. But in most cases it's fine and faster to use chunk without calling contiguous(). If there were a contiguous flag to torch.chunk then yes this kind of trick could also be applied there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, maybe copy=True/False argument could be introduced for chunk/split

chunk was doing the same as (3, E), so checkpoints won't be affected

I see, I guess I was misled by the original comment reshape to 3, E and not E, 3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the comment more clear.
I think we can leave chunk() like it is now, and if the same pattern appears in multiple places we can think about adding a flag later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, there might exist an internal method torch.transpose_copy, not sure if it's of any use here and runs faster than transpose + contiguous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know where this is implemented? I couldn't find it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure. Looks like it's some sort of clutch, and not implemented anywhere :( https://github.com/pytorch/pytorch/search?q=transpose_copy

@albanD albanD requested review from drisspg and removed request for albanD January 30, 2023 18:19
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 30, 2023
@@ -5178,9 +5185,9 @@ def multi_head_attention_forward(
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
q = q.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
Copy link
Contributor

@drisspg drisspg Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the above changes assures us that q,k,v is contiguous? Could we call view() instead of reshape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the path where we call in_projection_packed yes we could do a view. I think it's also fine in the other paths but to be safe in case we add other paths in the future I changed to reshape.

Is reshape more expensive than view on the CPU side?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does some checking to see if it is safe to call view otherwise it will clone. I think making it view is more explicit. I think if the projection changes are to assure contiguity of q,k,v at this point of the computation for perf reasons than we should be explicit. Also I pray no new paths are added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, changed to view in ad36237

@drisspg drisspg added module: performance Issues related to performance, either of kernel code or framework glue topic: performance topic category labels Jan 30, 2023
@drisspg drisspg self-requested a review January 30, 2023 21:51
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

return linear(q, w, b).chunk(3, dim=-1)
proj = linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-1).contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squeeze(-1) doesn't look right, last dimension most likely isn't 1 (unless E happens to be 1).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be squeeze(-2) probably...

Copy link
Contributor Author

@milesial milesial Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yes you're right, forgot to change when I updated from transpose(0, -1). Still works because of the view() later. I fixed

@milesial
Copy link
Contributor Author

milesial commented Feb 3, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: performance Issues related to performance, either of kernel code or framework glue open source topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants