Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Add native cumsum implementation #88319

Closed
wants to merge 2 commits into from
Closed

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Nov 2, 2022

Using https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/4057333-cumulativesumwithtensor?language=objc

Fall back to CPU if running on older MacOS versions
In unary_op add output tensor dims/dtype to the graph key (as even in default op we check output graph type)
Also, upcast int16 to int32 as MPS cumsum op on Ventura returns incorrect results for Int16 type (and it makes total sense for int8, as chances for overflow are very high)

Using https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/4057333-cumulativesumwithtensor?language=objc

Fall back to CPU if running on older MacOS versions

TODO: Add tests, figure out what is going on with integral types support
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 2, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88319

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit ba304db:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Nov 2, 2022
@@ -7,6 +7,13 @@
#include <ATen/native/mps/OperationUtils.h>
#include <torch/library.h>

// TODO: Remove me when moved to MacOS 13
@interface MPSGraph (VenturaOps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably move this to a header file and include all the venturaOps in the same place.

@@ -263,5 +270,33 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
});
}


static bool mpsSupportsCumsum() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also probably belongs to a single place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe it should be property of MPSDevice and run only once. Do you mind if I'll do it as followup PR?

@malfet
Copy link
Contributor Author

malfet commented Nov 4, 2022

@pytorchbot merge -f "MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet malfet deleted the malfet/mps-add-cumsum branch November 4, 2022 04:26
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
Using https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/4057333-cumulativesumwithtensor?language=objc

Fall back to CPU if running on older MacOS versions
In `unary_op` add output tensor dims/dtype to the graph key (as even in default op we check output graph type)
Also, upcast int16 to int32 as MPS cumsum op on Ventura returns incorrect results for Int16 type (and it makes total sense for int8, as chances for overflow are very high)
Pull Request resolved: pytorch#88319
Approved by: https://github.com/kulinseth
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Using https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/4057333-cumulativesumwithtensor?language=objc

Fall back to CPU if running on older MacOS versions
In `unary_op` add output tensor dims/dtype to the graph key (as even in default op we check output graph type)
Also, upcast int16 to int32 as MPS cumsum op on Ventura returns incorrect results for Int16 type (and it makes total sense for int8, as chances for overflow are very high)
Pull Request resolved: pytorch#88319
Approved by: https://github.com/kulinseth
@qqaatw
Copy link
Collaborator

qqaatw commented Dec 14, 2022

@malfet Hi, do we currently have a CI flow that runs on macOS Ventura?

@blackrabbit
Copy link

Thanks; this is awesome.

@KayKozaronek
Copy link

KayKozaronek commented Jan 24, 2023

@malfet or @kulinseth, I'm not sure who's responsible for this PR.

Is it correct to assume that 'aten::cumsum.out' has now been implemented? If so, which version of PyTorch am I supposed to run? I just updated PyTorch to version 1.13.1 and I'm still getting an error.

UserWarning: The operator 'aten::cumsum.out' is not currently supported on the MPS backend and will fall back to run on the CPU.

Should I run the Nightly version?

I hope this is the right place to ask this.

@malfet
Copy link
Contributor Author

malfet commented Jan 24, 2023

@KayKozaronek as one can easily see 657f2e1 is not part of https://github.com/pytorch/pytorch/tree/release/1.13 so yes, it's only available in nightlies at the moment

@mayank31398
Copy link

mayank31398 commented Jan 25, 2023

I think cumsum supports dim=-1 but this PR doesn't for MPS device.
I see this when trying to run generation using gpt2 in HF transformers
Can we fix? @malfet

Traceback (most recent call last):
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/mayankmishra/Desktop/IBM/transformers-bloom-inference/inference_server/cli.py", line 43, in <module>
    main()
  File "/Users/mayankmishra/Desktop/IBM/transformers-bloom-inference/inference_server/cli.py", line 36, in main
    response = model.generate(text=[input_text], generate_kwargs=generate_kwargs)
  File "/Users/mayankmishra/Desktop/IBM/transformers-bloom-inference/inference_server/model_handler/deployment.py", line 199, in generate
    raise response
  File "/Users/mayankmishra/Desktop/IBM/transformers-bloom-inference/inference_server/models/model.py", line 49, in generate
    output = self.model.generate(
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.9/site-packages/transformers/generation/utils.py", line 1518, in generate
    return self.greedy_search(
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.9/site-packages/transformers/generation/utils.py", line 2282, in greedy_search
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  File "/Users/mayankmishra/miniconda3/envs/ai/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 999, in prepare_inputs_for_generation
    position_ids = attention_mask.long().cumsum(-1) - 1
RuntimeError: Expected dim to be between 0 and 2 but got -1

malfet added a commit that referenced this pull request Feb 4, 2023
Use `wrap_dim` to get dim in range or range IndexError

Add test to test for that

Addresses feedback raised in #88319 (comment)
@malfet
Copy link
Contributor Author

malfet commented Feb 4, 2023

@mayank31398 thank you for reporting, but you should've created issue, so that more people would see it. Fixing in #94119

pytorchmergebot pushed a commit that referenced this pull request Feb 5, 2023
Use `wrap_dim` to get dim in range or range IndexError

Add test to test for that

Addresses feedback raised in #88319 (comment)

Pull Request resolved: #94119
Approved by: https://github.com/Skylion007, https://github.com/seemethere
@n-splv
Copy link

n-splv commented Feb 8, 2023

Hey, @malfet, thanks for the fix. Sorry for the dumb question, but has it been merged already? I see that the bot attached a "Merged" label to it, but when I install pytorch-nightly the error still persists.
My torch.version is '2.0.0.dev20230107'

@malfet
Copy link
Contributor Author

malfet commented Feb 8, 2023

@nick-maykr Not sure what fix you are referring to. Cumsum implementation or negative index? Former was landed back in 2022, but later just a few days back, so it couldn't have been part of Jan 7th nightly build.

@n-splv
Copy link

n-splv commented Feb 9, 2023

@malfet I meant the second one, thanks! When I installed the nightly build this problem no longer occurs.

However, I'm facing a new one: RuntimeError: MPS does not support min/max ops with int64 input. This one seems to be on HF Transformers side and afaik, is has not been addressed yet :(

@kulinseth
Copy link
Collaborator

@nick-maykr , we are looking into working around this by casting it to int32 (with a warning raised), till we have support for it in MPS kernels to handle it properly. cc @albanD

pytorchmergebot pushed a commit that referenced this pull request Feb 10, 2023
…4484)

Currently casting it as a workaround till we have full support in OS.
Fixes ##88319 (comment)

Pull Request resolved: #94484
Approved by: https://github.com/razarmehr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants