Skip to content

Commit

Permalink
Fix fusion parameter resolving in indexing analysis
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626011888
  • Loading branch information
beckerhe authored and tensorflower-gardener committed Apr 18, 2024
1 parent 0a985a4 commit 16cd582
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
22 changes: 22 additions & 0 deletions third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc
Expand Up @@ -488,6 +488,28 @@ TEST_F(CoalescingTest, UnusedParameter) {
ElementsAre(true, true));
}

TEST_F(CoalescingTest, Param) {
absl::string_view ir = R"(
HloModule module
fusion {
%p0 = u32[48,2,1280] parameter(0)
%p1 = u32[48,1,1280] parameter(1)
%p2 = u32[48,1,1280] parameter(2)
%concat = u32[48,2,1280] concatenate(u32[48,1,1280] %p1,
u32[48,1,1280] %p2), dimensions={1}
ROOT %shift = u32[48,2,1280] shift-right-logical(
u32[48,2,1280] %concat, u32[48,2,1280] %p0)
}
ENTRY entry {
%p0 = u32[48,2,1280] parameter(0)
%p1 = u32[48,1,1280] parameter(1)
%p2 = u32[48,1,1280] parameter(2)
ROOT %fusion = u32[48,2,1280] fusion(p0, p1, p2), kind=kLoop, calls=fusion
})";
// thread_x to linearized input mapping for thread_x in [0, 31]:
EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, true, true));
}

} // namespace
} // namespace gpu
} // namespace xla
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/gpu/model/indexing_analysis.cc
Expand Up @@ -1320,10 +1320,10 @@ GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing(
if (fusion_adaptor.ContainsInstruction(target_instr)) {
if (auto parameter_instr =
DynCast<HloParameterInstruction>(&target_instr.instruction())) {
const HloInstruction* user = parameter_instr->users().front();
auto fusion_operand = HloInstructionAdaptor(*user).GetOperand(
parameter_instr->parameter_number());
grouped_indexing_maps[&fusion_operand.instruction()] = {initial_map};
auto fusion_instr = parameter_instr->parent()->FusionInstruction();
auto fusion_operand =
fusion_instr->operand(parameter_instr->parameter_number());
grouped_indexing_maps[fusion_operand] = {initial_map};
return grouped_indexing_maps;
}
}
Expand Down

0 comments on commit 16cd582

Please sign in to comment.