Skip to content

Conversation

Isalia20
Copy link
Collaborator

@Isalia20 Isalia20 commented Mar 29, 2025

Fixes #142397

Basic implementation is done. What's left:

  • Different dtype/device tensors in the TensorList
  • fast path for grouping the foreach kernel
  • Tests

Regarding tests, I found some tests in test/test_torch.py for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.

By removing @onlyNativeDeviceTypes, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)

This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Copy link

pytorch-bot bot commented Mar 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150255

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 70 Pending

As of commit 8d72182 with merge base 7ac8186 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) module: amp (automated mixed precision) autocast release notes: mps Release notes category labels Mar 29, 2025
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@Isalia20 Isalia20 added the module: mps Related to Apple Metal Performance Shaders framework label Mar 29, 2025
@Isalia20 Isalia20 marked this pull request as draft March 29, 2025 12:20
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) and removed ciflow/mps Run MPS tests (subset of trunk) labels Mar 29, 2025
@Isalia20 Isalia20 added the ciflow/mps Run MPS tests (subset of trunk) label Mar 29, 2025
@Isalia20 Isalia20 marked this pull request as ready for review March 29, 2025 22:22
@Isalia20
Copy link
Collaborator Author

Something seems to be wrong with MPS tests. They are green when I build locally 🤔
Screenshot 2025-03-30 at 02 20 45

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label Mar 29, 2025
@Isalia20 Isalia20 added the ciflow/mps Run MPS tests (subset of trunk) label Mar 29, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label Mar 29, 2025
@Isalia20 Isalia20 added the ciflow/mps Run MPS tests (subset of trunk) label Mar 29, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label Mar 30, 2025
@Isalia20 Isalia20 added the ciflow/mps Run MPS tests (subset of trunk) label Mar 30, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label Mar 30, 2025
@Isalia20 Isalia20 added the ciflow/mps Run MPS tests (subset of trunk) label Mar 30, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label Mar 30, 2025
@Isalia20 Isalia20 added the ciflow/mps Run MPS tests (subset of trunk) label Mar 30, 2025
@Isalia20
Copy link
Collaborator Author

Okay that fixed the errors, ready for review now

Comment on lines 15 to 17
#define lib _ignored_lib_name_for_fused
#include <ATen/native/mps/FusedOptimizerOps_metallib.h>
#undef lib
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if there is a better way to do this 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't move library instantiation to header (and never import two libraries at once)

@malfet malfet added ciflow/trunk Trigger trunk jobs on your pull request ciflow/mps Run MPS tests (subset of trunk) labels Apr 6, 2025
@malfet
Copy link
Contributor

malfet commented Apr 6, 2025

Few meta points about this PR:

  • Pay attention to difference between signed/unsigned types, for example if one to look at
    __global__ void amp_update_scale_cuda_kernel(float* current_scale,
    int* growth_tracker,
    const float* found_inf,
    double growth_factor,
    double backoff_factor,
    int growth_interval)

    both growth_tracker and growth_interval are signed type, but in your PR for some reason growth_interval turned into an unsigned one, which might be fine semantically, but in that case would be good to add check that value is positive before casting it to the unsigned type
  • I'm not sure if AMP testing Is fully contained to ciflow/mps workflow, so getting a signal from trunk would be great
  • When you see too many known failures in pytofch-bot/DrCI comments, rebase to stable, otherwise signal can be occluded
  • Pay attention to error checking and do it as early as possible during the kernel execution
  • If you opt in to allocate GPU memory manually, make sure you are free it after use

@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/mps Run MPS tests (subset of trunk) labels Apr 6, 2025
@malfet malfet added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 6, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Apr 6, 2025
@malfet
Copy link
Contributor

malfet commented Apr 6, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 6, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor

malfet commented Apr 6, 2025

@pytorchbot merge -f "Lint + MPS are green, hopefully trunk as well"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kurzdev
Copy link

kurzdev commented Apr 6, 2025

Awesome, tysm! ❤️

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
Fixes pytorch#142397

Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests

Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.

By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`

This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification

Pull Request resolved: pytorch#150255
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Fixes pytorch#142397

Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests

Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.

By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`

This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification

Pull Request resolved: pytorch#150255
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: amp (automated mixed precision) autocast module: mps Related to Apple Metal Performance Shaders framework open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable GradScaler for MPS devices
6 participants