-
Notifications
You must be signed in to change notification settings - Fork 565
Description
🐛 Bug
There are four ways to get a device object. Three could be deprecated as they are likely unnecessary, and also anti-parallel to torch.cuda.device(), which does not return a device but is in fact a context manager.
torch.device('xla', 3)
torch_xla.device(3)
torch_xla.torch_xla.device(3)
torch_xla.core.xla_model.xla_device(3)
To Reproduce
Run this in a collab with a TPU enabled and pip installed pytorch_xla with the correct version.
torch.device('xla', 3)
torch_xla.device(3)
torch_xla.torch_xla.device(3)
torch_xla.core.xla_model.xla_device(3)
Expected behavior
Should deprecate the three ways to get a device inside torch_xla.
Docs that use them should be updated to torch.device("xla", <optional int>) and remove import of torch_xla.core.xla_model if necessary.
Environment
TPU on Collab
Additional context
Should verify that the four APIs do ultimately call the same underlying function to guarantee equivalent behavior.