Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Allow the tuple simplifier to operate on only subcomputations #19769

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions tensorflow/compiler/xla/service/tuple_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@ limitations under the License.

namespace xla {

TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) :
exclude_entry_computation_(exclude_entry_computation) {}

StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Initially add all GTE and Tuple instructions to the worklist.
std::queue<HloInstruction*> worklist;
for (auto* computation : module->computations()) {
if (exclude_entry_computation_ &&
computation == module->entry_computation()) {
continue;
}
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kTuple ||
instruction->opcode() == HloOpcode::kGetTupleElement) {
Expand Down
9 changes: 8 additions & 1 deletion tensorflow/compiler/xla/service/tuple_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,20 @@ namespace xla {
// the module.
class TupleSimplifier : public HloPassInterface {
public:
TupleSimplifier() {}
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
~TupleSimplifier() override {}
tensorflow::StringPiece name() const override { return "tuple-simplifier"; }

// Run tuple simplification on the given computation. Returns whether the
// computation was changed.
StatusOr<bool> Run(HloModule* module) override;

private:
// When set, this pipeline stage will perform optimization of all computations
// apart from the module's entry computation. This is used by Graphcore's
// backend.
bool exclude_entry_computation_;
};

} // namespace xla
Expand Down
77 changes: 77 additions & 0 deletions tensorflow/compiler/xla/service/tuple_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase {
TF_ASSERT_OK(changed_status.status());
EXPECT_EQ(change_expected, changed_status.ValueOrDie());
}
void Run(HloModule* module, bool change_expected, bool exclude_entry) {
TupleSimplifier simplifier(exclude_entry);
auto changed_status = simplifier.Run(module);
TF_ASSERT_OK(changed_status.status());
EXPECT_EQ(change_expected, changed_status.ValueOrDie());
}

const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
Expand Down Expand Up @@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
}

TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change it, how you have it here is totally fine, but FYI most of us enjoy writing tests using the textual format. (Update your tree and grep for ParseHloString.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have some tests which use that text string processing thing, but I'm not sure how you can refer to the pre-optimized instructions at post-optimization check time.

i suppose you could walk up the tree from the root and check the opcode or metadata.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i suppose you could walk up the tree from the root and check the opcode or metadata.

Yes, see hlo_matchers.h.

// Verify that the root computation can be excluded
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, move this comment up one line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the other tests in the file seem to have their comments following the test declaration. i think it is best to leave it here.

auto module = CreateNewModule();

HloInstruction* p0;
HloInstruction* p1;
HloComputation* c0;
HloComputation* c1;
HloComputation* entry;

{
HloComputation::Builder builder(TestName() + "_1");
p0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape_, "param"));
HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0));
HloInstruction* gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1));
HloInstruction* gte2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2));

builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));

c0 = module->AddEmbeddedComputation(builder.Build());
}
{
HloComputation::Builder builder(TestName() + "_2");
p1 = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape_, "param"));
HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0));
HloInstruction* gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1));
HloInstruction* gte2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2));

builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));

c1 = module->AddEmbeddedComputation(builder.Build());
}
{
HloComputation::Builder builder(TestName() + "_Entry");
HloInstruction* tuple_param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape_, "param"));
HloInstruction* call0 = builder.AddInstruction(
HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0));
HloInstruction* call1 = builder.AddInstruction(
HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1));
HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0));
HloInstruction* gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1));
HloInstruction* tuple0 =
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
HloInstruction* gte2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0));
HloInstruction* gte3 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1));

builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3}));

entry = module->AddEntryComputation(builder.Build());
}

Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our style is no space between end of the comment and true.

(Sorry... :-/ I believe clang-format will catch this, and it looks like we have open-sourced a .clang-format file for XLA. If you use clang-format it's very important to configure it only to format the lines you actually changed. Otherwise you'll have tons of spurious whitespace changes to the file -- its format is not stable. The git-clang-format script in LLVM is one way to do this. https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/git-clang-format)


EXPECT_THAT(c0->root_instruction(), p0);
EXPECT_THAT(c1->root_instruction(), p1);
EXPECT_THAT(entry->instruction_count(), 9);
}

} // namespace
} // namespace xla