Skip to content

Commit

Permalink
Reverts 49a561f
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628367605
  • Loading branch information
akuegel authored and tensorflower-gardener committed Apr 26, 2024
1 parent 83a53c0 commit 08fdea6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 168 deletions.
84 changes: 24 additions & 60 deletions third_party/xla/xla/hlo/ir/hlo_instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2322,8 +2322,6 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
// it's only added when inserting to the computation.
absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
std::vector<HloInstruction*> unfused_instructions;
absl::flat_hash_set<const HloInstruction*> new_roots;
std::vector<std::pair<HloInstruction*, int64_t>> old_fusion_outputs;
auto computation_to_merge =
instruction_to_merge->fused_instructions_computation();
for (auto fused_instruction :
Expand All @@ -2334,30 +2332,6 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
fused_instruction->parameter_number()));
continue;
}
// If 'instruction_to_merge' is a multi-output fusion, we need to skip the
// root tuple, but remember which of the fusion outputs need to become
// fusion outputs of the merged fusion.
if (fused_instruction->opcode() == HloOpcode::kTuple &&
fused_instruction == instruction_to_merge->fused_expression_root()) {
for (const HloInstruction* user : instruction_to_merge->users()) {
CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement);
old_fusion_outputs.emplace_back(
fused_instruction->mutable_operand(user->tuple_index()),
user->tuple_index());
bool has_outside_user = false;
for (HloInstruction* gte_user : user->users()) {
if (gte_user != this) {
has_outside_user = true;
break;
}
}
if (has_outside_user) {
new_roots.insert(
FindOrDie(old_to_new, old_fusion_outputs.back().first));
}
}
continue;
}

// Here we clone the insertion and call FuseInstructionIntoMultiOutput()
// which clones again. This can be improved.
Expand All @@ -2372,45 +2346,35 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
unfused_instructions.push_back(cloned_instruction);
InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
}
if (instruction_to_merge->IsMultiOutputFusion()) {
for (auto [old_root, tuple_index] : old_fusion_outputs) {
auto new_root = FindOrDie(old_to_new, old_root);
// Replace the get-tuple-element op on 'instruction_to_merge' referencing
// the same tuple index as 'old_root' with 'new_root'.
for (HloInstruction* gte : instruction_to_merge->users()) {
if (gte->opcode() == HloOpcode::kGetTupleElement &&
gte->tuple_index() == tuple_index) {
TF_CHECK_OK(gte->ReplaceAllUsesWith(new_root));
TF_CHECK_OK(gte->parent()->RemoveInstruction(gte));
}
}
}
} else {
// If there are no unfused instructions, the fused computation must consist
// only of kParameter instructions. Make the operand of the corresponding
// parameter number the new root.
HloInstruction* unfused_root =
unfused_instructions.empty()
? instruction_to_merge->mutable_operand(
instruction_to_merge->fused_instructions_computation()
->root_instruction()
->parameter_number())
: unfused_instructions.back();
new_roots.insert(unfused_root);
TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
}

// If there are no unfused instructions, the fused computation must consist
// only of kParameter instructions. Make the operand of the corresponding
// parameter number the new root.
HloInstruction* unfused_root =
unfused_instructions.empty()
? instruction_to_merge->mutable_operand(
instruction_to_merge->fused_instructions_computation()
->root_instruction()
->parameter_number())
: unfused_instructions.back();
TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));

TF_CHECK_OK(
instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
if (GetModule()) {
TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
}
for (int64_t i = unfused_instructions.size() - 1; i >= 0; --i) {
HloInstruction* instruction = unfused_instructions[i];
if (new_roots.contains(instruction)) {
FuseInstructionIntoMultiOutput(instruction);
} else {
FuseInstruction(instruction);
}

// Fuse the root instruction and generate multiple outputs.
if (unfused_instructions.empty()) {
return;
}
FuseInstructionIntoMultiOutput(unfused_root);
TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
// The rest instructions are of normal fusing.
for (int64_t i = unfused_instructions.size() - 2; i >= 0; --i) {
auto instruction = unfused_instructions[i];
FuseInstruction(instruction);
TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
}
}
Expand Down
2 changes: 0 additions & 2 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -880,8 +880,6 @@ xla_cc_test(
srcs = ["hlo_instruction_test.cc"],
tags = ["no_aarch64"],
deps = [
":pattern_matcher",
":pattern_matcher_gmock",
"//xla:literal",
"//xla:protobuf_util",
"//xla:shape_util",
Expand Down
44 changes: 0 additions & 44 deletions third_party/xla/xla/service/gpu/multi_output_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,50 +590,6 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add())));
}

TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingMultiOutputLoopAndMultiOutputLoop) {
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[8,16]{1,0} parameter(0)
mul = f32[8,16]{1,0} multiply(p0.1, p0.1)
exp = f32[8,16]{1,0} exponential(p0.1)
ROOT tuple = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(mul, exp)
}
fused_computation_2 {
p0.2 = f32[8,16]{1,0} parameter(0)
const.2 = f32[] constant(0)
broadcast = f32[8,16]{1,0} broadcast(const.2),
dimensions={}
add = f32[8,16]{1,0} add(p0.2, broadcast)
ROOT tuple.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(add, broadcast)
}
ENTRY entry {
p0 = f32[8,16]{1,0} parameter(0)
fusion.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
calls=fused_computation_1
fusion.2 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
calls=fused_computation_2
gte0 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=0
gte1 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=1
gte2 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=0
gte3 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=1
ROOT root = (f32[8,16]{1,0}, f32[8,16]{1,0}, f32[8,16]{1,0},
f32[8,16]{1,0})
tuple(gte0, gte1, gte2, gte3)
})"))
.value();
ASSERT_TRUE(mof_.Run(module.get()).value());
SCOPED_TRACE(module->ToString());
const HloInstruction* fusion =
module->entry_computation()->root_instruction()->operand(0)->operand(0);
ASSERT_TRUE(fusion->IsMultiOutputFusion());
EXPECT_THAT(
fusion->fused_expression_root(),
GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add(), m::Broadcast())));
}

TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
Expand Down
62 changes: 0 additions & 62 deletions third_party/xla/xla/service/hlo_instruction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/protobuf_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/test_helpers.h"
Expand All @@ -44,8 +42,6 @@ limitations under the License.
namespace xla {
namespace {

namespace m = ::xla::match;

using ::testing::ElementsAre;
using ::testing::UnorderedElementsAre;

Expand Down Expand Up @@ -2875,63 +2871,5 @@ TEST_F(HloInstructionTest, BackendConfigNotCopiedToDerivedWithDiffOpcode) {
EXPECT_FALSE(add2->has_backend_config());
}

TEST_F(HloInstructionTest,
MergeMultiOutputProducerFusionIntoMultiOutputFusion) {
const std::string& hlo_string = R"(
HloModule mof
mof_producer {
param0 = f32[10]{0} parameter(0)
param1 = f32[10]{0} parameter(1)
add = f32[10]{0} add(param0, param1)
sub = f32[10]{0} subtract(param0, param1)
ROOT res = (f32[10]{0}, f32[10]{0}, f32[10]{0}, f32[10]{0}) tuple(param1, add, sub, param0)
}
mof_consumer {
param0.0 = f32[10]{0} parameter(0)
param1.0 = f32[10]{0} parameter(1)
param2.0 = f32[10]{0} parameter(2)
mul = f32[10]{0} multiply(param0.0, param1.0)
div = f32[10]{0} divide(param0.0, param1.0)
ROOT res = (f32[10]{0}, f32[10]{0}, f32[10]{0}) tuple(mul, div, param2.0)
}
ENTRY main {
p0 = f32[10]{0} parameter(0)
p1 = f32[10]{0} parameter(1)
producer = (f32[10]{0}, f32[10]{0}, f32[10]{0}, f32[10]{0}) fusion(p0, p1), kind=kLoop, calls=mof_producer
gte0 = f32[10]{0} get-tuple-element(producer), index=0
gte1 = f32[10]{0} get-tuple-element(producer), index=1
gte2 = f32[10]{0} get-tuple-element(producer), index=2
gte3 = f32[10]{0} get-tuple-element(producer), index=3
consumer = (f32[10]{0}, f32[10]{0}, f32[10]{0}) fusion(gte1, gte2, gte3), kind=kLoop, calls=mof_consumer
gte4 = f32[10]{0} get-tuple-element(consumer), index=0
gte5 = f32[10]{0} get-tuple-element(consumer), index=1
gte6 = f32[10]{0} get-tuple-element(consumer), index=2
ROOT res = tuple(gte0, gte1, gte3, gte4, gte5, gte6)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
HloInstruction* producer = FindInstruction(module.get(), "producer");
HloInstruction* consumer = FindInstruction(module.get(), "consumer");
consumer->MergeFusionInstructionIntoMultiOutput(producer);
HloInstruction* fusion = nullptr;
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Tuple(
m::Parameter(1), m::GetTupleElement(m::Fusion(&fusion), 3),
m::Parameter(0), m::GetTupleElement(m::Fusion(), 0),
m::GetTupleElement(m::Fusion(), 1),
m::GetTupleElement(m::Fusion(), 2))));
EXPECT_THAT(fusion->fused_instructions_computation()->root_instruction(),
GmockMatch(m::Tuple(
m::Multiply(m::Add(m::Parameter(0), m::Parameter(1)),
m::Subtract(m::Parameter(0), m::Parameter(1))),
m::Divide(m::Add(m::Parameter(0), m::Parameter(1)),
m::Subtract(m::Parameter(0), m::Parameter(1))),
m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
}

} // namespace
} // namespace xla

0 comments on commit 08fdea6

Please sign in to comment.