Skip to content

torch.mps.*Tensor datatypes #82296

Closed
Closed
@TV4Fun

Description

@TV4Fun

🚀 The feature, motivation and pitch

An issue that has been debated ad nauseam and apparently still doesn't have an agreed upon answer as of PyTorch 1.12 is how or if to set a default device for Torch operations. See for example #27878, which has been open nearly 3 years now. A method many libraries use and which is recommended in at least a few places on Stack Overflow is to use set_default_tensor_type to default to a CUDA tensor type. This works for CUDA, but does not work for MPS, as there do not appear to be equivalent tensor types for that. This also creates problems with unclear type names, as if, for example, I create a tensor on MPS with w = torch.tensor([1.0], device='mps'), w.type() for this returns 'torch.mps.FloatTensor', but this is not actually a valid type. There is no torch.mps module, and if I try to pass it as a string, say with x = w.type('torch.mps.FloatTensor'), this returns an error that it is an invalid type.

As near as I can tell, there is no way to directly create a tensor on the MPS device without specifying device='mps' on every single call, which not only clutters code but also makes it very brittle if I happen to miss it on one call. Please correct me if I am wrong on this. My particular use case is that I would like to add MPS support to ML-Agents, which already supports CUDA by means of calling torch.set_default_tensor_type(torch.cuda.FloatTensor) if a CUDA device is available. There does not appear to be an equivalent way to do this for MPS and I have no desire to try and track down however many hundreds of tensor creation calls there are in their code. I know there have been calls to deprecate set_default_tensor_type, i.e. #53124, but I would recommend not doing that without providing some other way to provide a default device. In the meantime, I would really like it if I had an easy way to set the default tensor type to torch.mps.FloatTensor.

Alternatives

Provide a way to set the default device for torch.tensor() and similar calls to MPS.

Additional context

See also #260 and probably others.

cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: mpsRelated to Apple Metal Performance Shaders frameworktriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions