Skip to content

Commit

Permalink
Merge pull request #47384 from trentlo:h-fusion-sharing-opnd-with-use…
Browse files Browse the repository at this point in the history
…r-upstream-again

PiperOrigin-RevId: 359595301
Change-Id: Idaa94398fcae7de25946a2b197939d0209a54c5b
  • Loading branch information
tensorflower-gardener committed Feb 25, 2021
2 parents 0b5618f + 941e049 commit 3c4bd04
Show file tree
Hide file tree
Showing 5 changed files with 447 additions and 17 deletions.
50 changes: 50 additions & 0 deletions tensorflow/compiler/xla/service/copy_insertion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2877,6 +2877,56 @@ ENTRY main {
EXPECT_EQ(CountCopies(*module), 1);
}

TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
const string& hlo_string = R"(
HloModule test
fused_computation {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
add0 = f32[10, 20] add(p0, p1)
sub0 = f32[10, 10] subtract(p2, p3)
reshape0 = f32[200] reshape(add0)
reshape1 = f32[100] reshape(sub0)
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
slice0 = f32[200] slice(concat0), slice={[0:200]}
slice1 = f32[100] slice(concat0), slice={[200:300]}
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
}
ENTRY test {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
gte0 = f32[200] get-tuple-element(fusion), index=0
gte1 = f32[100] get-tuple-element(fusion), index=1
bitcast0 = f32[10,20] bitcast(gte0)
bitcast1 = f32[10,10] bitcast(gte1)
ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0},
/*param_number=*/0,
/*param_index=*/{}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1},
/*param_number=*/3,
/*param_index=*/{}));

InsertCopies(module.get());

// There should be no copies inserted.
EXPECT_EQ(CountCopies(*module), 0);
}

TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
const string& hlo_string = R"(
HloModule TestModule
Expand Down
29 changes: 21 additions & 8 deletions tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,6 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) {
return false;
}

// We can emit DUS in-place, horizontally fusing it makes the emitter no
// longer recognize that it can be done in-place. This creates much slower
// code. This restriction could be lifted if buffer assignment would recognize
// that the DUS can be done in-place even inside of a horizontal fusion.
if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
return false;
}

return true;
}

Expand All @@ -203,6 +195,19 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
return true;
}

// Returns whether any operand of `instr` is a parameter instruction that
// is shared with `fusion_instrs`.
bool AnyOpndIsParamSharedAmongFusions(
const HloInstruction* instr,
const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
return opnd->opcode() == HloOpcode::kParameter &&
absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
return user != instr && fusion_instrs.contains(user);
});
});
}

void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
HloInstruction* consumer) {
// First, find out all fusion instructions. We will filter out
Expand Down Expand Up @@ -230,6 +235,14 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
} else if (!HasOnlyRowMajorLayout(*instr)) {
VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
continue;
} else if (AnyOpndIsParamSharedAmongFusions(instr, fusion_instrs)) {
// Don't fuse fusions whose operands are parameter instructions that are
// shared among fusions because we cannot i/o alias the produced
// horizontal fusion due to the concat insertion.
VLOG(2) << "Reject the fusion instr because it shares parameter with"
<< " other fusion candidates, instr: ",
instr->ToString();
continue;
} else {
VLOG(2) << "Find a fusion candidate " << instr->ToString();
fusion_instrs_.push_back(instr);
Expand Down
58 changes: 49 additions & 9 deletions tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,33 +364,33 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
}

TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule NegativeTestForDynamicUpdateSlice
fusion.1 {
p.0 = f16[5,9,10]{2,1,0} parameter(0)
p.1 = s32[1]{0} parameter(1)
p.1 = s32[] parameter(1)
p.2 = f16[1,9,10]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice =
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}
fusion.2 {
p.0 = f16[5,9,10]{2,1,0} parameter(0)
p.1 = s32[1]{0} parameter(1)
p.1 = s32[] parameter(1)
p.2 = f16[1,9,10]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice =
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}
ENTRY entry {
p.00 = f16[5,9,10]{2,1,0} parameter(0)
p.01 = f16[5,9,10]{2,1,0} parameter(1)
p.10 = s32[1]{0} parameter(2)
p.11 = s32[1]{0} parameter(3)
p.10 = s32[] parameter(2)
p.11 = s32[] parameter(3)
p.20 = f16[1,9,10]{2,1,0} parameter(4)
p.21 = f16[1,9,10]{2,1,0} parameter(5)
Expand All @@ -400,6 +400,46 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
})")
.ValueOrDie();

EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());

VLOG(2) << "Dump after horizontal fusion:";
VLOG(2) << module->ToString();

EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
}

TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule BasicTest
fused_computation.1 {
arg.1 = f16[123]{0} parameter(0)
arg.2 = f16[123]{0} parameter(1)
ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
}
fused_computation.2 {
arg.1 = f16[123]{0} parameter(0)
arg.2 = f16[123]{0} parameter(1)
ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
}
ENTRY entry_computation {
arg.1 = f16[123]{0} parameter(0)
// arg.2 is shared by fusion.1 and fusion.2
arg.2 = f16[123]{0} parameter(1)
arg.3 = f16[123]{0} parameter(2)
fusion.1 = f16[123]{0}
fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
fusion.2 = f16[123]{0}
fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
tuple(fusion.1, fusion.2)
}
)")
.ValueOrDie();

EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
}

Expand Down
182 changes: 182 additions & 0 deletions tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,175 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
return true;
}

namespace {
bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kSlice &&
1 == instr->slice_starts().size() &&
1 == instr->slice_limits().size() &&
1 == instr->slice_strides().size() &&
1 == instr->slice_strides().at(0);
}

bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
if (!unnested_hlo.IsInputFusion()) {
return false;
}
const HloInstruction* root = unnested_hlo.fused_expression_root();
if (root->opcode() != HloOpcode::kTuple) {
return false;
}
return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
return Is1dSliceWithoutStrides(instr);
});
}

struct ConcatUsageInfo {
// Pointer to a previously seen concat. nullptr if no previously seen concat.
const HloInstruction* prev_concat;
// The opnd id of the seen concat.
int64 concat_opnd_idx;
// The slice that recovers the opnd in the concat outputs.
const HloInstruction* slice_to_recover_opnd;
};

// Returns an optional concat usage info to denote whether the concat is used in
// an elementwise manner. A concat followed by slices is considered effectively
// elementwise if the slices combinedly is a reverse function of the concat.
absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
const HloInstruction& concat, const HloInstruction& operand,
const ConcatUsageInfo& info) {
// First, check if this concat is in the below pattern. Also, we check
// that the slices combinedly are in effect a reverse function of the concat.
//
// Concat
// | |
// v v
// Slice Slice
//
std::vector<HloInstruction*> users = concat.users();
if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
// Limit our supported cases to 1 dimensional slices.
return absl::optional<ConcatUsageInfo>();
}
// Verify that each operand to the concat is reversed by a slice.
if (users.size() != concat.operand_count() ||
concat.operand_count() != concat.unique_operands().size()) {
return absl::optional<ConcatUsageInfo>();
}
absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
return a->slice_starts().at(0) < b->slice_starts().at(0);
});
int64 prev_limit = 0;
for (int64 i = 0; i < users.size(); ++i) {
const HloInstruction* u = users[i];
int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
if (u->slice_starts().at(0) != prev_limit ||
slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
return absl::optional<ConcatUsageInfo>();
}
prev_limit = u->slice_limits().at(0);
}

// If we have seen other concats, make sure they are identical. Multiple
// concats exist because horizontal fusion inserts one concat for each output
// of the fusion candidates. Check that all concats and operand ids are the
// same to know that the "transitive use closure" will be computed in the same
// iteration space.
int64 operand_idx = concat.operand_index(&operand);
if (info.prev_concat != nullptr) {
bool is_concat_identical = info.prev_concat->Identical(
concat,
/*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
// Operands don't need to be the same.
return true;
});
if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
return absl::optional<ConcatUsageInfo>();
}
}

const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
return absl::optional<ConcatUsageInfo>(
ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
}

// Returns whether we can prove the transitive uses of `param` are in effect
// elementwise. In other words, we prove that the "transitive use closure" will
// all be computed in the same iteration space without any reorder of elements.
// In addition, we check that the "transitive use closure" includes the output
// in the `root_tuple`.
// Theoretically, We can prove more patterns but our primary use case is
// SliceInputFusion.
bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
const HloInstruction* root_tuple,
const ShapeIndex& out_shape_idx) {
CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
CHECK_EQ(out_shape_idx.size(), 1);
absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(param);
ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
bool is_output_reachable = false;
while (!stack.empty()) {
const HloInstruction* current = stack.back();
stack.pop_back();
visited.insert(current);
for (const HloInstruction* user : current->users()) {
VLOG(3) << "Visiting: " << user->ToString();
switch (user->opcode()) {
case HloOpcode::kTuple:
if (user == root_tuple &&
current == root_tuple->operand(out_shape_idx.back())) {
// We need to know if the output is reachable by the `param` to make
// sure that they will be computed in the same iteration space.
is_output_reachable = true;
}
break;
case HloOpcode::kReshape:
if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
return false;
}
break;
case HloOpcode::kConcatenate: {
absl::optional<ConcatUsageInfo> optional_concat_info =
ConcatIsEffectivelyElementwise(*user, *current,
concat_usage_info);
if (!optional_concat_info) {
return false;
}
concat_usage_info = *optional_concat_info;
// Early continue as we only want to traverse through the slice that
// recovers the operand. It is guaranteed that the operand to the
// concat and the slice have the same iteration space. Insert the
// slice instead of the concat.
CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
stack.push_back(concat_usage_info.slice_to_recover_opnd);
continue;
}
default:
for (const int64 use_index : user->OperandIndices(current)) {
if (!user->IsElementwiseOnOperand(use_index)) {
// Found a user that is non-elementwise on the current
// instruction.
return false;
}
}
if (!LayoutUtil::Equal(current->shape().layout(),
user->shape().layout())) {
// Make sure the layout is not changed by the elementwise op.
return false;
}
break;
} // end of switch
if (!visited.contains(user)) {
stack.push_back(user);
}
}
}
return is_output_reachable;
}
} // namespace

bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
const HloValueSet& value_set = GetValueSet(instruction, index);
Expand Down Expand Up @@ -1266,10 +1435,23 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
if (operand->opcode() == HloOpcode::kConstant) {
return false;
}

const Shape& operand_subshape =
ShapeUtil::GetSubshape(operand->shape(), operand_index);
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
if (IsSliceInputFusion(*user)) {
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
// We don't require the same dimensions but only the same number of elements
// and type (to make sure the same buffer size).
return operand_subshape.IsArray() && user_subshape.IsArray() &&
ShapeUtil::ElementsIn(operand_subshape) ==
ShapeUtil::ElementsIn(user_subshape) &&
ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
AreTransitiveUsesEffectivelyElementwise(
fusion_param, user->fused_expression_root(), user_index);
}

// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
Expand Down

0 comments on commit 3c4bd04

Please sign in to comment.