-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MPS kernels #7643
Add MPS kernels #7643
Conversation
dd5f42a
to
6f32285
Compare
Hi @qqaatw , I saw this PR isn't in draft state anymore. Is this ready for review? |
Hi @NicolasHug, yes, please. There is an issue with f16 inputs for RoI ops, which doesn't have test coverage. Otherwise the added ops are tested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @qqaatw. I gave a quick first glance at the tests and made some minor comments / suggestions, but this looks great overall.
As discussed offline with @albanD , we're OK to introduce these new MPS kernels in torchvision, with the shared understanding that the MPS-related support (typically bug reports and fixes) will be on the responsibility of the MPS team.
There is an issue with f16 inputs for RoI ops, which doesn't have test coverage. Otherwise the added ops are tested.
What's the issue? If float16
isn't supported for MPS that's OK, but maybe we should write a small test asserting the error message?
int64_t w_stride = grad.stride(3); | ||
int64_t output_size = grad.numel(); | ||
|
||
at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious, what makes this kernel and the other roi align / pool kernels non-deterministic?
For the CUDA kernels, it's the calls to atomicAdd
, but I'm curious what the reason is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MPS kernels also make use of atomic addition which is either provided by the Metal library or the custom implementation depending on the Metal version (See the atomic_add_float
function in mps_kernels.h
).
I've added a note in the PR description. Hope it properly explains the non-determinism.
@@ -271,6 +277,8 @@ def test_jit_boxes_list(self): | |||
|
|||
|
|||
class TestPSRoIPool(RoIOpTester): | |||
mps_backward_atol = 5e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD , any thought regarding this atol
value for gradcheck()
?
For ref we typically use 1e-5
for CPU/CUDA, although we seem to be testing on float64 while the MPS tests are currently running on float32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gradcheck is a bit tricky here as we usually only run it in fp64 precision to get accurate results.
Unfortunately, MPS doesn't support fp64 so we can only resolve to comparing with CPU results or increasing the tolerance significantly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for reviewing @NicolasHug!
What's the issue? If float16 isn't supported for MPS that's OK, but maybe we should write a small test asserting the error message?
The issue is that the atomic operations on MPS do not support half, and the RoI backward kernels make use of atomic addition. Added checks to the RoI backward kernels. The forward kernels work fine!
int64_t w_stride = grad.stride(3); | ||
int64_t output_size = grad.numel(); | ||
|
||
at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MPS kernels also make use of atomic addition which is either provided by the Metal library or the custom implementation depending on the Metal version (See the atomic_add_float
function in mps_kernels.h
).
I've added a note in the PR description. Hope it properly explains the non-determinism.
@@ -158,12 +158,12 @@ def from_K(t): | |||
y = ( | |||
from_K(roi_start_h) | |||
+ ph[None, :, None] * from_K(bin_size_h) | |||
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) | |||
+ (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.5
is by default f32, casting to the input dtype.
) # [K, PH, IY] | ||
x = ( | ||
from_K(roi_start_w) | ||
+ pw[None, :, None] * from_K(bin_size_w) | ||
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) | ||
+ (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
Hi @NicolasHug, sorry for the delayed update. I've applied all the suggestions. |
Gently pinging @NicolasHug. |
Sorry for the delay @qqaatw . I'll provide another round tomorrow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great , thanks @qqaatw
Thanks @qqaatw !! |
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Reviewed By: matteobettini Differential Revision: D48642285 fbshipit-source-id: 00534d4080565eb66ed6b2dbb8416f8d7526687e Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
Summary:
nms
,roi_align
,roi_pool
,ps_roi_align
,ps_roi_pool
and their corresponding backward kernels if any. Most implementations are inspired by the CUDA implementations.mps_kernels.h
for the ease of sharing helper functions and macros, as well as caching PSOs.atomic_float
is supported in Metal 3 (macOS Ventura, MSL specs, section 2.6) and later only, for systems with Metal 2.x, we implement a custom atomic addition function.float64
. Thus, the absolute tolerances of gradcheck in RoI backward tests are adjusted accordingly.cc @NicolasHug @pmeier @albanD @kulinseth