diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc index 1f670664434c02..ed44547af3fca4 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/service/while_loop_invariant_code_motion.h" +#include +#include +#include +#include + #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -32,7 +37,6 @@ limitations under the License. #include "xla/service/while_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -112,10 +116,14 @@ static void CreateLoopInvariantCopy( } // Returns true if `instruction` is worth hoisting only if it lets us hoist some -// instruction using it. The rationale is that hoisting these instructions will -// prevent simplification and fusion in the while body. +// instruction using it. The rationale is that hoisting these instructions will +// prevent simplification, fusion, and sharding annotation in the while body. bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( const HloInstruction& instruction) { + if (instruction.IsCustomCall("Sharding")) { + return true; + } + switch (instruction.opcode()) { default: return false; diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc index 5a9a35e31fbe9b..57f2768b458c0e 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc @@ -639,7 +639,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) { EXPECT_FALSE(simplified_loop); } -TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) { +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistSPMDFullToShardShape) { auto m = CreateNewVerifiedModule(); auto array_s32 = ShapeUtil::MakeShape(S32, {4}); Shape while_shape = @@ -690,5 +690,43 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) { EXPECT_FALSE(simplified_loop); } +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], f32[2], s32[]) parameter(0) + gte.0 = f32[2] get-tuple-element(p_body), index=0 + gte.1 = f32[2] get-tuple-element(p_body), index=1 + sharding.0 = f32[2] custom-call(gte.0), custom_call_target="Sharding", sharding={devices=[2]<=[2]} + sharding.1 = f32[2] custom-call(gte.1), custom_call_target="Sharding", sharding={replicated} + add.0 = f32[2] add(sharding.0, sharding.1) + gte.2 = s32[] get-tuple-element(p_body), index=2 + const = s32[] constant(1) + add.1 = s32[] add(gte.2, const) + ROOT root = (f32[2], f32[2], s32[]) tuple(gte.0, add.0, add.1) + } + + condition { + p_cond = (f32[2], f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=2 + const = s32[] constant(5) + ROOT result = pred[] compare(gte, const), direction=LT + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], f32[2], s32[]) tuple(param.0, param.0, param.1) + ROOT while = (f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(module.get())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla