-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.mps.*Tensor datatypes #82296
Comments
cc @ezyang |
With TorchFunctionMode we can do this in a few dozen lines of code. I'll put up a PoC later today |
@ezyang With your PoC, can we for example make the default device be a |
yes; it will basically be similar to how torchdistx does it |
This is totally untested
I'd like to put this into core, but before we can get there, we have to make some policy decisions. For example, if you turn on DeviceMode("mps"), and then someone writes torch.randn(2, device="cpu"), does mps override the CPU device, or does the explicit device win out? It would be nice for a limited set of power users to try this out, make behavior modifications as necessary based on what they observe is necessary in the wild, and then we ship that. |
I would definitely suggest an explicit device overriding the default device mode. Though if you think this could cause problems, you could always raise a warning the first time this happens. |
Is this in the development branch yet? I'd like to be able to try it out. |
it's definitely in the nightly, and it may also work on the most recent official release |
Thank you @ezyang. Due to #78681 and #79337, I am unable to test this on Torch nightly with my M1 mac. Trying this on Torch 1.12.0 gives this error:
|
Ok you are missing a bugfix that makes the syntax work, try writing the context manager as MyMode.push(device) instead |
Okay, thank you, that works and if I create this with
the created tensor appears to be on the MPS device. Is there a way to do this without having to move all my code inside of a |
|
Okay, that works on the simple case. Trying on a more complex example is causing an internal error (again on Torch 1.12.0):
The This module sets the default device mode with the lines
The line which is triggering this error is https://github.com/Unity-Technologies/ml-agents/blob/main/ml-agents/mlagents/trainers/sac/optimizer_torch.py#L159, which is calling |
It's possible |
No, that was not the problem, as |
Confirmed this bug does not exist in the dev build. I was able to implement your code above with a few tweaks here and it appears to have set MPS as the default device. Thank you. |
we merge it to master 👀 |
@kulinseth, I would call @ezyang's solution here a hack. It works, but it would still be nice if there were a simpler way to set the default device to MPS. |
What did you find hacky about it in the end? |
Just that it involves creating a long list of each individual tensor constructor and then creating a context manager at global scope and manually calling |
Yes, so supposing that PyTorch core maintained the internal implementation details, and we gave a "global state" function API matching the old API, would that be fine? |
@ezyang I'd have to see exactly what you are proposing, but that sounds like a better solution. |
Err, it'd be exactly the same code you're running, just in pytorch library so you don't have to see the sausage 😛 |
Yeah, that works.
…On Wed, Oct 5, 2022, 7:20 PM Edward Z. Yang ***@***.***> wrote:
Err, it'd be exactly the same code you're running, just in pytorch library
so you don't have to see the sausage 😛
—
Reply to this email directly, view it on GitHub
<#82296 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABD4ECE3VUZUDBGZNIETHE3WBYZNXANCNFSM54YGAFJQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 3f678907f143125bf0931d2c86eae7c9eb8ee156 Pull Request resolved: #91525
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: a3644fb379a42e65052be03d86bca7a56baec198 Pull Request resolved: #91525
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: e6a8d81a1b4d5e6ed09ac51874b4a0f892d5872b Pull Request resolved: #91525
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 1075451a29c7411044ada78df5b7be08cb06354a Pull Request resolved: #91525
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 539474a5becc7713dc61aa42edd3802640caf319 Pull Request resolved: #91525
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 00ab99d2388d3115a72dd6da1dfb29a39f727acd Pull Request resolved: #91525
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 0d6ddaaf2eaac23ffa77d23826d372ab64c7bb54 Pull Request resolved: #91525
Fixes #82296 Fixes #27878 Fixes #260 Open to bikeshedding the module name. Open to bikeshedding if any of the identifiers should get reexporting from torch top level. Open to bikeshedding the global setter API name / location. Tests coming later. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 79b87ed4a96e3facce461f54d84286832f496d0b Pull Request resolved: #91525
Is this fixed? Such error propagates also to PyTorch dependent libraries such as Speechbrain (speechbrain/speechbrain#1794) |
You have to use the new torch.utils.device_mode to set your default device.
…On Tue, Jan 10, 2023 at 12:07 AM Mattia Surricchio ***@***.***> wrote:
Is this fixed? Such error propagates also to PyTorch dependent libraries
such as Speechbrain (speechbrain/speechbrain#1794
<speechbrain/speechbrain#1794>)
—
Reply to this email directly, view it on GitHub
<#82296 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABD4ECGV4AJJ2TKVVUE2ZS3WRUKF3ANCNFSM54YGAFJQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
That feature doesn't look like is available yet in PyTorch docs (or my local package installation). So I guess it hasnt been released yet, right?
|
The feature as landed in the PR has some API changes, in particular you can just use torch.device as the context manager, and there's also now torch.set_default_device |
But this change is available as unstable release from master right? It doesn't look like it is available as a stable release, am i wrong? |
No it's never been in stable. However the snippet in this issue is self contained so you can backport it to a sufficiently recent stable (I think 1.13 only) |
Do you mean this one right? #82296 (comment) |
yup |
I am still having this issue :// |
🚀 The feature, motivation and pitch
An issue that has been debated ad nauseam and apparently still doesn't have an agreed upon answer as of PyTorch 1.12 is how or if to set a default device for Torch operations. See for example #27878, which has been open nearly 3 years now. A method many libraries use and which is recommended in at least a few places on Stack Overflow is to use
set_default_tensor_type
to default to a CUDA tensor type. This works for CUDA, but does not work for MPS, as there do not appear to be equivalent tensor types for that. This also creates problems with unclear type names, as if, for example, I create a tensor on MPS withw = torch.tensor([1.0], device='mps')
,w.type()
for this returns'torch.mps.FloatTensor'
, but this is not actually a valid type. There is notorch.mps
module, and if I try to pass it as a string, say withx = w.type('torch.mps.FloatTensor')
, this returns an error that it is an invalid type.As near as I can tell, there is no way to directly create a tensor on the MPS device without specifying
device='mps'
on every single call, which not only clutters code but also makes it very brittle if I happen to miss it on one call. Please correct me if I am wrong on this. My particular use case is that I would like to add MPS support to ML-Agents, which already supports CUDA by means of callingtorch.set_default_tensor_type(torch.cuda.FloatTensor)
if a CUDA device is available. There does not appear to be an equivalent way to do this for MPS and I have no desire to try and track down however many hundreds of tensor creation calls there are in their code. I know there have been calls to deprecateset_default_tensor_type
, i.e. #53124, but I would recommend not doing that without providing some other way to provide a default device. In the meantime, I would really like it if I had an easy way to set the default tensor type totorch.mps.FloatTensor
.Alternatives
Provide a way to set the default device for
torch.tensor()
and similar calls to MPS.Additional context
See also #260 and probably others.
cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev
The text was updated successfully, but these errors were encountered: