-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA] Allow the tuple simplifier to operate on only subcomputations #19769
Changes from all commits
909a332
8b4cca1
4e8eeaf
a62acbd
0ed43ff
76b5f20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase { | |
TF_ASSERT_OK(changed_status.status()); | ||
EXPECT_EQ(change_expected, changed_status.ValueOrDie()); | ||
} | ||
void Run(HloModule* module, bool change_expected, bool exclude_entry) { | ||
TupleSimplifier simplifier(exclude_entry); | ||
auto changed_status = simplifier.Run(module); | ||
TF_ASSERT_OK(changed_status.status()); | ||
EXPECT_EQ(change_expected, changed_status.ValueOrDie()); | ||
} | ||
|
||
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); | ||
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( | ||
|
@@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { | |
EXPECT_THAT(computation->root_instruction(), tuple); | ||
} | ||
|
||
TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { | ||
// Verify that the root computation can be excluded | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit, move this comment up one line? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the other tests in the file seem to have their comments following the test declaration. i think it is best to leave it here. |
||
auto module = CreateNewModule(); | ||
|
||
HloInstruction* p0; | ||
HloInstruction* p1; | ||
HloComputation* c0; | ||
HloComputation* c1; | ||
HloComputation* entry; | ||
|
||
{ | ||
HloComputation::Builder builder(TestName() + "_1"); | ||
p0 = builder.AddInstruction( | ||
HloInstruction::CreateParameter(0, tuple_shape_, "param")); | ||
HloInstruction* gte0 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0)); | ||
HloInstruction* gte1 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1)); | ||
HloInstruction* gte2 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2)); | ||
|
||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); | ||
|
||
c0 = module->AddEmbeddedComputation(builder.Build()); | ||
} | ||
{ | ||
HloComputation::Builder builder(TestName() + "_2"); | ||
p1 = builder.AddInstruction( | ||
HloInstruction::CreateParameter(0, tuple_shape_, "param")); | ||
HloInstruction* gte0 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0)); | ||
HloInstruction* gte1 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1)); | ||
HloInstruction* gte2 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2)); | ||
|
||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); | ||
|
||
c1 = module->AddEmbeddedComputation(builder.Build()); | ||
} | ||
{ | ||
HloComputation::Builder builder(TestName() + "_Entry"); | ||
HloInstruction* tuple_param = builder.AddInstruction( | ||
HloInstruction::CreateParameter(0, tuple_shape_, "param")); | ||
HloInstruction* call0 = builder.AddInstruction( | ||
HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0)); | ||
HloInstruction* call1 = builder.AddInstruction( | ||
HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1)); | ||
HloInstruction* gte0 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0)); | ||
HloInstruction* gte1 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1)); | ||
HloInstruction* tuple0 = | ||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); | ||
HloInstruction* gte2 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0)); | ||
HloInstruction* gte3 = builder.AddInstruction( | ||
HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1)); | ||
|
||
builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3})); | ||
|
||
entry = module->AddEntryComputation(builder.Build()); | ||
} | ||
|
||
Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our style is no space between end of the comment and (Sorry... :-/ I believe clang-format will catch this, and it looks like we have open-sourced a |
||
|
||
EXPECT_THAT(c0->root_instruction(), p0); | ||
EXPECT_THAT(c1->root_instruction(), p1); | ||
EXPECT_THAT(entry->instruction_count(), 9); | ||
} | ||
|
||
} // namespace | ||
} // namespace xla |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to change it, how you have it here is totally fine, but FYI most of us enjoy writing tests using the textual format. (Update your tree and grep for
ParseHloString
.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have some tests which use that text string processing thing, but I'm not sure how you can refer to the pre-optimized instructions at post-optimization check time.
i suppose you could walk up the tree from the root and check the opcode or metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see hlo_matchers.h.