Description
🚀 The feature, motivation and pitch
An issue that has been debated ad nauseam and apparently still doesn't have an agreed upon answer as of PyTorch 1.12 is how or if to set a default device for Torch operations. See for example #27878, which has been open nearly 3 years now. A method many libraries use and which is recommended in at least a few places on Stack Overflow is to use set_default_tensor_type
to default to a CUDA tensor type. This works for CUDA, but does not work for MPS, as there do not appear to be equivalent tensor types for that. This also creates problems with unclear type names, as if, for example, I create a tensor on MPS with w = torch.tensor([1.0], device='mps')
, w.type()
for this returns 'torch.mps.FloatTensor'
, but this is not actually a valid type. There is no torch.mps
module, and if I try to pass it as a string, say with x = w.type('torch.mps.FloatTensor')
, this returns an error that it is an invalid type.
As near as I can tell, there is no way to directly create a tensor on the MPS device without specifying device='mps'
on every single call, which not only clutters code but also makes it very brittle if I happen to miss it on one call. Please correct me if I am wrong on this. My particular use case is that I would like to add MPS support to ML-Agents, which already supports CUDA by means of calling torch.set_default_tensor_type(torch.cuda.FloatTensor)
if a CUDA device is available. There does not appear to be an equivalent way to do this for MPS and I have no desire to try and track down however many hundreds of tensor creation calls there are in their code. I know there have been calls to deprecate set_default_tensor_type
, i.e. #53124, but I would recommend not doing that without providing some other way to provide a default device. In the meantime, I would really like it if I had an easy way to set the default tensor type to torch.mps.FloatTensor
.
Alternatives
Provide a way to set the default device for torch.tensor()
and similar calls to MPS.
Additional context
See also #260 and probably others.
cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev