Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.mps.*Tensor datatypes #82296

Closed
TV4Fun opened this issue Jul 27, 2022 · 38 comments
Closed

torch.mps.*Tensor datatypes #82296

TV4Fun opened this issue Jul 27, 2022 · 38 comments
Labels
high priority module: mps Related to Apple Metal Performance Shaders framework triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@TV4Fun
Copy link

TV4Fun commented Jul 27, 2022

🚀 The feature, motivation and pitch

An issue that has been debated ad nauseam and apparently still doesn't have an agreed upon answer as of PyTorch 1.12 is how or if to set a default device for Torch operations. See for example #27878, which has been open nearly 3 years now. A method many libraries use and which is recommended in at least a few places on Stack Overflow is to use set_default_tensor_type to default to a CUDA tensor type. This works for CUDA, but does not work for MPS, as there do not appear to be equivalent tensor types for that. This also creates problems with unclear type names, as if, for example, I create a tensor on MPS with w = torch.tensor([1.0], device='mps'), w.type() for this returns 'torch.mps.FloatTensor', but this is not actually a valid type. There is no torch.mps module, and if I try to pass it as a string, say with x = w.type('torch.mps.FloatTensor'), this returns an error that it is an invalid type.

As near as I can tell, there is no way to directly create a tensor on the MPS device without specifying device='mps' on every single call, which not only clutters code but also makes it very brittle if I happen to miss it on one call. Please correct me if I am wrong on this. My particular use case is that I would like to add MPS support to ML-Agents, which already supports CUDA by means of calling torch.set_default_tensor_type(torch.cuda.FloatTensor) if a CUDA device is available. There does not appear to be an equivalent way to do this for MPS and I have no desire to try and track down however many hundreds of tensor creation calls there are in their code. I know there have been calls to deprecate set_default_tensor_type, i.e. #53124, but I would recommend not doing that without providing some other way to provide a default device. In the meantime, I would really like it if I had an easy way to set the default tensor type to torch.mps.FloatTensor.

Alternatives

Provide a way to set the default device for torch.tensor() and similar calls to MPS.

Additional context

See also #260 and probably others.

cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

@TV4Fun TV4Fun changed the title torch.mps.*tensor datatypes torch.mps.*Tensor datatypes Jul 27, 2022
@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels Jul 27, 2022
@albanD
Copy link
Collaborator

albanD commented Jul 28, 2022

cc @ezyang

@ezyang
Copy link
Contributor

ezyang commented Jul 28, 2022

With TorchFunctionMode we can do this in a few dozen lines of code. I'll put up a PoC later today

@soumith
Copy link
Member

soumith commented Jul 28, 2022

@ezyang
@pbelevich and I were discussing this in a more general context.

With your PoC, can we for example make the default device be a meta device?

@ezyang
Copy link
Contributor

ezyang commented Jul 28, 2022

yes; it will basically be similar to how torchdistx does it

@ezyang
Copy link
Contributor

ezyang commented Jul 28, 2022

This is totally untested

import torch
from torch.overrides import TorchFunctionMode

_DEVICE_CONSTRUCTOR = {
    # standard ones
    torch.empty,
    torch.empty_strided,
    torch.empty_quantized,
    torch.ones,
    torch.arange,
    torch.bartlett_window,
    torch.blackman_window,
    torch.eye,
    torch.fft.fftfreq,
    torch.fft.rfftfreq,
    torch.full,
    torch.fill,
    torch.hamming_window,
    torch.hann_window,
    torch.kaiser_window,
    torch.linspace,
    torch.logspace,
    torch.nested_tensor,
    # torch.normal,
    torch.ones,
    torch.rand,
    torch.randn,
    torch.randint,
    torch.randperm,
    torch.range,
    torch.sparse_coo_tensor,
    torch.sparse_compressed_tensor,
    torch.sparse_csr_tensor,
    torch.sparse_csc_tensor,
    torch.sparse_bsr_tensor,
    torch.sparse_bsc_tensor,
    torch.tril_indices,
    torch.triu_indices,
    torch.vander,
    torch.zeros,
    torch.asarray,
    # weird ones
    torch.tensor,
    torch.as_tensor,
}

class DeviceMode(TorchFunctionMode):
    def __init__(self, device):
        self.device = torch.device(device)

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func in _DEVICE_CONSTRUCTOR:
            if kwargs.get('device') is None:
                kwargs['device'] = self.device
            return func(*args, **kwargs)
        return func(*args, **kwargs)


with DeviceMode(torch.device("meta")):
    print(torch.empty(3))

I'd like to put this into core, but before we can get there, we have to make some policy decisions. For example, if you turn on DeviceMode("mps"), and then someone writes torch.randn(2, device="cpu"), does mps override the CPU device, or does the explicit device win out? It would be nice for a limited set of power users to try this out, make behavior modifications as necessary based on what they observe is necessary in the wild, and then we ship that.

@TV4Fun
Copy link
Author

TV4Fun commented Jul 29, 2022

I would definitely suggest an explicit device overriding the default device mode. Though if you think this could cause problems, you could always raise a warning the first time this happens.

@TV4Fun
Copy link
Author

TV4Fun commented Aug 3, 2022

Is this in the development branch yet? I'd like to be able to try it out.

@ezyang
Copy link
Contributor

ezyang commented Aug 3, 2022

it's definitely in the nightly, and it may also work on the most recent official release

@TV4Fun
Copy link
Author

TV4Fun commented Aug 4, 2022

Thank you @ezyang. Due to #78681 and #79337, I am unable to test this on Torch nightly with my M1 mac. Trying this on Torch 1.12.0 gives this error:

TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 61>()
     57             return func(*args, **kwargs)
     58         return func(*args, **kwargs)
---> 61 with DeviceMode(torch.device("meta")):
     62     print(torch.empty(3))

File ~/miniforge3/envs/torch-nightly/lib/python3.10/site-packages/torch/utils/_mode_utils.py:28, in _wrap_init.<locals>.wrapped(self, inner, *args, **kwargs)
     25 @functools.wraps(f)
     26 def wrapped(self, *args, inner=undef, **kwargs):
     27     if inner is undef:
---> 28         raise TypeError(
     29             f"missing inner keyword argument; instead of constructing a {meta_init_error_info.mode_class_name} "
     30             f"directly, pass the constructor to push_{meta_init_error_info.mode_name}_mode"
     31         )
     32     self.inner = inner
     33     return f(self, *args, **kwargs)

TypeError: missing inner keyword argument; instead of constructing a TorchDispatchMode directly, pass the constructor to push_torch_dispatch_mode

@ezyang
Copy link
Contributor

ezyang commented Aug 4, 2022

Ok you are missing a bugfix that makes the syntax work, try writing the context manager as MyMode.push(device) instead

@TV4Fun
Copy link
Author

TV4Fun commented Aug 4, 2022

Okay, thank you, that works and if I create this with

with DeviceMode.push(torch.device("mps")):
    print(torch.empty(3))

the created tensor appears to be on the MPS device. Is there a way to do this without having to move all my code inside of a with statement though? It seems to do nothing if I call DeviceMode.push(torch.device("mps")) or even manually DeviceMode.push(torch.device("mps")).__enter__(). I understand the idea of using a context manager here, but in practice, that still means a lot of work porting code that just assumes you can set a default.

@ezyang
Copy link
Contributor

ezyang commented Aug 4, 2022

g = DeviceMode.push(torch.device("mps"))
g.__enter__()

@TV4Fun
Copy link
Author

TV4Fun commented Aug 5, 2022

Okay, that works on the simple case. Trying on a more complex example is causing an internal error (again on Torch 1.12.0):

Traceback (most recent call last):
  File "/Users/jcroteau/miniforge3/envs/ml-agents/bin/mlagents-learn", line 33, in <module>
    sys.exit(load_entry_point('mlagents', 'console_scripts', 'mlagents-learn')())
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/learn.py", line 260, in main
    run_cli(parse_command_line())
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/learn.py", line 256, in run_cli
    run_training(run_seed, options, num_areas)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/learn.py", line 132, in run_training
    tc.start_learning(env_manager)
  File "/Users/jcroteau/code/ml-agents/ml-agents-envs/mlagents_envs/timers.py", line 305, in wrapped
    return func(*args, **kwargs)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 173, in start_learning
    self._reset_env(env_manager)
  File "/Users/jcroteau/code/ml-agents/ml-agents-envs/mlagents_envs/timers.py", line 305, in wrapped
    return func(*args, **kwargs)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 107, in _reset_env
    self._register_new_behaviors(env_manager, env_manager.first_step_infos)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 268, in _register_new_behaviors
    self._create_trainers_and_managers(env_manager, new_behavior_ids)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 166, in _create_trainers_and_managers
    self._create_trainer_and_manager(env_manager, behavior_id)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 142, in _create_trainer_and_manager
    trainer.add_policy(parsed_behavior_id, policy)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/sac/trainer.py", line 352, in add_policy
    self.optimizer = self.create_sac_optimizer()
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/sac/trainer.py", line 333, in create_sac_optimizer
    return TorchSACOptimizer(  # type: ignore
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/sac/optimizer_torch.py", line 159, in __init__
    torch.log(
  File "/Users/jcroteau/miniforge3/envs/ml-agents/lib/python3.9/site-packages/torch/overrides.py", line 1738, in wrapped
    return f(self, *args, **kwargs)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/torch_utils/torch.py", line 93, in __torch_function__
    return func(*args, **kwargs)
  File "/Users/jcroteau/miniforge3/envs/ml-agents/lib/python3.9/site-packages/torch/overrides.py", line 1738, in wrapped
    return f(self, *args, **kwargs)
  File "/Users/jcroteau/miniforge3/envs/ml-agents/lib/python3.9/site-packages/torch/overrides.py", line 1831, in __torch_function__
    return func(*args, **kwargs)
RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/OperationUtils.mm":363, please report a bug to PyTorch. Placeholder tensor is empty!

The __torch_function__ in "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/torch_utils/torch.py", line 93 is from your DeviceMode class, which I have put into its own module which is loaded by all of my other modules that use Torch (none of them import torch directly).

This module sets the default device mode with the lines

global _device
global _device_mode
_device = torch.device(device_str)
_device_mode = DeviceMode.push(_device)
_device_mode.__enter__()

The line which is triggering this error is https://github.com/Unity-Technologies/ml-agents/blob/main/ml-agents/mlagents/trainers/sac/optimizer_torch.py#L159, which is calling torch.as_tensor with a simple float array argument. Looking at the call stack, it appears other tensors have already been created on the MPS device successfully, so I am not sure why this particular call is causing problems. I will continue to investigate.

@ezyang
Copy link
Contributor

ezyang commented Aug 5, 2022

It's possible torch.as_tensor doesn't actually work with meta tensor. You could stub out the implementation in torch function, bypassing the func call with a func to torch.empty with appropriate types

@TV4Fun
Copy link
Author

TV4Fun commented Aug 5, 2022

No, that was not the problem, as torch.as_tensor worked in my toy example, and this error still came up when I replaced the above call with torch.tensor. On closer inspection, the problem was that torch.as_tensor was being called with an empty list, and this internal assertion wasn't actually checking if the argument was empty before raising the error. Using a non-empty list fixed this issue. I am not sure if this bug still exists in the dev build or not.

@TV4Fun
Copy link
Author

TV4Fun commented Aug 9, 2022

Confirmed this bug does not exist in the dev build. I was able to implement your code above with a few tweaks here and it appears to have set MPS as the default device. Thank you.

@kulinseth
Copy link
Collaborator

@ezyang and @TV4Fun , what are the next steps here?

@ezyang
Copy link
Contributor

ezyang commented Oct 5, 2022

we merge it to master 👀

@TV4Fun
Copy link
Author

TV4Fun commented Oct 5, 2022

@kulinseth, I would call @ezyang's solution here a hack. It works, but it would still be nice if there were a simpler way to set the default device to MPS.

@ezyang
Copy link
Contributor

ezyang commented Oct 5, 2022

What did you find hacky about it in the end?

@TV4Fun
Copy link
Author

TV4Fun commented Oct 5, 2022

Just that it involves creating a long list of each individual tensor constructor and then creating a context manager at global scope and manually calling __enter__() on it. I suppose it could work if you want to integrate it into Torch and create a simple set_default function that user code can call it will work, but the finer details should be internal to Torch.

@ezyang
Copy link
Contributor

ezyang commented Oct 5, 2022

Yes, so supposing that PyTorch core maintained the internal implementation details, and we gave a "global state" function API matching the old API, would that be fine?

@TV4Fun
Copy link
Author

TV4Fun commented Oct 6, 2022

@ezyang I'd have to see exactly what you are proposing, but that sounds like a better solution.

@ezyang
Copy link
Contributor

ezyang commented Oct 6, 2022

Err, it'd be exactly the same code you're running, just in pytorch library so you don't have to see the sausage 😛

@TV4Fun
Copy link
Author

TV4Fun commented Oct 6, 2022 via email

ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 3f678907f143125bf0931d2c86eae7c9eb8ee156
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: a3644fb379a42e65052be03d86bca7a56baec198
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: e6a8d81a1b4d5e6ed09ac51874b4a0f892d5872b
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Dec 31, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 31, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 1075451a29c7411044ada78df5b7be08cb06354a
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Dec 31, 2022
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 31, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 539474a5becc7713dc61aa42edd3802640caf319
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 00ab99d2388d3115a72dd6da1dfb29a39f727acd
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 0d6ddaaf2eaac23ffa77d23826d372ab64c7bb54
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 79b87ed4a96e3facce461f54d84286832f496d0b
Pull Request resolved: #91525
@mattiasu96
Copy link

Is this fixed? Such error propagates also to PyTorch dependent libraries such as Speechbrain (speechbrain/speechbrain#1794)

@TV4Fun
Copy link
Author

TV4Fun commented Jan 10, 2023 via email

@mattiasu96
Copy link

That feature doesn't look like is available yet in PyTorch docs (or my local package installation). So I guess it hasnt been released yet, right?

You have to use the new torch.utils.device_mode to set your default device.

On Tue, Jan 10, 2023 at 12:07 AM Mattia Surricchio @.> wrote: Is this fixed? Such error propagates also to PyTorch dependent libraries such as Speechbrain (speechbrain/speechbrain#1794 <speechbrain/speechbrain#1794>) — Reply to this email directly, view it on GitHub <#82296 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABD4ECGV4AJJ2TKVVUE2ZS3WRUKF3ANCNFSM54YGAFJQ . You are receiving this because you were mentioned.Message ID: @.>

@ezyang
Copy link
Contributor

ezyang commented Jan 10, 2023

The feature as landed in the PR has some API changes, in particular you can just use torch.device as the context manager, and there's also now torch.set_default_device

@mattiasu96
Copy link

But this change is available as unstable release from master right? It doesn't look like it is available as a stable release, am i wrong?

@ezyang
Copy link
Contributor

ezyang commented Jan 10, 2023

No it's never been in stable. However the snippet in this issue is self contained so you can backport it to a sufficiently recent stable (I think 1.13 only)

@mattiasu96
Copy link

No it's never been in stable. However the snippet in this issue is self contained so you can backport it to a sufficiently recent stable (I think 1.13 only)

Do you mean this one right? #82296 (comment)

@ezyang
Copy link
Contributor

ezyang commented Jan 10, 2023

yup

@chelseas
Copy link

I am still having this issue ://

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: mps Related to Apple Metal Performance Shaders framework triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants