From 57d1de85571217deca6e4f8eb1227a646cb3d4a1 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 21 Feb 2024 09:12:13 -0800 Subject: [PATCH] [XLA:GPU][Coalescing] Fix a typo: operands index was passed instead of root index in thread_id map computation for reduce. PiperOrigin-RevId: 609017178 --- .../xla/xla/service/gpu/fusions/reduction.cc | 2 +- .../gpu/model/coalescing_analysis_test.cc | 35 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index 39d7e956195dd9..fb48625c0871ac 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -1453,7 +1453,7 @@ std::optional ReductionFusion::ComputeThreadIdToInputIndexing( mlir::MLIRContext* ctx) const { const auto& groups = reduction_codegen_info_.GetIndexGroups(); - auto* hero = analysis_.fusion_heroes()[hero_operand_index]; + auto* hero = analysis_.fusion_heroes()[root_index]; if (groups.is_reduction_root[root_index] && hero_operand_index >= hero->operand_count() / 2) { // We don't have indexing for the init values. diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index f612cf48cf8865..5b250330bb8120 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -278,7 +278,7 @@ TEST_F(CoalescingTest, ColumnReduction) { EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); } -TEST_F(CoalescingTest, VariadicReduce) { +TEST_F(CoalescingTest, VariadicReduceViaLoopEmitter) { absl::string_view ir = R"( HloModule module max { @@ -310,6 +310,39 @@ TEST_F(CoalescingTest, VariadicReduce) { ElementsAre(true, true, true, true)); } +TEST_F(CoalescingTest, VariadicReduceViaReductionEmitter) { + absl::string_view ir = R"( + HloModule module + max { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + max01 = s32[] maximum(p0, p1) + max23 = s32[] maximum(p2, p3) + ROOT max = (s32[], s32[]) tuple(max01, max23) + } + fusion { + p0 = s32[32,40] parameter(0) + p1 = s32[32,40] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT reduce = (s32[32], s32[32]) + reduce(s32[32,40] p0, s32[32,40] p1, s32[] p2, s32[] p3), + dimensions={1}, to_apply=max + } + ENTRY entry { + p0 = s32[32,40] parameter(0) + p1 = s32[32,40] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT f = (s32[32], s32[32]) fusion(p0, p1, p2, p3), + kind=kInput, calls=fusion + })"; + EXPECT_THAT(IsReadCoalescedPerOperand(ir), + ElementsAre(true, true, true, true)); +} + TEST_F(CoalescingTest, UnusedParameter) { Shape shape = ShapeUtil::MakeShape(F32, {100000});