Skip to content

[RFC] Exposing additional XLA collective communication primitives. #3138

@hjm-aws

Description

@hjm-aws

🚀 Feature

Add the ability to translate the following Collective Communication ops to native
XLA instructions:

  • all_gather
  • reduce_scatter
  • collective_permute
  • send
  • recv

At the moment all_gather is implemented with all_reduce in PyTorch/XLA. The other 4
operators are not yet implemented. We are proposing to change the current all_gather
implementation to generate native XLA all_gather instruction, and to add native
XLA support for reduce_scatter, collective_permute, send, and recv operations.

We are intending to submit PRs for this proposal.

Motivation

Adding the ability to distributed training libraries (such as DeepSpeed) and
Apps (such as MegatronLM) to utilize underlying XLA devices/accelerators to
train very large models by calling torch.distributed API.

Currently MegatronLM and DeepSpeed use torch.distributed API to execute Collective
Communication primitives, but torch.distributed does not support XLA. To enable
MegatronLM and DeepSpeed to run on XLA, we are proposing to utilize PT/XLA, which is
a bridge from PT to XLA, to provide the ability to translate the collective
communication primitives listed above to native XLA instructions.

Pitch

We would like users to be able to call the following APIs to insert the corresponding
XLA instructions:

import torch_xla.core.xla_model as xm

xm.all_gather(tensor)
xm.reduce_scatter(xm.REDUCE_SUM, tensor, groups=[[0,1]])
xm.collective_permute(tensor, pairs=[[0,2],[1,3]])
xm.send(tensor, channel)
tensor = xm.recv(shape, channel)

Alternatives

User could implement all these ops with all_reduce, but the efficiency would be too
low for them to be practical.

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCnostaleDo not consider for staleness

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions