Skip to content

torch.distributed.all_reduce not converted to stableHLO #8854

@AleksKnezevic

Description

@AleksKnezevic

🐛 Bug

As the title suggests, torch.distributed.all_reduce is not being converted to stableHLO.

To Reproduce

I run the following test:

import os
import torch
from torch import nn
import torch
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo


def test():
    class Basic(nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)

    model = Basic()
    prog = export(model, (torch.rand(20, 10), ))
    shlo = exported_program_to_stablehlo(prog)
    print(shlo.get_stablehlo_text())

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
torch.distributed.init_process_group(world_size=1, rank=0)


if __name__ == "__main__":
    test()

Expected behavior

I would expect a stableHLO module with all_reduce, however I get the following error:

WARNING:root:Defaulting to PJRT_DEVICE=CPU
loc("all-reduce.10"): error: failed to legalize operation 'mhlo.all_reduce' that was explicitly marked illegal
[rank0]: Traceback (most recent call last):
[rank0]:   File "/localdev/aknezevic/xt/test_mp.py", line 28, in <module>
[rank0]:     test()
[rank0]:   File "/localdev/aknezevic/xt/test_mp.py", line 19, in test
[rank0]:     shlo = exported_program_to_stablehlo(prog)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 626, in exported_program_to_stablehlo
[rank0]:     bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 405, in _exported_program_to_stablehlo_bundle
[rank0]:     stablehlo_content = xm.get_stablehlo_bytecode(res)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 1103, in get_stablehlo_bytecode
[rank0]:     return torch_xla._XLAC._get_stablehlo(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: torch_xla/csrc/runtime/stablehlo_helper.cc:109 : Check failed: status.ok() 
[rank0]: *** Begin stack trace ***
[rank0]:        tsl::CurrentStackTrace()
[rank0]:        torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
[rank0]:        torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool)
[rank0]:        torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
[rank0]:        torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        PyEval_EvalCode
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyRun_SimpleFileObject
[rank0]:        _PyRun_AnyFileObject
[rank0]:        Py_RunMain
[rank0]:        Py_BytesMain
[rank0]: 
[rank0]:        __libc_start_main
[rank0]:        _start
[rank0]: *** End stack trace ***
[rank0]: MHLO -> StableHLO conversion failed.
[rank0]: StableHLO Module from MHLO -> StableHLO conversion is not leagal.Please open a github issue to PyTorch/XLA.
[rank0]: Original HLO dump:
[rank0]: HloModule IrToHlo.14, entry_computation_layout={(f32[], f32[20,10]{1,0})->(f32[20,10]{1,0}, f32[20,10]{1,0})}

[rank0]: %AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
[rank0]:   %x.7 = f32[] parameter(0)
[rank0]:   %y.8 = f32[] parameter(1)
[rank0]:   ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
[rank0]: }

[rank0]: ENTRY %IrToHlo.14 (p0.1: f32[], p1.2: f32[20,10]) -> (f32[20,10], f32[20,10]) {
[rank0]:   %p1.2 = f32[20,10]{1,0} parameter(1)
[rank0]:   %p0.1 = f32[] parameter(0)
[rank0]:   %tuple.3 = (f32[20,10]{1,0}, f32[]) tuple(f32[20,10]{1,0} %p1.2, f32[] %p0.1)
[rank0]:   %get-tuple-element.4 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=0
[rank0]:   %get-tuple-element.5 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=1
[rank0]:   %all-reduce.10 = (f32[20,10]{1,0}, f32[]) all-reduce(f32[20,10]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.5), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.6
[rank0]:   %get-tuple-element.12 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=1
[rank0]:   %get-tuple-element.11 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=0
[rank0]:   ROOT %tuple.13 = (f32[20,10]{1,0}, f32[20,10]{1,0}) tuple(f32[20,10]{1,0} %get-tuple-element.11, f32[20,10]{1,0} %get-tuple-element.11)
[rank0]: }

Environment

  • Reproducible on XLA backend CPU:
  • torch_xla version 2.5.0 and 2.6.0 (I tried both):

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdistributedSPMD and other distributed things.stablehloStableHLO related work

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions