Skip to content

Commit

Permalink
[XLA:GPU][Coalescing] Fix a typo: operands index was passed instead o…
Browse files Browse the repository at this point in the history
…f root index in thread_id map computation for reduce.

PiperOrigin-RevId: 609017178
  • Loading branch information
pifon2a authored and tensorflower-gardener committed Feb 21, 2024
1 parent a4e3b78 commit 57d1de8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/fusions/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1453,7 +1453,7 @@ std::optional<IndexingMap> ReductionFusion::ComputeThreadIdToInputIndexing(
mlir::MLIRContext* ctx) const {
const auto& groups = reduction_codegen_info_.GetIndexGroups();

auto* hero = analysis_.fusion_heroes()[hero_operand_index];
auto* hero = analysis_.fusion_heroes()[root_index];
if (groups.is_reduction_root[root_index] &&
hero_operand_index >= hero->operand_count() / 2) {
// We don't have indexing for the init values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ TEST_F(CoalescingTest, ColumnReduction) {
EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true));
}

TEST_F(CoalescingTest, VariadicReduce) {
TEST_F(CoalescingTest, VariadicReduceViaLoopEmitter) {
absl::string_view ir = R"(
HloModule module
max {
Expand Down Expand Up @@ -310,6 +310,39 @@ TEST_F(CoalescingTest, VariadicReduce) {
ElementsAre(true, true, true, true));
}

TEST_F(CoalescingTest, VariadicReduceViaReductionEmitter) {
absl::string_view ir = R"(
HloModule module
max {
p0 = s32[] parameter(0)
p1 = s32[] parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
max01 = s32[] maximum(p0, p1)
max23 = s32[] maximum(p2, p3)
ROOT max = (s32[], s32[]) tuple(max01, max23)
}
fusion {
p0 = s32[32,40] parameter(0)
p1 = s32[32,40] parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
ROOT reduce = (s32[32], s32[32])
reduce(s32[32,40] p0, s32[32,40] p1, s32[] p2, s32[] p3),
dimensions={1}, to_apply=max
}
ENTRY entry {
p0 = s32[32,40] parameter(0)
p1 = s32[32,40] parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
ROOT f = (s32[32], s32[32]) fusion(p0, p1, p2, p3),
kind=kInput, calls=fusion
})";
EXPECT_THAT(IsReadCoalescedPerOperand(ir),
ElementsAre(true, true, true, true));
}

TEST_F(CoalescingTest, UnusedParameter) {
Shape shape = ShapeUtil::MakeShape(F32, {100000});

Expand Down

0 comments on commit 57d1de8

Please sign in to comment.