-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[MPS] Add support for Custom Kernels #100661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100661
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f12a876: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The CI trunk failures seem unrelated to the PR: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw if the answer to a lot of my questions is "there is a single stream so this is not relevant", it should be made very explicit.
8f0f34f
to
5dcef82
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @albanD for the review .
And @razarmehr for the PR . The change looks good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @albanD for the review .
And @razarmehr for the PR . The change looks good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but please undef the define in header. (And again, would be nice to submit as separate PR as it has nothing to do with adding support for customer Kernels, is it?
And not sure why rename _mps_synchronize
to ``_mps_DeviceSynchronize`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change in itself sounds ok but I think we need to be a lot stricter on the documentation if we expect non-MPS maintainers to be able to use this / be able to review PRs touching MPS code.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: .github/workflows/lint.yml / lintrunner / linux-job Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "all checks except the pre-existing lint-linux failure are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Co-authored-by: albanD <desmaison.alban@gmail.com>
f9b0d78
to
becdec3
Compare
Thanks @TaiPhamD. For raw Metal kernel you have the flexibility to cache the PSO as you see fit. We end up caching the Graph as there is CPU overhead in compiling it and we can do this as shapes are known. For Metal we already have caching built in-place for you in OS and you get that already. |
@malfet has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@malfet has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: Meta Internal-Only Changes Check Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "Internal builds are fine" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
- This change introduces these APIs to enable developing custom kernels on the MPS Stream: `torch::mps::get_command_buffer()` `torch::mps::get_dispatch_queue()` `torch::mps::commit()` - Add ObjC test case Pull Request resolved: #100661 Approved by: https://github.com/kulinseth, https://github.com/malfet
I've added the implementation of erfinv using the algorithm from https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152 in order for the MPS based algorithm to match the CPU automatic test. This PR is using the new metal api calls from #100661 Testing shows MPS has a decent speed up (270x) compared to CPU on tensor size of 100 mil elements. ``` import torch x = torch.arange(-1, 1, 1e-8) # default cpu tensor #measure CPU compute time by calling torch.erfinv time = %timeit -o -q -r 5 torch.erfinv(x) cpu_time = time.average print("CPU torch.erfinv time: ", cpu_time) x = x.to("mps") # measure MPS compute time time = %timeit -o -q -r 5 torch.erfinv(x) mps_time = time.average print("MPS torch.erfinv time: ", mps_time) print(f"MPS torch.erfinv is {cpu_time/mps_time*100} percent faster than CPU torch.erfinv") # compute MSE between MPS and CPU torch.erfinv x = x.to("cpu") y_cpu = torch.erfinv(x) x = x.to("mps") y_mps = torch.erfinv(x) y_mps = y_mps.to("cpu") mask = torch.isfinite(y_cpu) & torch.isfinite(y_mps.to("cpu")) y_mps = y_mps[mask] y_cpu = y_cpu[mask] x = x[mask] print(f"length of y_mps: {len(y_mps)}, length of y_cpu: {len(y_cpu)}, length of x: {len(x)}") mse = torch.square(y_cpu - y_mps).mean() print("MSE between MPS and CPU torch.erfinv: ", mse) diff = torch.abs(y_cpu - y_mps) print("Largest difference") print(f"x: {x[torch.argmax(diff)]}, y_cpu: {y_cpu[torch.argmax(diff)]}, y_mps: {y_mps[torch.argmax(diff)]} , diff = {y_cpu[torch.argmax(diff)] - y_mps[torch.argmax(diff)]}") ``` CPU torch.erfinv time: 2.654937833400254 MPS torch.erfinv time: 0.009831255332002912 MPS torch.erfinv is 27005.07456822776 percent faster than CPU torch.erfinv length of y_mps: 199999992, length of y_cpu: 199999992, length of x: 199999992 MSE between MPS and CPU torch.erfinv: tensor(4.2339e-14) Largest difference x: -0.9999980330467224, y_cpu: -3.363569736480713, y_mps: -3.3635685443878174 , diff = -1.1920928955078125e-06 Fixes ##86808 Pull Request resolved: #101507 Approved by: https://github.com/kulinseth
torch::mps::get_command_buffer()
torch::mps::get_dispatch_queue()
torch::mps::commit()