Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement an algsimp optimization for dot operation. #28170

Merged
merged 2 commits into from Apr 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
211 changes: 211 additions & 0 deletions tensorflow/compiler/xla/service/algebraic_simplifier.cc
Expand Up @@ -382,6 +382,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {

StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);

StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims(
HloInstruction* dot);

HloComputation* GetOrCreateScalarAddComputation() {
if (scalar_add_computation_) {
return scalar_add_computation_;
Expand Down Expand Up @@ -1499,6 +1502,202 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
return memoized_lookup;
}

// This function tries to transform
// dot(reshape(transpose(A)), Const) to
// dot(reshape(A), reshape(transpose(reshape(Const)))),
// so that the reshape and transpose on the Const side can be constant folded.
//
// The basic idea is that since the accumulation in the dot operation is
// associative, so as long as we permute the elements of the contracting
// dimensions on both sides of the dot in the same way, the result of the
// dot is not affected.
StatusOr<HloInstruction*>
AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
HloInstruction* dot) {
// This transformation assumes layout is not assigned yet.
if (options_.is_layout_sensitive()) {
return nullptr;
}

// Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the
// remainder of this function easier.
auto dnums = dot->dot_dimension_numbers();
auto lhs_contracting_dims = dnums.lhs_contracting_dimensions();
auto rhs_contracting_dims = dnums.rhs_contracting_dimensions();
auto* lhs = dot->mutable_operand(0);
auto* rhs = dot->mutable_operand(1);
if (dot->operand(0)->IsConstant()) {
std::swap(lhs, rhs);
std::swap(lhs_contracting_dims, rhs_contracting_dims);
}

// Require single contracting dim to make the implementation easier to
// track contracting dims.
if (dnums.lhs_contracting_dimensions_size() != 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pull the vectors lhs_contracting_dims and rhs_contracting_dims below this if statement, then we can simply do

// Comment explaining why we're pulling these into vectors, I am still not sure what is the problem this solves, it seems to be more complex to have two copies of one piece of data?
std::vector<int64> lhs_contracting_dims = {dnums.lhs_contracting_dims[0]};

return nullptr;
}

// Pattern match Dot(reshape(transpose(input), constant))
HloInstruction* reshape;
HloInstruction* transpose;
HloInstruction* input;
HloInstruction* constant;
if (!Match(lhs,
m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) ||
!Match(rhs, m::Constant(&constant))) {
return nullptr;
}

// Check that reshape squishes some dims into one dim and that this one
// dim is the dot's lhs contracting dim. The size of unmodified_dims should
// be N - 1, where N is the rank of the reshape output. This means that the
// reshape squishes some dims into one dim. lhs contracting dim should not
// be in unmodified_dims. This means that the squishing target dim is the
// lhs contracting dim.
auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(
reshape->operand(0)->shape(), reshape->shape());
CHECK_EQ(lhs_contracting_dims.size(), 1);
if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
absl::c_any_of(unmodified_dims, [&](const std::pair<int64, int64>& p) {
return p.second == lhs_contracting_dims[0];
})) {
return nullptr;
}
// Virtually pull the reshape into the dot. Now the dot is equivalent to a
// new dot with "unsquished" lhs contracting dims. We don't need to actually
// create a new dot instruction. We can just keep track of lhs and
// lhs_contracting_dims.
CHECK_GT(reshape->operand(0)->shape().rank(), reshape->shape().rank());
lhs_contracting_dims.Resize(
reshape->operand(0)->shape().rank() - reshape->shape().rank() + 1,
lhs_contracting_dims[0]);
absl::c_iota(lhs_contracting_dims, lhs_contracting_dims[0]);
lhs = lhs->mutable_operand(0);

// Check if transpose only permutes the contracting dims.
const auto& transpose_dims = transpose->dimensions();
for (int64 i = 0; i < transpose_dims.size(); ++i) {
if (transpose_dims[i] != i &&
!absl::c_linear_search(lhs_contracting_dims, i)) {
return nullptr;
}
}
// Virtually pull the transpose into the dot. Now the dot is equivalent to
// a new dot with "permuted" lhs contracting dims.
std::vector<int64> permutation;
for (auto dim : lhs_contracting_dims) {
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
}
auto new_lhs_contracting_dims =
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
lhs_contracting_dims.Clear();
for (auto dim : new_lhs_contracting_dims) {
lhs_contracting_dims.Add(dim);
}
lhs = lhs->mutable_operand(0);

// All checks are passed at this point.
//
// Transform lhs. Remove the transpose and reshape by sorting the lhs
// contracting dims and squishing them into a single one. We don't actually
// squish the lhs_contracting_dims here because we still need the unsquished
// contracting dims to invert reshape and transpose.
absl::c_sort(lhs_contracting_dims);
lhs = computation_->AddInstruction(
HloInstruction::CreateReshape(reshape->shape(), lhs));

// Transform rhs. Say the input HLO is:
//
// t0 = f32[2, 2, 3] parameter(0)
// t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1}
// t2 = f32[2, 6] reshape(t1)
// t3 = f32[6, 2] constant(...)
// dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1},
// rhs_contracting_dims={0}
//
// At this point in the function, we have decided that the second and third
// dims of t0 can be switched to remove the transpose, and we have
// "virtually decomposed" the input HLO to:
//
// t0 = f32[2, 2, 3] parameter(0)
// t2' = f32[2, 6] reshape(t0)
// t3' = f32[6, 2] ops-to-be-filled ...
// dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1},
// rhs_contracting_dims={0}
//
// The rest of this function is to fill in the ops of t3'. To do this, we
// unsquish the contracting dimensions in t3 and then apply the inverse of
// the transpose from t1.

// Invert reshape.
CHECK_EQ(rhs_contracting_dims.size(), 1);
auto rhs_unsquished_shape_dims = constant->shape().dimensions();
auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
rhs_contracting_dims[0]);
for (auto dim : lhs_contracting_dims) {
it = rhs_unsquished_shape_dims.insert(it,
transpose->shape().dimensions(dim));
++it;
}
HloInstruction* rhs_reshape =
computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(constant->shape().element_type(),
rhs_unsquished_shape_dims),
constant));
rhs = rhs_reshape;

// Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims.
rhs_contracting_dims.Resize(lhs_contracting_dims.size(),
rhs_contracting_dims[0]);
absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);

// Invert transpose. First compute the shape.
auto rhs_transpose_shape_dims = rhs_reshape->shape().dimensions();
it = rhs_transpose_shape_dims.erase(
rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
rhs_contracting_dims.size());
for (auto dim : lhs_contracting_dims) {
it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim));
++it;
}
// Then compute the transpose dims.
std::vector<int64> rhs_transpose_dims(rhs_reshape->shape().rank());
absl::c_iota(rhs_transpose_dims, 0);
it = rhs_transpose_dims.erase(
rhs_transpose_dims.begin() + rhs_contracting_dims[0],
rhs_transpose_dims.begin() + rhs_contracting_dims[0] +
rhs_contracting_dims.size());
auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims);
for (auto dim : lhs_contracting_dims) {
it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] -
lhs_contracting_dims[0] +
rhs_contracting_dims[0]);
++it;
}
HloInstruction* rhs_transpose =
computation_->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(constant->shape().element_type(),
rhs_transpose_shape_dims),
rhs_reshape, rhs_transpose_dims));
rhs = rhs_transpose;

// Squish the multiple rhs contracting dims into a single one.
rhs = computation_->AddInstruction(
HloInstruction::CreateReshape(constant->shape(), rhs));

// If we virtually swapped lhs and rhs, we need to swap it back before
// creating new dot.
if (dot->operand(0)->IsConstant()) {
std::swap(lhs, rhs);
}

HloInstruction* new_dot =
computation_->AddInstruction(HloInstruction::CreateDot(
dot->shape(), lhs, rhs, dnums, dot->precision_config()));
return new_dot;
}

Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
HloInstruction *lhs, *rhs;
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
Expand Down Expand Up @@ -1632,6 +1831,18 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
return ReplaceInstruction(dot, new_dot);
}

// Simplify dot(reshape(transpose(A)), Const) to:
// dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape
// and transpose on the Const side can be constant folded.
TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized,
OptimizeDotOfReorderContractingDims(dot));
if (dot_of_reorder_optimized) {
VLOG(10) << " Replaced dot " << dot->ToString()
<< " with new dot operation: "
<< dot_of_reorder_optimized->ToString();
return ReplaceInstruction(dot, dot_of_reorder_optimized);
}

TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
OptimizeDotOfConcat(dot));
if (dot_of_concat_optimized) {
Expand Down