-
Notifications
You must be signed in to change notification settings - Fork 566
Description
🐛 Bug
The nightly PyTorch XLA build (20220413 for torch, torchvision, torch_xla) gives an unexpected error when all_reduce
is used together with reduced_scatter
as follows
2022-04-15 06:35:51.117279: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.
2022-04-15 06:35:51.117346: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)
This error doesn't happen on the 20220408 build. This is likely a side effect of #3484 that removes reduce_scatter
's layout pining.
To Reproduce
- Allocate a v3-8 TPU VM from
tpu-vm-pt-1.10
runtime and install20220413
version oftorch
,torchvision
, andtorch_xla
, while keeping20220408
version of libtpu (since the newer20220413
version was reported bad in PyTorch XLA.data
assignment fails when the new tensor is a different shape #3502 (comment)).
# torch, torchvision and torch_xla 20220413
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220413-cp38-cp38-linux_x86_64.whl
# libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl
- Save the following content to a python file (e.g.
/home/ronghanghu/test_all_gather_all_reduce_reduce_scatter.py
below).
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
world_size = xm.xrt_world_size()
t1 = torch.ones(1024, device=xm.xla_device())
t2 = xm.all_gather(t1).flatten()
t3 = xm.all_reduce(xm.REDUCE_SUM, t2)
t4 = xm.reduce_scatter(xm.REDUCE_SUM, t3, scale=1.0, scatter_dim=0, shard_count=world_size)
t5 = t4.sum()
xm.mark_step()
print(f"t5: {t5}")
if __name__ == "__main__":
xmp.spawn(_mp_fn, args=(), nprocs=8)
- Run this file on the v3-8 TPU VM:
python3 /home/ronghanghu/test_all_gather_all_reduce_reduce_scatter.py
It prints
ronghanghu@t1v-n-f1525942-w-0:~$ python3 test_all_gather_all_reduce_reduce_scatter.py
2022-04-15 06:57:23.387994: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-04-15 06:57:23.388052: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-04-15 06:57:33.186197: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.
2022-04-15 06:57:33.186250: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)
https://symbolize.stripped_domain/r/?trace=7feaf5dfd03b,7feaf5dfd0bf,7fea3683ebcf,7fea30eb3922,7fea30e71ebd,7fea30ec1db0,7fea30ec18ae,7fea2cd23ed3,7fea323651b8,7fea362f08a0,7fea362f2633,7fea36807cb1,7fea368074e0,7fea367ef8cb,7feaf5d9f608&map=a
7d53509d90e6515f49ac71c94546cafe5812b54:7fea283df000-7fea396d4e30 *** SIGABRT received by PID 302216 (TID 303572) on cpu 1 from PID 302216; stack trace: ***
...
This example tries to cover all the 3 distributed ops. However, the error is caused by all_reduce
and reduce_scatter
being used together (and the error persists if we remove the all_gather
op).
I think it would be great to add this example to the test cases of PyTorch XLA.
Expected behavior
The error of "HloModule has a mix of layout constrained and unconstrained AllReduce instructions" should not happen when all_gather
, all_reduce
and reduce_scatter
are used together.
Environment
- Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
- torch_xla version: 20220413 nightly from
tpu-vm-pt-1.10
(see Step 1 above)
Additional context
This error breaks the FSDP implementation in #3431, which often relies on all the 3 APIs (all_gather
, all_reduce
, and reduce_scatter
) in the same program.
The error is also reproducible on today's 20220415
build of torch
, torchvision
, and torch_xla
.
cc: @JackCaoG