-
Notifications
You must be signed in to change notification settings - Fork 565
Closed
Closed
Copy link
Labels
bugSomething isn't workingSomething isn't workingdistributedSPMD and other distributed things.SPMD and other distributed things.stablehloStableHLO related workStableHLO related work
Description
🐛 Bug
As the title suggests, torch.distributed.all_reduce is not being converted to stableHLO.
To Reproduce
I run the following test:
import os
import torch
from torch import nn
import torch
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
def test():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
model = Basic()
prog = export(model, (torch.rand(20, 10), ))
shlo = exported_program_to_stablehlo(prog)
print(shlo.get_stablehlo_text())
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
torch.distributed.init_process_group(world_size=1, rank=0)
if __name__ == "__main__":
test()Expected behavior
I would expect a stableHLO module with all_reduce, however I get the following error:
WARNING:root:Defaulting to PJRT_DEVICE=CPU
loc("all-reduce.10"): error: failed to legalize operation 'mhlo.all_reduce' that was explicitly marked illegal
[rank0]: Traceback (most recent call last):
[rank0]: File "/localdev/aknezevic/xt/test_mp.py", line 28, in <module>
[rank0]: test()
[rank0]: File "/localdev/aknezevic/xt/test_mp.py", line 19, in test
[rank0]: shlo = exported_program_to_stablehlo(prog)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 626, in exported_program_to_stablehlo
[rank0]: bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 405, in _exported_program_to_stablehlo_bundle
[rank0]: stablehlo_content = xm.get_stablehlo_bytecode(res)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 1103, in get_stablehlo_bytecode
[rank0]: return torch_xla._XLAC._get_stablehlo(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: torch_xla/csrc/runtime/stablehlo_helper.cc:109 : Check failed: status.ok()
[rank0]: *** Begin stack trace ***
[rank0]: tsl::CurrentStackTrace()
[rank0]: torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
[rank0]: torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool)
[rank0]: torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
[rank0]: torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
[rank0]:
[rank0]:
[rank0]:
[rank0]:
[rank0]: _PyObject_MakeTpCall
[rank0]: _PyEval_EvalFrameDefault
[rank0]:
[rank0]: PyEval_EvalCode
[rank0]:
[rank0]:
[rank0]:
[rank0]: _PyRun_SimpleFileObject
[rank0]: _PyRun_AnyFileObject
[rank0]: Py_RunMain
[rank0]: Py_BytesMain
[rank0]:
[rank0]: __libc_start_main
[rank0]: _start
[rank0]: *** End stack trace ***
[rank0]: MHLO -> StableHLO conversion failed.
[rank0]: StableHLO Module from MHLO -> StableHLO conversion is not leagal.Please open a github issue to PyTorch/XLA.
[rank0]: Original HLO dump:
[rank0]: HloModule IrToHlo.14, entry_computation_layout={(f32[], f32[20,10]{1,0})->(f32[20,10]{1,0}, f32[20,10]{1,0})}
[rank0]: %AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
[rank0]: %x.7 = f32[] parameter(0)
[rank0]: %y.8 = f32[] parameter(1)
[rank0]: ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
[rank0]: }
[rank0]: ENTRY %IrToHlo.14 (p0.1: f32[], p1.2: f32[20,10]) -> (f32[20,10], f32[20,10]) {
[rank0]: %p1.2 = f32[20,10]{1,0} parameter(1)
[rank0]: %p0.1 = f32[] parameter(0)
[rank0]: %tuple.3 = (f32[20,10]{1,0}, f32[]) tuple(f32[20,10]{1,0} %p1.2, f32[] %p0.1)
[rank0]: %get-tuple-element.4 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=0
[rank0]: %get-tuple-element.5 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=1
[rank0]: %all-reduce.10 = (f32[20,10]{1,0}, f32[]) all-reduce(f32[20,10]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.5), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.6
[rank0]: %get-tuple-element.12 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=1
[rank0]: %get-tuple-element.11 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=0
[rank0]: ROOT %tuple.13 = (f32[20,10]{1,0}, f32[20,10]{1,0}) tuple(f32[20,10]{1,0} %get-tuple-element.11, f32[20,10]{1,0} %get-tuple-element.11)
[rank0]: }Environment
- Reproducible on XLA backend CPU:
- torch_xla version 2.5.0 and 2.6.0 (I tried both):
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdistributedSPMD and other distributed things.SPMD and other distributed things.stablehloStableHLO related workStableHLO related work