Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Pixel shuffle unshuffle support #99306

Closed
wants to merge 8 commits into from

Conversation

alexdremov
Copy link
Contributor

@alexdremov alexdremov commented Apr 17, 2023

Fixes #83196

Now, MPS implementation is blazingly fast.

Though, I have several questions on improving this PR:

  1. I copied code from test_nn.py. Is there better way to test this?
  2. I decided to use usepixelshuffleorder:YES. Am I right performance-wise? According to docs:
`usePixelShuffleOrder` can be
used to control how the data within spatial blocks is ordered in the
`depthAxis` dimension: with `usePixelShuffleOrder=YES` the values within the
spatial blocks are stored contiguosly within the `depthAxis` dimension whereas
otherwise they are stored interleaved with existing values in the `depthAxis` dimension.

cc: @razarmehr @kulinseth

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99306

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Merge Blocking SEVs

There is 1 active merge blocking SEVs. Please view them below:

If you must merge, use @pytorchbot merge -f.

✅ You can merge normally! (1 Unrelated Failure)

As of commit dc8bab3 with merge base 781b7eb (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Apr 17, 2023
@alexdremov
Copy link
Contributor Author

This fails on MacOS 12. Any ideas? Should I just make MPS implementation unavailable on MacOS 12?

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great . Can you please refactor the upscale and downscale code for pixel shuffle and unshuffle.

@kulinseth
Copy link
Collaborator

This fails on MacOS 12. Any ideas? Should I just make MPS implementation unavailable on MacOS 12?

You can use fallback for macOS 12 . The api should work but seems like there are bugs ..

using CachedGraph = MPSUnaryCachedGraph;

if (upscale_factor == 1) {
return self;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kulinseth should this return self.clone()? What is expected behaviour?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figured out it should cause we don't want this:

t = torch.randn(1, 4, 3, 3, device='mps')
res = F.pixel_shuffle(t, 1)
t -= 10000  # res modified too

@alexdremov
Copy link
Contributor Author

Looks great . Can you please refactor the upscale and downscale code for pixel shuffle and unshuffle.

Could you please clarify which code specifically? Is there something wrong in PixelShuffle.mm?

@alexdremov
Copy link
Contributor Author

@kulinseth any updates?

out_shape.insert(out_shape.end(), {oc, oh, ow});

Tensor output = at::empty(out_shape, self.options());

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be added...

  auto output = at::empty({0}, self.options());
  if (output.numel() == 0) {
    return output;
  }

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modulo this looks good.

@kulinseth
Copy link
Collaborator

Looks great . Can you please refactor the upscale and downscale code for pixel shuffle and unshuffle.

Could you please clarify which code specifically? Is there something wrong in PixelShuffle.mm?

I meant if you start the upscale and downscale code, the calculations are same. The Graph and other parts can be refactored:

const int64_t c = self.size(-3);
   const int64_t h = self.size(-2);
   const int64_t w = self.size(-1);
   constexpr auto NUM_NON_BATCH_DIMS = 3;
   const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;

   const int64_t downscale_factor_squared = downscale_factor * downscale_factor;
   const int64_t oc = c * downscale_factor_squared;
   const int64_t oh = h / downscale_factor;
   const int64_t ow = w / downscale_factor;

   std::vector<int64_t> out_shape(self.sizes().begin(), self_sizes_batch_end);
   out_shape.insert(out_shape.end(), {oc, oh, ow});

I would encourage you move all the code to a helper function and then call it from Shuffle and Unshuffle.

@alexdremov
Copy link
Contributor Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 22, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approval needed from one of the following:
YXIE14, dulinriley, Hangjun, eprivezentsev, govardhan, ...

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@alexdremov
Copy link
Contributor Author

Merge failed

Reason: Approval needed from one of the following: YXIE14, dulinriley, Hangjun, eprivezentsev, govardhan, ...

Details for Dev Infra team
Raised by workflow job
Failing merge rule: Core Maintainers

@kulinseth, could you tag an appropriate core maintainer, please? ;)

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD can you please take a look ?

@kulinseth kulinseth requested a review from albanD April 22, 2023 14:50
@kulinseth
Copy link
Collaborator

Merge failed

Reason: Approval needed from one of the following: YXIE14, dulinriley, Hangjun, eprivezentsev, govardhan, ...
Details for Dev Infra team
Raised by workflow job
Failing merge rule: Core Maintainers

@kulinseth, could you tag an appropriate core maintainer, please? ;)

Done , added @albanD

@alexdremov
Copy link
Contributor Author

@albanD, could you please take a look?

@alexdremov
Copy link
Contributor Author

@albanD, @kulinseth this PR is almost a month old

@kulinseth
Copy link
Collaborator

@albanD, @kulinseth this PR is almost a month old

Pinged @albanD.

TORCH_WARN_ONCE("MPS: pixel_shuffle op is supported starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");

return at::native::pixel_shuffle_cpu(self.to("cpu"), upscale_factor).clone().to("mps");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why clone the output here and below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, unneeded

test/test_mps.py Outdated
@@ -934,6 +934,110 @@ def leak_gpu0():
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
leak_gpu0()


# These tests were taken from test/test_nn.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have OpInfo based tests for this. So this should already properly covered?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s a good question . This is a good additional test which tests different ranges of upscale_factor . Although we will enable test_nn in near future, then we can remove these duplicates

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant there is a regular OpInfo:

That one should already be ran by the existing tests in this file no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these tests already run, I'll be happy to remove them. Is there a way to check it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the CI logs on the mps machine, you can look for nn_functional_pixel_shuffle

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found these lines

Did you find similar ones for unshuffle there?

Yes

  test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_shuffle_cpu_float16 PASSED [0.1346s] [ 21%]
  test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_shuffle_cpu_float32 PASSED [0.1138s] [ 21%]
  test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_unshuffle_cpu_float16 PASSED [0.0176s] [ 21%]
  test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_unshuffle_cpu_float32 PASSED [0.0179s] [ 21%]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these already cover this test you just copied here? (and the test in test_nn is just redundant)
Or they test different things?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're different

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's a bid redundant, but can be done as a followup PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexdremov can you please remove the test_nn copied test in this PR to avoid duplication? We are planning to enable test_nn testing with 'mps' device soon and will have this enabled there.

@github-actions github-actions bot added the Stale label Jul 15, 2023
@kulinseth
Copy link
Collaborator

@alexdremov , can we resurrect this PR ?

@kulinseth
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased mps_channel_shuffle onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout mps_channel_shuffle && git pull --rebase)

@alexdremov
Copy link
Contributor Author

@alexdremov , can we resurrect this PR ?

@kulinseth Seems like code side is fine and it successfully rebased.

There’s only @malfet questions that need to be checked. I

#99306 (review)

@github-actions github-actions bot closed this Aug 20, 2023
@kulinseth kulinseth reopened this Aug 21, 2023
@alexdremov
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased mps_channel_shuffle onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout mps_channel_shuffle && git pull --rebase)

@alexdremov
Copy link
Contributor Author

alexdremov commented Aug 26, 2023

@kulinseth I've ressurected this PR. Merging?

  • I added non-contiguous tests
  • Added style fixes noted above
  • This works with integers. Tests run checks for various dtypes:
test/test_mps.py::TestPixelShuffle::test_pixel_shuffle_unshuffle PASSED [5.7919s]                                                                                                                        [  4%]
test/test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_shuffle_cpu_float16 PASSED [0.0728s]                                                                                    [  9%]
test/test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_shuffle_cpu_float32 PASSED [0.0689s]                                                                                    [ 14%]
test/test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_unshuffle_cpu_float16 PASSED [0.0044s]                                                                                  [ 19%]
test/test_mps.py::TestConsistencyCPU::test_output_grad_match_nn_functional_pixel_unshuffle_cpu_float32 PASSED [0.0041s]                                                                                  [ 23%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_bool PASSED [0.0355s]                                                                                            [ 28%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_float16 PASSED [0.0025s]                                                                                         [ 33%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_float32 PASSED [0.0027s]                                                                                         [ 38%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_int16 PASSED [0.0375s]                                                                                           [ 42%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_int32 PASSED [0.0357s]                                                                                           [ 47%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_int64 PASSED [0.0360s]                                                                                           [ 52%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_int8 PASSED [0.0346s]                                                                                            [ 57%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_shuffle_cpu_uint8 PASSED [0.0369s]                                                                                           [ 61%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_bool PASSED [0.0360s]                                                                                          [ 66%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_float16 PASSED [0.0023s]                                                                                       [ 71%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_float32 PASSED [0.0020s]                                                                                       [ 76%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_int16 PASSED [0.0371s]                                                                                         [ 80%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_int32 PASSED [0.0383s]                                                                                         [ 85%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_int64 PASSED [0.0357s]                                                                                         [ 90%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_int8 PASSED [0.0344s]                                                                                          [ 95%]
test/test_mps.py::TestConsistencyCPU::test_output_match_nn_functional_pixel_unshuffle_cpu_uint8 PASSED [0.0352s]                                                                                         [100%]

@alexdremov
Copy link
Contributor Author

@kulinseth could you take a look please? Anything else to fix?

@alexdremov
Copy link
Contributor Author

@kulinseth any updates?

@alexdremov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category Stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PixelShuffle is very slow on MPS compared to ConvTranspose2d and PixelShuffle on cuda
6 participants