-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[MPS] Speedup addmm #116548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Speedup addmm #116548
Conversation
- Do not copy bias to output - Skip respective multiplication op if either alpha or beta are equal to 1.0 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116548
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a6a34e5 with merge base 97891b1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
- Do not copy bias to output - Skip respective multiplication op if either alpha or beta are equal to 1.0 [ghstack-poisoned]
if (&output != &self) { | ||
output.resize_(bias_sizes); | ||
if (beta.toComplexDouble() != 0.0) { | ||
output.copy_(*bias_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is an out variant, overriding the output completely is not ok. We should add into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and it is(all torch.addmm
OpInfo tests use alpha
and beta
not equal to 1), this is why I'm removing this one, as copy is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I guess the next steps here would be to add this case to OpInfo and fix this kernel to have the appropriate behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think anything is needed here: addmm_out_mps
copied bias to output for some reason, even though it wasn't at all needed, as MPSGraph always overwrites the output;
pytorch/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Lines 302 to 305 in 035e558
MPSGraphTensor* outputTensor = productTimesAlphaTensor; | |
if (is_beta_non_zero) { | |
outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor | |
secondaryTensor:biasTimesBetaTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, or perhaps I did not understand your original question: this override only happens if output != self, though yes, I'm not sure this kernel will work as expected if output = self, but imo this should be done as separate PR, this just eliminates unneeded multiplications if alpha and beta are 1 and unneeded override as function always writes to output
- Do not copy bias to output - Skip respective multiplication op if either alpha or beta are equal to 1.0 [ghstack-poisoned]
- Do not copy bias to output - Skip respective multiplication op if either alpha or beta are equal to 1.0 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
@pytorchbot merge -f "Lint and MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):