Skip to content

New all-gather API takes much more memory than 1.10 all-gather implementation via all-reduce #3510

@ronghanghu

Description

@ronghanghu

🐛 Bug

In our internal tests, the new xm.all_gather API implemented in #3275 is shown to take significantly more memory to execute than the previous all-gather implementation via all_reduce in PyTorch XLA 1.10 (in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615).

For example, as shown in the reproducing steps below, the new xm.all_gather API fails to all-gather a 512 MB tensor on v3-8 (that would result in only a 4 GB output tensor, much smaller than the 16 GB total TPU memory size). Meanwhile, the old xm.all_gather API in PyTorch XLA 1.10 can handle this case without a problem.

It is weird and unexpected why the new xm.all_gather API takes so much memory to execute. (A well-implemented all-gather should in principle take the same amount of memory as the output tensor size.) It is breaking many large-model use cases, e.g. when using a large number of sharded parameters or large tensors, or ZeRO (FSDP) in #3431.

One workaround is for the large-model users to manually revert to the previous all-gather implementation via all_reduce in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615. However, #3506 partially prevents this workaround since now xm.all_reduce cannot work with xm.reduce_scatter.

To Reproduce

  1. Allocate a v3-8 TPU VM from tpu-vm-pt-1.10 runtime and install 20220415 version of torch, torchvision, and torch_xla, while keeping 20220408 version of libtpu (since the newer 20220415 version was reported bad in PyTorch XLA .data assignment fails when the new tensor is a different shape #3502 (comment)).
# torch, torchvision and torch_xla 20220415
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220415-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220415-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220415-cp38-cp38-linux_x86_64.whl

# libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl
  1. Save the following content to a python file (e.g. /home/ronghanghu/test_all_gather_only_mem.py below).
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
    world_size = xm.xrt_world_size()
    device = xm.xla_device()

    t1 = torch.ones(1024**3 // world_size, device=device)  # (4 GB // world_size) in float32, i.e. 512 MB on v3-8
    xm.mark_step()
    print(f"t1.sum(): {t1.sum()}, mem: {xm.get_memory_info(device)}")

    t2 = xm.all_gather(t1).flatten()  # 4 GB in float32
    del t1
    xm.mark_step()
    print(f"t2.sum(): {t2.sum()}, mem: {xm.get_memory_info(device)}")

if __name__ == "__main__":
    xmp.spawn(_mp_fn, args=(), nprocs=8)
  1. Run this file on the v3-8 TPU VM:
python3 /home/ronghanghu/test_all_gather_only_mem.py

It prints

...
2022-04-18 06:42:22.783996: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at tpu_execute_op.cc:266 : RESOURCE_EXHAUSTED: Attempting to reserve 12.00G at the bottom of memory. That was not possible. There are 10.98G free, 0B reserved, and 10.98G reservable.                      
...
  1. If we revert xm.all_gather to the older version implemented via all_reduce (adding the snippet below to the code), then this example /home/ronghanghu/test_all_gather_only_mem.py can run without memory errors.
def old_all_gather(value, dim=0, groups=None):
    """
    This is the older all_gather implementation via all_reduce in PyTorch XLA 1.10 in
    https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615
    """
    if dim < 0:
        dim = value.dim() + dim
    size = value.size(dim)
    padding = [0] * (2 * value.dim())
    ordinal = xm.get_ordinal()
    if groups is None:
        left, right = ordinal, xm.xrt_world_size() - 1 - ordinal
    else:
        ordinals = dict()
        for g in groups:
            for i, x in enumerate(g):
                ordinals[x] = (i, len(g) - 1 - i)
        left, right = ordinals[ordinal]
    idx = value.dim() - 1 - dim
    padding[2 * idx] = left * size
    padding[2 * idx + 1] = right * size
    return xm.all_reduce(xm.REDUCE_SUM, F.pad(value, padding), groups=groups)


xm.all_gather = old_all_gather

Expected behavior

The new xm.all_gather API should not take so much memory to execute. It should in principle take the same amount of memory as the output tensor size. Otherwise, it is preventing a lot of practical scaling applications such as the ZeRO optimizer from Microsoft.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
  • torch_xla version: 20220415 nightly from tpu-vm-pt-1.10 (see Step 1 above)

Additional context

Based on the error message Attempting to reserve 12.00G at the bottom of memory. That was not possible., it is still weird why the new xm.all_gather API cannot run in the example above -- since TPU v3 has 16 GB memory size, it should be able to run even if the new all-gather takes 12 GB?

cc: @JackCaoG

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions