Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve device auto-detection #7730

Closed
3 tasks done
will-cromar opened this issue Jul 23, 2024 · 2 comments · Fixed by #7787
Closed
3 tasks done

Improve device auto-detection #7730

will-cromar opened this issue Jul 23, 2024 · 2 comments · Fixed by #7787
Assignees
Labels
usability Bugs/features related to improving the usability of PyTorch/XLA

Comments

@will-cromar
Copy link
Collaborator

will-cromar commented Jul 23, 2024

Device auto-detection only works for some functions.

torch_xla.device() works:

# python
>>> import torch_xla
>>> torch_xla.device()
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
device(type='xla', index=0)

torch_xla.real_devices() does not work:

# python
>>> import torch_xla
>>> torch_xla.real_devices()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/workspaces/ptxla/pytorch/xla/torch_xla/torch_xla.py", line 49, in real_devices
    return torch_xla._XLAC._xla_real_devices()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:31 : $PJRT_DEVICE is not set.

Only functions that pass through a function wrapped in requires_pjrt (or ones that call using_pjrt) trigger auto-detection; we did this to accommodate XRT, which is no longer a concern. I believe we can trigger auto detection upon import, or at least more broadly to cover our public API usage.

Tasks:

  • Remove using_pjrt and requires_pjrt. These functions are both irrelevant now, and we only (ab)use them for a side-effect
  • Remove this warning during auto-detection: WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
  • Trigger device auto-detection somewhere else that makes sense (e.g. package import). Remember that anything that references env vars will cause a graph break in Dynamo according to @JackCaoG
@will-cromar will-cromar added the usability Bugs/features related to improving the usability of PyTorch/XLA label Jul 23, 2024
@zpcore
Copy link
Collaborator

zpcore commented Jul 29, 2024

With using_pjrt user doesn't need to call the function _maybe_select_default_device to init the device by themselves. To get rid of using_pjrt decorator, we either need to:

  1. force user to call function like _maybe_select_default_device at the beginning of their script; Or
  2. place the _maybe_select_default_device within some other functions that must be called at the beginning of their script.

Both ways look annoying. Can we do something in torch_xla/__init__.py file to automatically find the device at the first place?

@will-cromar
Copy link
Collaborator Author

__init__.py sounds fine. That's what I meant by "trigger auto detection upon import".

Broadly, there are two options that I see:

  1. Decorate any function that could potentially init the runtime something equivalent to requires_pjrt. This is arguably cleaner/more precise, but I'm concerned we'll miss some functions and not fully root out this $PJRT_DEVICE is not set error.
  2. _maybe_select_default_device in __init__.py so it runs automatically when the user imports torch_xla.

I'm open to other solutions too. The core constraint is that runtime init happens in C++, while most of the logic we need to detect devices is in Python.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
usability Bugs/features related to improving the usability of PyTorch/XLA
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants