You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Device auto-detection only works for some functions.
torch_xla.device() works:
# python
>>> import torch_xla
>>> torch_xla.device()
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
device(type='xla', index=0)
torch_xla.real_devices() does not work:
# python
>>> import torch_xla
>>> torch_xla.real_devices()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/workspaces/ptxla/pytorch/xla/torch_xla/torch_xla.py", line 49, in real_devices
return torch_xla._XLAC._xla_real_devices()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:31 : $PJRT_DEVICE is not set.
Only functions that pass through a function wrapped in requires_pjrt (or ones that call using_pjrt) trigger auto-detection; we did this to accommodate XRT, which is no longer a concern. I believe we can trigger auto detection upon import, or at least more broadly to cover our public API usage.
Tasks:
Remove using_pjrt and requires_pjrt. These functions are both irrelevant now, and we only (ab)use them for a side-effect
Remove this warning during auto-detection: WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
Trigger device auto-detection somewhere else that makes sense (e.g. package import). Remember that anything that references env vars will cause a graph break in Dynamo according to @JackCaoG
The text was updated successfully, but these errors were encountered:
With using_pjrt user doesn't need to call the function _maybe_select_default_device to init the device by themselves. To get rid of using_pjrt decorator, we either need to:
force user to call function like _maybe_select_default_device at the beginning of their script; Or
place the _maybe_select_default_device within some other functions that must be called at the beginning of their script.
Both ways look annoying. Can we do something in torch_xla/__init__.py file to automatically find the device at the first place?
__init__.py sounds fine. That's what I meant by "trigger auto detection upon import".
Broadly, there are two options that I see:
Decorate any function that could potentially init the runtime something equivalent to requires_pjrt. This is arguably cleaner/more precise, but I'm concerned we'll miss some functions and not fully root out this $PJRT_DEVICE is not set error.
_maybe_select_default_device in __init__.py so it runs automatically when the user imports torch_xla.
I'm open to other solutions too. The core constraint is that runtime init happens in C++, while most of the logic we need to detect devices is in Python.
Device auto-detection only works for some functions.
torch_xla.device()
works:torch_xla.real_devices()
does not work:Only functions that pass through a function wrapped in
requires_pjrt
(or ones that callusing_pjrt
) trigger auto-detection; we did this to accommodate XRT, which is no longer a concern. I believe we can trigger auto detection upon import, or at least more broadly to cover our public API usage.Tasks:
using_pjrt
andrequires_pjrt
. These functions are both irrelevant now, and we only (ab)use them for a side-effectWARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
The text was updated successfully, but these errors were encountered: