-
Notifications
You must be signed in to change notification settings - Fork 566
Description
🐛 Bug
In our internal tests, the new xm.all_gather
API implemented in #3275 is shown to take significantly more memory to execute than the previous all-gather implementation via all_reduce in PyTorch XLA 1.10 (in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615).
For example, as shown in the reproducing steps below, the new xm.all_gather
API fails to all-gather a 512 MB tensor on v3-8 (that would result in only a 4 GB output tensor, much smaller than the 16 GB total TPU memory size). Meanwhile, the old xm.all_gather
API in PyTorch XLA 1.10 can handle this case without a problem.
It is weird and unexpected why the new xm.all_gather
API takes so much memory to execute. (A well-implemented all-gather should in principle take the same amount of memory as the output tensor size.) It is breaking many large-model use cases, e.g. when using a large number of sharded parameters or large tensors, or ZeRO (FSDP) in #3431.
One workaround is for the large-model users to manually revert to the previous all-gather implementation via all_reduce in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615. However, #3506 partially prevents this workaround since now xm.all_reduce
cannot work with xm.reduce_scatter
.
To Reproduce
- Allocate a v3-8 TPU VM from
tpu-vm-pt-1.10
runtime and install20220415
version oftorch
,torchvision
, andtorch_xla
, while keeping20220408
version of libtpu (since the newer20220415
version was reported bad in PyTorch XLA.data
assignment fails when the new tensor is a different shape #3502 (comment)).
# torch, torchvision and torch_xla 20220415
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220415-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220415-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220415-cp38-cp38-linux_x86_64.whl
# libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl
- Save the following content to a python file (e.g.
/home/ronghanghu/test_all_gather_only_mem.py
below).
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
world_size = xm.xrt_world_size()
device = xm.xla_device()
t1 = torch.ones(1024**3 // world_size, device=device) # (4 GB // world_size) in float32, i.e. 512 MB on v3-8
xm.mark_step()
print(f"t1.sum(): {t1.sum()}, mem: {xm.get_memory_info(device)}")
t2 = xm.all_gather(t1).flatten() # 4 GB in float32
del t1
xm.mark_step()
print(f"t2.sum(): {t2.sum()}, mem: {xm.get_memory_info(device)}")
if __name__ == "__main__":
xmp.spawn(_mp_fn, args=(), nprocs=8)
- Run this file on the v3-8 TPU VM:
python3 /home/ronghanghu/test_all_gather_only_mem.py
It prints
...
2022-04-18 06:42:22.783996: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at tpu_execute_op.cc:266 : RESOURCE_EXHAUSTED: Attempting to reserve 12.00G at the bottom of memory. That was not possible. There are 10.98G free, 0B reserved, and 10.98G reservable.
...
- If we revert
xm.all_gather
to the older version implemented via all_reduce (adding the snippet below to the code), then this example/home/ronghanghu/test_all_gather_only_mem.py
can run without memory errors.
def old_all_gather(value, dim=0, groups=None):
"""
This is the older all_gather implementation via all_reduce in PyTorch XLA 1.10 in
https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615
"""
if dim < 0:
dim = value.dim() + dim
size = value.size(dim)
padding = [0] * (2 * value.dim())
ordinal = xm.get_ordinal()
if groups is None:
left, right = ordinal, xm.xrt_world_size() - 1 - ordinal
else:
ordinals = dict()
for g in groups:
for i, x in enumerate(g):
ordinals[x] = (i, len(g) - 1 - i)
left, right = ordinals[ordinal]
idx = value.dim() - 1 - dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
return xm.all_reduce(xm.REDUCE_SUM, F.pad(value, padding), groups=groups)
xm.all_gather = old_all_gather
Expected behavior
The new xm.all_gather
API should not take so much memory to execute. It should in principle take the same amount of memory as the output tensor size. Otherwise, it is preventing a lot of practical scaling applications such as the ZeRO optimizer from Microsoft.
Environment
- Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
- torch_xla version: 20220415 nightly from
tpu-vm-pt-1.10
(see Step 1 above)
Additional context
Based on the error message Attempting to reserve 12.00G at the bottom of memory. That was not possible.
, it is still weird why the new xm.all_gather
API cannot run in the example above -- since TPU v3 has 16 GB memory size, it should be able to run even if the new all-gather takes 12 GB?
cc: @JackCaoG