Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Add repeat_interleave to MPS #88649

Closed
wants to merge 41 commits into from

Conversation

jazzysoggy
Copy link
Contributor

Fixes #87219

Implements new repeat_interleave function into aten/src/ATen/native/mps/operations/Repeat.mm
Adds it to aten/src/ATen/native/native_functions.yaml
Adds new test test_repeat_interleave to test/test_mps/py

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Nov 8, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 8, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88649

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9229982:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jazzysoggy jazzysoggy changed the title Add repeat_interleave to MPS [MPS] Add repeat_interleave to MPS Nov 8, 2022
Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. @jazzysoggy please fix the lint issue.

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jazzysoggy please fix the lint issue.

Also what you have is just a copy of Repeat.cpp core file. The backend need to implement the aforementioned interface.

aten/src/ATen/native/mps/operations/Repeat.mm Outdated Show resolved Hide resolved
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 10, 2022
Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the Lint error, otherwise it looks good.

@kulinseth
Copy link
Collaborator

Also there is Build failure:

/Users/runner/work/pytorch/pytorch/aten/src/ATen/native/mps/operations/Repeat.mm:216:105: error: 'NSNumber' may not respond to 'UTF8String' [-Werror]
[3056](https://github.com/pytorch/pytorch/actions/runs/3482646873/jobs/5825210002#step:10:3057)
    string key = "repeat_interleave_mps:" + getTensorsStringKey({repeats})+ ":" + string([repeats_shape UTF8String]);
[3057](https://github.com/pytorch/pytorch/actions/runs/3482646873/jobs/5825210002#step:10:3058)
                                                                                          ~~~~~~~~~~~~~ ^

@jazzysoggy
Copy link
Contributor Author

@kulinseth I think I need a little help. Not quite sure how to properly write a native metal kernel. I think I've got the framework correct, but just not accessing the correct namespaces for certain objects or defining functions in the right places.

@kulinseth
Copy link
Collaborator

@kulinseth I think I need a little help. Not quite sure how to properly write a native metal kernel. I think I've got the framework correct, but just not accessing the correct namespaces for certain objects or defining functions in the right places.

@jazzysoggy , is this still an issue. Can you please paste the error message you are running into?

@jazzysoggy
Copy link
Contributor Author

@kulinseth I think I need a little help. Not quite sure how to properly write a native metal kernel. I think I've got the framework correct, but just not accessing the correct namespaces for certain objects or defining functions in the right places.

@jazzysoggy , is this still an issue. Can you please paste the error message you are running into?

No, no longer facing this issue anymore. The entire thing compile, just with testing errors now.

‘’FAIL [13.217s]: test_no_warning_on_import (main.TestFallbackWarning)
3541

3542
Traceback (most recent call last):
3543
File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_mps.py", line 7033, in test_no_warning_on_import
3544
self.assertEqual(out, "")
3545
File "/Users/ec2-user/runner/_work/_temp/conda_environment_3509465096/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2555, in assertEqual
3546
assert_equal(
3547
File "/Users/ec2-user/runner/_work/_temp/conda_environment_3509465096/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1118, in assert_equal
3548
raise error_metas[0].to_error(msg)
3549
AssertionError: String comparison failed: '[W OperatorEntry.cpp:159] Warning: Overri[535 chars]l)\n' != ''
3550

  • [W OperatorEntry.cpp:159] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
    3551
  • operator: aten::repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor
    3552
  • registered at /Users/runner/work/pytorch/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
    

3553

  • dispatch key: MPS
    3554
  • previous kernel: registered at /Users/runner/work/pytorch/pytorch/build/aten/src/ATen/RegisterCPU.cpp:30830
    3555
  •    new kernel: registered at /Users/runner/work/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:48 (function registerKernel)
    

3556’’

@jazzysoggy
Copy link
Contributor Author

@kulinseth I’m able to compile, but I currently have issues with passing the test. Here’s the error log for the failed test

FAIL [13.217s]: test_no_warning_on_import (main.TestFallbackWarning)
3541

3542
Traceback (most recent call last):
3543
File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_mps.py", line 7033, in test_no_warning_on_import
3544
self.assertEqual(out, "")
3545
File "/Users/ec2-user/runner/_work/_temp/conda_environment_3509465096/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2555, in assertEqual
3546
assert_equal(
3547
File "/Users/ec2-user/runner/_work/_temp/conda_environment_3509465096/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1118, in assert_equal
3548
raise error_metas[0].to_error(msg)
3549
AssertionError: String comparison failed: '[W OperatorEntry.cpp:159] Warning: Overri[535 chars]l)\n' != ''
3550

  • [W OperatorEntry.cpp:159] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
    3551
  • operator: aten::repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor
    3552
  • registered at /Users/runner/work/pytorch/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
    

3553

  • dispatch key: MPS
    3554
  • previous kernel: registered at /Users/runner/work/pytorch/pytorch/build/aten/src/ATen/RegisterCPU.cpp:30830
    3555
  •    new kernel: registered at /Users/runner/work/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:48 (function registerKernel)
    

@changeling
Copy link

@jazzysoggy Just checking on the status of this. @kulinseth?

@changeling
Copy link

Any word?

@DenisVieriu97
Copy link
Collaborator

@jazzysoggy I've rebased your current branch against latest PyTorch master. Please check the results once the build is finished

@DenisVieriu97
Copy link
Collaborator

@jazzysoggy I've fixed the remaining failures and the kernel. Please take a look and merge it when everything is green

@DenisVieriu97
Copy link
Collaborator

@pytorchbot merge -g

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 12, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Can't merge closed PR #88649

@changeling
Copy link

Thank y'all for all your work on this one, @jazzysoggy, @DenisVieriu97, @kulinseth! This one is much appreciated. Cheers.

@changeling
Copy link

changeling commented Feb 14, 2023

@kulinseth Might be a good idea to update #77764 now that this has been implemented. It appears under both Good First Issue and Not categorized. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MPS] Add support for aten::repeat_interleave for MPS backend
7 participants