-
Notifications
You must be signed in to change notification settings - Fork 566
Description
🚀 Feature
Add the ability to translate the following Collective Communication ops to native
XLA instructions:
all_gather
reduce_scatter
collective_permute
send
recv
At the moment all_gather
is implemented with all_reduce
in PyTorch/XLA. The other 4
operators are not yet implemented. We are proposing to change the current all_gather
implementation to generate native XLA all_gather
instruction, and to add native
XLA support for reduce_scatter
, collective_permute
, send
, and recv
operations.
We are intending to submit PRs for this proposal.
Motivation
Adding the ability to distributed training libraries (such as DeepSpeed) and
Apps (such as MegatronLM) to utilize underlying XLA devices/accelerators to
train very large models by calling torch.distributed
API.
Currently MegatronLM and DeepSpeed use torch.distributed
API to execute Collective
Communication primitives, but torch.distributed
does not support XLA. To enable
MegatronLM and DeepSpeed to run on XLA, we are proposing to utilize PT/XLA, which is
a bridge from PT to XLA, to provide the ability to translate the collective
communication primitives listed above to native XLA instructions.
Pitch
We would like users to be able to call the following APIs to insert the corresponding
XLA instructions:
import torch_xla.core.xla_model as xm
xm.all_gather(tensor)
xm.reduce_scatter(xm.REDUCE_SUM, tensor, groups=[[0,1]])
xm.collective_permute(tensor, pairs=[[0,2],[1,3]])
xm.send(tensor, channel)
tensor = xm.recv(shape, channel)
Alternatives
User could implement all these ops with all_reduce
, but the efficiency would be too
low for them to be practical.