New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Pixel shuffle unshuffle support #99306
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99306
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Merge Blocking SEVsThere is 1 active merge blocking SEVs. Please view them below:
If you must merge, use ✅ You can merge normally! (1 Unrelated Failure)As of commit dc8bab3 with merge base 781b7eb (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4534efc
to
6acce78
Compare
This fails on MacOS 12. Any ideas? Should I just make MPS implementation unavailable on MacOS 12? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great . Can you please refactor the upscale and downscale code for pixel shuffle and unshuffle.
You can use fallback for macOS 12 . The api should work but seems like there are bugs .. |
using CachedGraph = MPSUnaryCachedGraph; | ||
|
||
if (upscale_factor == 1) { | ||
return self; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kulinseth should this return self.clone()
? What is expected behaviour?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Figured out it should cause we don't want this:
t = torch.randn(1, 4, 3, 3, device='mps')
res = F.pixel_shuffle(t, 1)
t -= 10000 # res modified too
Could you please clarify which code specifically? Is there something wrong in |
@kulinseth any updates? |
out_shape.insert(out_shape.end(), {oc, oh, ow}); | ||
|
||
Tensor output = at::empty(out_shape, self.options()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be added...
auto output = at::empty({0}, self.options());
if (output.numel() == 0) {
return output;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modulo this looks good.
I meant if you start the upscale and downscale code, the calculations are same. The Graph and other parts can be refactored:
I would encourage you move all the code to a helper function and then call it from Shuffle and Unshuffle. |
@pytorchmergebot merge |
Merge failedReason: Approval needed from one of the following: |
@kulinseth, could you tag an appropriate core maintainer, please? ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD can you please take a look ?
Done , added @albanD |
@albanD, could you please take a look? |
@albanD, @kulinseth this PR is almost a month old |
Pinged @albanD. |
TORCH_WARN_ONCE("MPS: pixel_shuffle op is supported starting from macOS 13.0. ", | ||
"Falling back on CPU. This may have performance implications."); | ||
|
||
return at::native::pixel_shuffle_cpu(self.to("cpu"), upscale_factor).clone().to("mps"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why clone the output here and below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, unneeded
test/test_mps.py
Outdated
@@ -934,6 +934,110 @@ def leak_gpu0(): | |||
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"): | |||
leak_gpu0() | |||
|
|||
|
|||
# These tests were taken from test/test_nn.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have OpInfo based tests for this. So this should already properly covered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s a good question . This is a good additional test which tests different ranges of upscale_factor . Although we will enable test_nn in near future, then we can remove these duplicates
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant there is a regular OpInfo:
OpInfo( |
That one should already be ran by the existing tests in this file no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these tests already run, I'll be happy to remove them. Is there a way to check it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the CI logs on the mps machine, you can look for nn_functional_pixel_shuffle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found these lines
Did you find similar ones for unshuffle there?
Yes
test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_shuffle_cpu_float16 PASSED [0.1346s] [ 21%]
test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_shuffle_cpu_float32 PASSED [0.1138s] [ 21%]
test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_unshuffle_cpu_float16 PASSED [0.0176s] [ 21%]
test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_unshuffle_cpu_float32 PASSED [0.0179s] [ 21%]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these already cover this test you just copied here? (and the test in test_nn is just redundant)
Or they test different things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're different
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it's a bid redundant, but can be done as a followup PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexdremov can you please remove the test_nn copied test in this PR to avoid duplication? We are planning to enable test_nn testing with 'mps' device soon and will have this enabled there.
5f21e98
to
ee3aee8
Compare
@alexdremov , can we resurrect this PR ? |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
d161002
to
1abfa6c
Compare
@kulinseth Seems like code side is fine and it successfully rebased. There’s only @malfet questions that need to be checked. I |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
1abfa6c
to
b6be7b6
Compare
@kulinseth I've ressurected this PR. Merging?
|
@kulinseth could you take a look please? Anything else to fix? |
@kulinseth any updates? |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #83196
Now, MPS implementation is blazingly fast.
Though, I have several questions on improving this PR:
test_nn.py
. Is there better way to test this?usepixelshuffleorder:YES
. Am I right performance-wise? According to docs:cc: @razarmehr @kulinseth