Skip to content

Commit

Permalink
[XLA] remove degenerate indexing dimensions in algebraic simplifier
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642414848
  • Loading branch information
blakehechtman authored and tensorflower-gardener committed Jun 11, 2024
1 parent c0761a4 commit dd6e541
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
3 changes: 3 additions & 0 deletions third_party/xla/xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,9 @@ class HloGatherInstruction : public HloInstruction {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
GatherDimensionNumbers* mutable_gather_dimension_numbers() {
return gather_dimension_numbers_.get();
}
absl::Span<const int64_t> gather_slice_sizes() const {
return gather_slice_sizes_;
}
Expand Down
16 changes: 16 additions & 0 deletions third_party/xla/xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4163,6 +4163,22 @@ absl::Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
}
}
}

if (gather->gather_dimension_numbers().index_vector_dim() <
gather->operand(1)->shape().rank() &&
gather->gather_dimension_numbers().start_index_map_size() == 1) {
Shape updated_shape = ShapeUtil::DeleteDimension(
gather->gather_dimension_numbers().index_vector_dim(),
gather->operand(1)->shape());
Cast<HloGatherInstruction>(gather)
->mutable_gather_dimension_numbers()
->set_index_vector_dim(updated_shape.rank());
TF_RETURN_IF_ERROR(gather->ReplaceOperandWithDifferentShape(
1, gather->mutable_operand(1)->AddInstruction(
HloInstruction::CreateReshape(updated_shape,
gather->mutable_operand(1)))));
MarkAsChanged();
}
return absl::OkStatus();
}

Expand Down
25 changes: 25 additions & 0 deletions third_party/xla/xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7842,6 +7842,31 @@ INSTANTIATE_TEST_SUITE_P(
DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));

TEST_F(AlgebraicSimplifierTest, GatherWithDegenerateIndex) {
const char* hlo_string = R"(
HloModule repeat
ENTRY main {
o = f32[25,25] parameter(0)
i = s32[1,100] parameter(1)
ROOT g = f32[100,25] gather(o, i), collapsed_slice_dims={0},
start_index_map={0},
index_vector_dim=0,
offset_dims={1},
slice_sizes={1,25}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

AlgebraicSimplifierOptions options;
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).value());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Gather(m::Parameter(0),
m::Reshape(m::Parameter(1)))));
}

TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
const char* hlo_string = R"(
HloModule repeat
Expand Down

0 comments on commit dd6e541

Please sign in to comment.