Skip to content

Commit

Permalink
Add Slice in IsSupportedConstantExpression() in `xla::HloConstantSp…
Browse files Browse the repository at this point in the history
…litter`.

Slicing a constant expression is also a constant expression.

PiperOrigin-RevId: 631567220
  • Loading branch information
ZixuanJiang authored and tensorflower-gardener committed May 7, 2024
1 parent 27a713f commit f9de342
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
16 changes: 13 additions & 3 deletions third_party/xla/xla/hlo/transforms/hlo_constant_splitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,19 @@ bool IsSupportedConstant(const HloInstruction* instruction,
// Return if this is one of the constant expressions that we consider for
// duplication.
bool IsSupportedConstantExpression(const HloInstruction* instruction) {
return !instruction->HasSideEffect() &&
(instruction->opcode() == HloOpcode::kBroadcast ||
instruction->IsElementwise());
if (instruction->HasSideEffect()) {
return false;
}
if (instruction->IsElementwise()) {
return true;
}
switch (instruction->opcode()) {
case HloOpcode::kBroadcast:
case HloOpcode::kSlice:
return true;
default:
return false;
}
}

// Perform duplication of a certain constant expression and replace the
Expand Down
32 changes: 31 additions & 1 deletion third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ TEST_F(HloConstantSplitterTest, PreservingConstantsWithZeroUsers) {
EXPECT_FALSE(status_or.value());
}

TEST_F(HloConstantSplitterTest, SplittingExpressions) {
TEST_F(HloConstantSplitterTest, SplittingExpressionsWithBroadcast) {
const char* module_str = R"(
HloModule test_module
Expand Down Expand Up @@ -121,6 +121,36 @@ TEST_F(HloConstantSplitterTest, SplittingExpressions) {
EXPECT_EQ(module->entry_computation()->instruction_count(), 23);
}

TEST_F(HloConstantSplitterTest, SplittingExpressionsWithSlice) {
const char* module_str = R"(
HloModule test_module
ENTRY entry_computation {
iota.0 = u32[64] iota(), iota_dimension=0
slice.0 = u32[32] slice(iota.0), slice={[0:32]}
broadcast.0 = u32[16,32] broadcast(slice.0), dimensions={1}
broadcast.1 = u32[32,32] broadcast(slice.0), dimensions={1}
p.0 = u32[16,32] parameter(0)
p.1 = u32[32,32] parameter(1)
add.0 = u32[16,32] add(p.0, broadcast.0)
add.1 = u32[32,32] add(p.1, broadcast.1)
ROOT root = (u32[16,32], u32[32,32]) tuple(add.0, add.1)
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(module_str));
HloConstantSplitter pass = HloConstantSplitter(/*split_expressions=*/true);
const auto status_or = HloTestBase::RunHloPass(&pass, module.get());
TF_ASSERT_OK(status_or.status());
// Verify that the changed flag returned is correct.
EXPECT_TRUE(status_or.value());
HloDCE dce;
TF_ASSERT_OK(dce.Run(module.get()).status());
XLA_VLOG_LINES(1, module->entry_computation()->ToString());
EXPECT_EQ(module->entry_computation()->instruction_count(), 11);
}

TEST_F(HloConstantSplitterTest, NoSplittingSideEffectExpressions) {
const char* module_str = R"(
HloModule test_module
Expand Down

0 comments on commit f9de342

Please sign in to comment.