Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions docs/pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,6 @@ PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --f
Currently, only a single host is supported, and multi-host GPU cluster support
will be added in an future release.

#### Known Issues

The GPU integration has issues with replica groups in collectives (i.e. the
`group` parameter of the XLA collective ops). If the replica groups are
changed, there is a chance that the process will hang. For now, the
recommendation is to use a single replica group containing all devices, as is
the case in data parallel training.

## Key differences from XRT

Although in most cases we expect PjRt and XRT to work mostly interchangeably
Expand Down
2 changes: 2 additions & 0 deletions test/pjrt/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _ddp_init(index: int = ...):
def test_ddp_init(self):
pjrt._run_multiprocess(self._ddp_init)

@absltest.skipIf(pjrt.device_type() == 'GPU',
"GPU device is not supported by pjrt.spawn_threads")
def test_ddp_init_threaded(self):
pjrt.spawn_threads(self._ddp_init)

Expand Down
3 changes: 1 addition & 2 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ function run_xla_backend_mp {
function run_pjrt {
echo "Running in PjRt runtime: $@"
if [ -x "$(command -v nvidia-smi)" ]; then
# TODO(jonbolin): Only run GPU tests with a single device due to collective failures.
PJRT_DEVICE=GPU GPU_NUM_DEVICES=1 run_test "$@"
PJRT_DEVICE=GPU run_test "$@"
else
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test "$@"
Expand Down
3 changes: 1 addition & 2 deletions test/utils/run_test_coverage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ function run_xla_backend_mp {
function run_pjrt {
echo "Running in PjRt runtime: $@"
if [ -x "$(command -v nvidia-smi)" ]; then
# TODO(jonbolin): Only run GPU tests with a single device due to collective failures.
PJRT_DEVICE=GPU GPU_NUM_DEVICES=1 run_test "$@"
PJRT_DEVICE=GPU run_test "$@"
else
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test "$@"
Expand Down
1 change: 1 addition & 0 deletions torch_xla/experimental/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def _initialize_single_process(local_rank: int, local_world_size: int):

def spawn_threads(fn: Callable, args: Tuple = ()) -> None:
"""Run function in one process with one thread per addressable device."""
assert device_type() != 'GPU', "spawn_threads does not support GPU device"
spawn_fn = _SpawnFn(fn, *args)
_run_thread_per_device(
local_rank=0,
Expand Down