-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[MPS] grad scaler #150255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] grad scaler #150255
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150255
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 70 PendingAs of commit 8d72182 with merge base 7ac8186 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Okay that fixed the errors, ready for review now |
#define lib _ignored_lib_name_for_fused | ||
#include <ATen/native/mps/FusedOptimizerOps_metallib.h> | ||
#undef lib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if there is a better way to do this 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't move library instantiation to header (and never import two libraries at once)
Few meta points about this PR:
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge -f "Lint + MPS are green, hopefully trunk as well" |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Awesome, tysm! ❤️ |
Fixes pytorch#142397 Basic implementation is done. What's left: - [x] Different dtype/device tensors in the TensorList - [x] fast path for grouping the foreach kernel - [x] Tests Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device. By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put: `instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)` This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification Pull Request resolved: pytorch#150255 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Fixes pytorch#142397 Basic implementation is done. What's left: - [x] Different dtype/device tensors in the TensorList - [x] fast path for grouping the foreach kernel - [x] Tests Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device. By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put: `instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)` This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification Pull Request resolved: pytorch#150255 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Fixes #142397
Basic implementation is done. What's left:
Regarding tests, I found some tests in
test/test_torch.py
for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.By removing
@onlyNativeDeviceTypes
, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)
This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen