Skip to content

Remove non-parallel approaches to get a device #8827

@yaoshiang

Description

@yaoshiang

🐛 Bug

There are four ways to get a device object. Three could be deprecated as they are likely unnecessary, and also anti-parallel to torch.cuda.device(), which does not return a device but is in fact a context manager.

torch.device('xla', 3)
torch_xla.device(3)
torch_xla.torch_xla.device(3)
torch_xla.core.xla_model.xla_device(3)

To Reproduce

Run this in a collab with a TPU enabled and pip installed pytorch_xla with the correct version.

torch.device('xla', 3)
torch_xla.device(3)
torch_xla.torch_xla.device(3)
torch_xla.core.xla_model.xla_device(3)

Expected behavior

Should deprecate the three ways to get a device inside torch_xla.
Docs that use them should be updated to torch.device("xla", <optional int>) and remove import of torch_xla.core.xla_model if necessary.

Environment

TPU on Collab

Additional context

Should verify that the four APIs do ultimately call the same underlying function to guarantee equivalent behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestusabilityBugs/features related to improving the usability of PyTorch/XLA

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions