Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

H fusion sharing opnd with user upstream again #47384

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions tensorflow/compiler/xla/service/copy_insertion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2877,6 +2877,56 @@ ENTRY main {
EXPECT_EQ(CountCopies(*module), 1);
}

TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
const string& hlo_string = R"(
HloModule test

fused_computation {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
add0 = f32[10, 20] add(p0, p1)
sub0 = f32[10, 10] subtract(p2, p3)
reshape0 = f32[200] reshape(add0)
reshape1 = f32[100] reshape(sub0)
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
slice0 = f32[200] slice(concat0), slice={[0:200]}
slice1 = f32[100] slice(concat0), slice={[200:300]}
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
}

ENTRY test {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
gte0 = f32[200] get-tuple-element(fusion), index=0
gte1 = f32[100] get-tuple-element(fusion), index=1
bitcast0 = f32[10,20] bitcast(gte0)
bitcast1 = f32[10,10] bitcast(gte1)
ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0},
/*param_number=*/0,
/*param_index=*/{}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1},
/*param_number=*/3,
/*param_index=*/{}));

InsertCopies(module.get());

// There should be no copies inserted.
EXPECT_EQ(CountCopies(*module), 0);
}

TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
const string& hlo_string = R"(
HloModule TestModule
Expand Down
29 changes: 21 additions & 8 deletions tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,6 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) {
return false;
}

// We can emit DUS in-place, horizontally fusing it makes the emitter no
// longer recognize that it can be done in-place. This creates much slower
// code. This restriction could be lifted if buffer assignment would recognize
// that the DUS can be done in-place even inside of a horizontal fusion.
if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
return false;
}

return true;
}

Expand All @@ -203,6 +195,19 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
return true;
}

// Returns whether any operand of `instr` is a parameter instruction that
// is shared with `fusion_instrs`.
bool AnyOpndIsParamSharedAmongFusions(
const HloInstruction* instr,
const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
return opnd->opcode() == HloOpcode::kParameter &&
absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
return user != instr && fusion_instrs.contains(user);
});
});
}

void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
HloInstruction* consumer) {
// First, find out all fusion instructions. We will filter out
Expand Down Expand Up @@ -230,6 +235,14 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
} else if (!HasOnlyRowMajorLayout(*instr)) {
VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
continue;
} else if (AnyOpndIsParamSharedAmongFusions(instr, fusion_instrs)) {
// Don't fuse fusions whose operands are parameter instructions that are
// shared among fusions because we cannot i/o alias the produced
// horizontal fusion due to the concat insertion.
VLOG(2) << "Reject the fusion instr because it shares parameter with"
<< " other fusion candidates, instr: ",
instr->ToString();
continue;
} else {
VLOG(2) << "Find a fusion candidate " << instr->ToString();
fusion_instrs_.push_back(instr);
Expand Down
58 changes: 49 additions & 9 deletions tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,33 +364,33 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
}

TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule NegativeTestForDynamicUpdateSlice

fusion.1 {
p.0 = f16[5,9,10]{2,1,0} parameter(0)
p.1 = s32[1]{0} parameter(1)
p.1 = s32[] parameter(1)
p.2 = f16[1,9,10]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice =
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}

fusion.2 {
p.0 = f16[5,9,10]{2,1,0} parameter(0)
p.1 = s32[1]{0} parameter(1)
p.1 = s32[] parameter(1)
p.2 = f16[1,9,10]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice =
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}

ENTRY entry {
p.00 = f16[5,9,10]{2,1,0} parameter(0)
p.01 = f16[5,9,10]{2,1,0} parameter(1)
p.10 = s32[1]{0} parameter(2)
p.11 = s32[1]{0} parameter(3)
p.10 = s32[] parameter(2)
p.11 = s32[] parameter(3)
p.20 = f16[1,9,10]{2,1,0} parameter(4)
p.21 = f16[1,9,10]{2,1,0} parameter(5)

Expand All @@ -400,6 +400,46 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
})")
.ValueOrDie();

EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());

VLOG(2) << "Dump after horizontal fusion:";
VLOG(2) << module->ToString();

EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
}

TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule BasicTest

fused_computation.1 {
arg.1 = f16[123]{0} parameter(0)
arg.2 = f16[123]{0} parameter(1)
ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
}

fused_computation.2 {
arg.1 = f16[123]{0} parameter(0)
arg.2 = f16[123]{0} parameter(1)
ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
}

ENTRY entry_computation {
arg.1 = f16[123]{0} parameter(0)
// arg.2 is shared by fusion.1 and fusion.2
arg.2 = f16[123]{0} parameter(1)
arg.3 = f16[123]{0} parameter(2)
fusion.1 = f16[123]{0}
fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
fusion.2 = f16[123]{0}
fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
tuple(fusion.1, fusion.2)
}
)")
.ValueOrDie();

EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
}

Expand Down
182 changes: 182 additions & 0 deletions tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,175 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
return true;
}

namespace {
bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kSlice &&
1 == instr->slice_starts().size() &&
1 == instr->slice_limits().size() &&
1 == instr->slice_strides().size() &&
1 == instr->slice_strides().at(0);
}

bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
if (!unnested_hlo.IsInputFusion()) {
return false;
}
const HloInstruction* root = unnested_hlo.fused_expression_root();
if (root->opcode() != HloOpcode::kTuple) {
return false;
}
return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
return Is1dSliceWithoutStrides(instr);
});
}

struct ConcatUsageInfo {
// Pointer to a previously seen concat. nullptr if no previously seen concat.
const HloInstruction* prev_concat;
// The opnd id of the seen concat.
int64 concat_opnd_idx;
// The slice that recovers the opnd in the concat outputs.
const HloInstruction* slice_to_recover_opnd;
};

// Returns an optional concat usage info to denote whether the concat is used in
// an elementwise manner. A concat followed by slices is considered effectively
// elementwise if the slices combinedly is a reverse function of the concat.
absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
const HloInstruction& concat, const HloInstruction& operand,
const ConcatUsageInfo& info) {
// First, check if this concat is in the below pattern. Also, we check
// that the slices combinedly are in effect a reverse function of the concat.
//
// Concat
// | |
// v v
// Slice Slice
//
std::vector<HloInstruction*> users = concat.users();
if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
// Limit our supported cases to 1 dimensional slices.
return absl::optional<ConcatUsageInfo>();
}
// Verify that each operand to the concat is reversed by a slice.
if (users.size() != concat.operand_count() ||
concat.operand_count() != concat.unique_operands().size()) {
return absl::optional<ConcatUsageInfo>();
}
absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
return a->slice_starts().at(0) < b->slice_starts().at(0);
});
int64 prev_limit = 0;
for (int64 i = 0; i < users.size(); ++i) {
const HloInstruction* u = users[i];
int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
if (u->slice_starts().at(0) != prev_limit ||
slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
return absl::optional<ConcatUsageInfo>();
}
prev_limit = u->slice_limits().at(0);
}

// If we have seen other concats, make sure they are identical. Multiple
// concats exist because horizontal fusion inserts one concat for each output
// of the fusion candidates. Check that all concats and operand ids are the
// same to know that the "transitive use closure" will be computed in the same
// iteration space.
int64 operand_idx = concat.operand_index(&operand);
if (info.prev_concat != nullptr) {
bool is_concat_identical = info.prev_concat->Identical(
concat,
/*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
// Operands don't need to be the same.
return true;
});
if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
return absl::optional<ConcatUsageInfo>();
}
}

const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
return absl::optional<ConcatUsageInfo>(
ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
}

// Returns whether we can prove the transitive uses of `param` are in effect
// elementwise. In other words, we prove that the "transitive use closure" will
// all be computed in the same iteration space without any reorder of elements.
// In addition, we check that the "transitive use closure" includes the output
// in the `root_tuple`.
// Theoretically, We can prove more patterns but our primary use case is
// SliceInputFusion.
bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
const HloInstruction* root_tuple,
const ShapeIndex& out_shape_idx) {
CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
CHECK_EQ(out_shape_idx.size(), 1);
absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(param);
ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
bool is_output_reachable = false;
while (!stack.empty()) {
const HloInstruction* current = stack.back();
stack.pop_back();
visited.insert(current);
for (const HloInstruction* user : current->users()) {
VLOG(3) << "Visiting: " << user->ToString();
switch (user->opcode()) {
case HloOpcode::kTuple:
if (user == root_tuple &&
current == root_tuple->operand(out_shape_idx.back())) {
// We need to know if the output is reachable by the `param` to make
// sure that they will be computed in the same iteration space.
is_output_reachable = true;
}
break;
case HloOpcode::kReshape:
if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
return false;
}
break;
case HloOpcode::kConcatenate: {
absl::optional<ConcatUsageInfo> optional_concat_info =
ConcatIsEffectivelyElementwise(*user, *current,
concat_usage_info);
if (!optional_concat_info) {
return false;
}
concat_usage_info = *optional_concat_info;
// Early continue as we only want to traverse through the slice that
// recovers the operand. It is guaranteed that the operand to the
// concat and the slice have the same iteration space. Insert the
// slice instead of the concat.
CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
stack.push_back(concat_usage_info.slice_to_recover_opnd);
continue;
}
default:
for (const int64 use_index : user->OperandIndices(current)) {
if (!user->IsElementwiseOnOperand(use_index)) {
// Found a user that is non-elementwise on the current
// instruction.
return false;
}
}
if (!LayoutUtil::Equal(current->shape().layout(),
user->shape().layout())) {
// Make sure the layout is not changed by the elementwise op.
return false;
}
break;
} // end of switch
if (!visited.contains(user)) {
stack.push_back(user);
}
}
}
return is_output_reachable;
}
} // namespace

bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
const HloValueSet& value_set = GetValueSet(instruction, index);
Expand Down Expand Up @@ -1266,10 +1435,23 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
if (operand->opcode() == HloOpcode::kConstant) {
return false;
}

const Shape& operand_subshape =
ShapeUtil::GetSubshape(operand->shape(), operand_index);
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
if (IsSliceInputFusion(*user)) {
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
// We don't require the same dimensions but only the same number of elements
// and type (to make sure the same buffer size).
return operand_subshape.IsArray() && user_subshape.IsArray() &&
ShapeUtil::ElementsIn(operand_subshape) ==
ShapeUtil::ElementsIn(user_subshape) &&
ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
AreTransitiveUsesEffectivelyElementwise(
fusion_param, user->fused_expression_root(), user_index);
}

// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
Expand Down