From 6dd61239c84900e79aa4f6814fa923ad869ab599 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 21 May 2024 19:12:18 -0700 Subject: [PATCH] Move host-compute-offloading wrapping after SPMD sharding to shard calls executed on host PiperOrigin-RevId: 635998816 --- third_party/xla/xla/service/spmd/spmd_partitioner.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index a487b00025362f..c9075863bbc374 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -2283,9 +2283,11 @@ Status SpmdPartitioningVisitor::HandleCall(HloInstruction* hlo) { call_graph_) .status()); SetPartitionedHlo(hlo, [&] { - return b_.AddInstruction(HloInstruction::CreateCall( + auto* call = b_.AddInstruction(HloInstruction::CreateCall( MakePartitionedShape(hlo->shape(), hlo->sharding()), call_args, hlo->called_computations()[0])); + call->set_raw_backend_config_string(hlo->raw_backend_config_string()); + return call; }); return OkStatus(); }