Skip to content

Commit

Permalink
[XLA:AlgebraicSimplifier] Bugfix on gather simplification.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624207589
  • Loading branch information
Tongfei-Guo authored and tensorflower-gardener committed Apr 12, 2024
1 parent 42e7261 commit 7ecd526
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
10 changes: 5 additions & 5 deletions third_party/xla/xla/service/algebraic_simplifier.cc
Expand Up @@ -4006,17 +4006,17 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
bool padded_on_reshape_unmodified_dims = true;
bool padded_on_gather_operand_passthrough_operand_dims = true;
std::vector<int64_t> padded_dims = GetPaddedDims(pad);
absl::flat_hash_map<int64_t, int64_t> reshape_dims_to_padded_dims;
for (int64_t padded_dim : padded_dims) {
reshape_dims_to_padded_dims[reshape_unmodified_dims[padded_dim]] =
padded_dim;
}
for (int64_t padded_dim : padded_dims) {
if (!reshape_unmodified_dims.contains(padded_dim)) {
padded_on_reshape_unmodified_dims = false;
break;
}
}
absl::flat_hash_map<int64_t, int64_t> reshape_dims_to_padded_dims;
for (int64_t padded_dim : padded_dims) {
reshape_dims_to_padded_dims[reshape_unmodified_dims[padded_dim]] =
padded_dim;
}
for (auto& [padded_reshape_dim, _] : reshape_dims_to_padded_dims) {
if (!gather_operand_passthrough_operand_to_output_dims.contains(
padded_reshape_dim)) {
Expand Down
27 changes: 27 additions & 0 deletions third_party/xla/xla/service/algebraic_simplifier_test.cc
Expand Up @@ -7969,6 +7969,33 @@ ENTRY %entry {
m::ConstantScalar(0))));
}

TEST_F(AlgebraicSimplifierTest, GatherOfReshapeOfPad3) {
const char* hlo_string = R"(
HloModule module
ENTRY %entry {
parameter.0 = f32[2,4256]{1,0} parameter(0)
constant = f32[] constant(0)
pad.264 = f32[2,4480]{1,0} pad(parameter.0, constant), padding=0_0x0_224
slice.267 = f32[2,4480]{1,0} slice(pad.264), slice={[0:2], [0:4480]}
reshape.269 = f32[2,28,160]{2,1,0} reshape(slice.267)
parameter.1 = s32[27,2]{1,0} parameter(1)
ROOT gather.271 = f32[2,27,2,160]{3,2,1,0} gather(reshape.269, parameter.1), offset_dims={0,3}, collapsed_slice_dims={1}, start_index_map={1}, index_vector_dim=2, slice_sizes={2,1,160}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

AlgebraicSimplifierOptions options;
AlgebraicSimplifier simplifier(options);
VLOG(2) << "After rewrite \n" << module->ToString();
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
GmockMatch(m::Gather(
m::Reshape(m::Slice(m::Pad(m::Parameter(0), m::ConstantScalar(0)))),
m::Parameter(1))));
}

TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
const char* hlo_string = R"(
HloModule module
Expand Down

0 comments on commit 7ecd526

Please sign in to comment.