Skip to content

All-reduce together w/ reduce-scatter causes crash on nightly 20220413 build #3506

@ronghanghu

Description

@ronghanghu

🐛 Bug

The nightly PyTorch XLA build (20220413 for torch, torchvision, torch_xla) gives an unexpected error when all_reduce is used together with reduced_scatter as follows

2022-04-15 06:35:51.117279: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.                           
2022-04-15 06:35:51.117346: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)

This error doesn't happen on the 20220408 build. This is likely a side effect of #3484 that removes reduce_scatter's layout pining.

To Reproduce

  1. Allocate a v3-8 TPU VM from tpu-vm-pt-1.10 runtime and install 20220413 version of torch, torchvision, and torch_xla, while keeping 20220408 version of libtpu (since the newer 20220413 version was reported bad in PyTorch XLA .data assignment fails when the new tensor is a different shape #3502 (comment)).
# torch, torchvision and torch_xla 20220413
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220413-cp38-cp38-linux_x86_64.whl

# libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl
  1. Save the following content to a python file (e.g. /home/ronghanghu/test_all_gather_all_reduce_reduce_scatter.py below).
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
    world_size = xm.xrt_world_size()
    t1 = torch.ones(1024, device=xm.xla_device())
    t2 = xm.all_gather(t1).flatten()
    t3 = xm.all_reduce(xm.REDUCE_SUM, t2)
    t4 = xm.reduce_scatter(xm.REDUCE_SUM, t3, scale=1.0, scatter_dim=0, shard_count=world_size)
    t5 = t4.sum()
    xm.mark_step()
    print(f"t5: {t5}")

if __name__ == "__main__":
    xmp.spawn(_mp_fn, args=(), nprocs=8)
  1. Run this file on the v3-8 TPU VM:
python3 /home/ronghanghu/test_all_gather_all_reduce_reduce_scatter.py

It prints

ronghanghu@t1v-n-f1525942-w-0:~$ python3 test_all_gather_all_reduce_reduce_scatter.py                                                                                                                                                               
2022-04-15 06:57:23.387994: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-04-15 06:57:23.388052: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-04-15 06:57:33.186197: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.
2022-04-15 06:57:33.186250: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)                                                                                                             
https://symbolize.stripped_domain/r/?trace=7feaf5dfd03b,7feaf5dfd0bf,7fea3683ebcf,7fea30eb3922,7fea30e71ebd,7fea30ec1db0,7fea30ec18ae,7fea2cd23ed3,7fea323651b8,7fea362f08a0,7fea362f2633,7fea36807cb1,7fea368074e0,7fea367ef8cb,7feaf5d9f608&map=a
7d53509d90e6515f49ac71c94546cafe5812b54:7fea283df000-7fea396d4e30                                                                                                                                                                                  *** SIGABRT received by PID 302216 (TID 303572) on cpu 1 from PID 302216; stack trace: ***                               
...

This example tries to cover all the 3 distributed ops. However, the error is caused by all_reduce and reduce_scatter being used together (and the error persists if we remove the all_gather op).

I think it would be great to add this example to the test cases of PyTorch XLA.

Expected behavior

The error of "HloModule has a mix of layout constrained and unconstrained AllReduce instructions" should not happen when all_gather, all_reduce and reduce_scatter are used together.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
  • torch_xla version: 20220413 nightly from tpu-vm-pt-1.10 (see Step 1 above)

Additional context

This error breaks the FSDP implementation in #3431, which often relies on all the 3 APIs (all_gather, all_reduce, and reduce_scatter) in the same program.

The error is also reproducible on today's 20220415 build of torch, torchvision, and torch_xla.

cc: @JackCaoG

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions