From 423bb50ddddc450bb9505cf6ae737a889fe57326 Mon Sep 17 00:00:00 2001 From: Bin Fan Date: Fri, 26 Apr 2019 15:01:05 -0700 Subject: [PATCH 1/2] Implement an algsimp optimization for dot operation. The basic idea is that dot(reshape(transpose(A)), constant) can be replaced by dot(reshape(A), reshape(transpose(reshape(constant)))) if the effect of lhs transpose and reshape is to reorder elements in contracting dims. We inverse the reorder on the constant side so that it can be constant folded. --- .../xla/service/algebraic_simplifier.cc | 211 ++++++++++++++++++ .../xla/service/algebraic_simplifier_test.cc | 208 +++++++++++++++++ .../compiler/xla/tests/dot_operation_test.cc | 179 +++++++++++++++ 3 files changed, 598 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 45653908458cf1..c257610ebf665b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -382,6 +382,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr OptimizeDotOfGather(HloInstruction* dot); + StatusOr OptimizeDotOfReorderContractingDims( + HloInstruction* dot); + HloComputation* GetOrCreateScalarAddComputation() { if (scalar_add_computation_) { return scalar_add_computation_; @@ -1499,6 +1502,202 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( return memoized_lookup; } +// This function tries to transform +// dot(reshape(transpose(A)), Const) to +// dot(reshape(A), reshape(transpose(reshape(Const)))), +// so that the reshape and transpose on the Const side can be constant folded. +// +// The basic idea is that since the accumulation in the dot operation is +// associative, so as long as we permute the elements of the contracting +// dimensions on both sides of the dot in the same way, the result of the +// dot is not affected. +StatusOr +AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( + HloInstruction* dot) { + // This transformation assumes layout is not assigned yet. + if (options_.is_layout_sensitive()) { + return nullptr; + } + + // Canonicalize dot(, rhs) to dot(rhs, ) to make the + // remainder of this function easier. + auto dnums = dot->dot_dimension_numbers(); + auto lhs_contracting_dims = dnums.lhs_contracting_dimensions(); + auto rhs_contracting_dims = dnums.rhs_contracting_dimensions(); + auto* lhs = dot->mutable_operand(0); + auto* rhs = dot->mutable_operand(1); + if (dot->operand(0)->IsConstant()) { + std::swap(lhs, rhs); + std::swap(lhs_contracting_dims, rhs_contracting_dims); + } + + // Require single contracting dim to make the implementation easier to + // track contracting dims. + if (dnums.lhs_contracting_dimensions_size() != 1) { + return nullptr; + } + + // Pattern match Dot(reshape(transpose(input), constant)) + HloInstruction* reshape; + HloInstruction* transpose; + HloInstruction* input; + HloInstruction* constant; + if (!Match(lhs, + m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) || + !Match(rhs, m::Constant(&constant))) { + return nullptr; + } + + // Check if reshape squishes some dims into one dim, and that this one + // dim is the dot's lhs contracting dim. + // The size of unmodified_dims should be N - 1, where N is the rank of the + // reshape output. This means that the reshape squishes some dims into one + // dim. lhs contracting dim should not be in unmodified_dims. This means + // that the squishing target dim is the lhs contracting dim. + auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape( + reshape->operand(0)->shape(), reshape->shape()); + CHECK_EQ(lhs_contracting_dims.size(), 1); + if ((unmodified_dims.size() != reshape->shape().rank() - 1) || + (absl::c_find_if(unmodified_dims, [&](const std::pair& p) { + return p.second == lhs_contracting_dims[0]; + }) != unmodified_dims.end())) { + return nullptr; + } + // Virtually pull the reshape into the dot. Now the dot is equivalent to a + // new dot with "unsquished" lhs contracting dims. We don't need to actually + // create a new dot instruction. We can just keep track of lhs and + // lhs_contracting_dims. + CHECK_GT(reshape->operand(0)->shape().rank(), reshape->shape().rank()); + lhs_contracting_dims.Resize( + reshape->operand(0)->shape().rank() - reshape->shape().rank() + 1, + lhs_contracting_dims[0]); + absl::c_iota(lhs_contracting_dims, lhs_contracting_dims[0]); + lhs = lhs->mutable_operand(0); + + // Check if transpose only permutes the contracting dims. + const auto& transpose_dims = transpose->dimensions(); + for (int64 i = 0; i < transpose_dims.size(); ++i) { + if (transpose_dims[i] != i && + !absl::c_linear_search(lhs_contracting_dims, i)) { + return nullptr; + } + } + // Virtually pull the transpose into the dot. Now the dot is equivalent to + // a new dot with "permuted" lhs contracting dims. + std::vector permutation; + for (auto dim : lhs_contracting_dims) { + permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); + } + auto new_lhs_contracting_dims = + ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); + lhs_contracting_dims.Clear(); + for (auto dim : new_lhs_contracting_dims) { + lhs_contracting_dims.Add(dim); + } + lhs = lhs->mutable_operand(0); + + // All checks are passed at this point. + // + // Transform lhs. Remove the transpose and reshape by sorting the lhs + // contracting dims and squishing them into a single one. We don't actually + // squish the lhs_contracting_dims here because we still need the unsquished + // contracting dims to invert reshape and transpose. + absl::c_sort(lhs_contracting_dims); + lhs = computation_->AddInstruction( + HloInstruction::CreateReshape(reshape->shape(), lhs)); + + // Transform rhs. Say the input HLO is: + // + // t0 = f32[2, 2, 3] parameter(0) + // t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1} + // t2 = f32[2, 6] reshape(t1) + // t3 = f32[6, 2] constant(...) + // dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1}, + // rhs_contracting_dims={0} + // + // At this point in the function, we have decided that the second and third + // dims of t0 can be switched to remove the transpose, and we have + // "virtually decomposed" the input HLO to: + // + // t0 = f32[2, 2, 3] parameter(0) + // t2' = f32[2, 6] reshape(t0) + // t3' = f32[6, 2] ops-to-be-filled ... + // dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1}, + // rhs_contracting_dims={0} + // + // The rest of this function is to fill in the ops of t3'. To do this, we + // unsquish the contracting dimensions in t3 and then apply the inverse of + // the transpose from t1. + + // Invert reshape. + CHECK_EQ(rhs_contracting_dims.size(), 1); + auto rhs_unsquished_shape_dims = constant->shape().dimensions(); + auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() + + rhs_contracting_dims[0]); + for (auto dim : lhs_contracting_dims) { + it = rhs_unsquished_shape_dims.insert(it, + transpose->shape().dimensions(dim)); + ++it; + } + HloInstruction* rhs_reshape = + computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(constant->shape().element_type(), + rhs_unsquished_shape_dims), + constant)); + rhs = rhs_reshape; + + // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims. + rhs_contracting_dims.Resize(lhs_contracting_dims.size(), + rhs_contracting_dims[0]); + absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]); + + // Invert transpose. First compute the shape. + auto rhs_transpose_shape_dims = rhs_reshape->shape().dimensions(); + it = rhs_transpose_shape_dims.erase( + rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0], + rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] + + rhs_contracting_dims.size()); + for (auto dim : lhs_contracting_dims) { + it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim)); + ++it; + } + // Then compute the transpose dims. + std::vector rhs_transpose_dims(rhs_reshape->shape().rank()); + absl::c_iota(rhs_transpose_dims, 0); + it = rhs_transpose_dims.erase( + rhs_transpose_dims.begin() + rhs_contracting_dims[0], + rhs_transpose_dims.begin() + rhs_contracting_dims[0] + + rhs_contracting_dims.size()); + auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims); + for (auto dim : lhs_contracting_dims) { + it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] - + lhs_contracting_dims[0] + + rhs_contracting_dims[0]); + ++it; + } + HloInstruction* rhs_transpose = + computation_->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(constant->shape().element_type(), + rhs_transpose_shape_dims), + rhs_reshape, rhs_transpose_dims)); + rhs = rhs_transpose; + + // Squish the multiple rhs contracting dims into a single one. + rhs = computation_->AddInstruction( + HloInstruction::CreateReshape(constant->shape(), rhs)); + + // If we virtually swapped lhs and rhs, we need to swap it back before + // creating new dot. + if (dot->operand(0)->IsConstant()) { + std::swap(lhs, rhs); + } + + HloInstruction* new_dot = + computation_->AddInstruction(HloInstruction::CreateDot( + dot->shape(), lhs, rhs, dnums, dot->precision_config())); + return new_dot; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); @@ -1632,6 +1831,18 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, new_dot); } + // Simplify dot(reshape(transpose(A)), Const) to: + // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape + // and transpose on the Const side can be constant folded. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized, + OptimizeDotOfReorderContractingDims(dot)); + if (dot_of_reorder_optimized) { + VLOG(10) << " Replaced dot " << dot->ToString() + << " with new dot operation: " + << dot_of_reorder_optimized->ToString(); + return ReplaceInstruction(dot, dot_of_reorder_optimized); + } + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index f1f2c77a27bcfc..f8d8084242f60d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5094,5 +5094,213 @@ TEST_F(AlgebraicSimplifierTest, CopyReshape) { GmockMatch(m::Reshape(m::Parameter(0)).WithShapeEqualTo(&result_shape))); } +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RL) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6}); + auto shape2 = ShapeUtil::MakeShape(F32, {3, 2, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + // The transformation of moving transpose and reshape to the constant side + // is layout insensitive. We ignore layout when checking shapes. + const HloInstruction* transpose; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); + ASSERT_TRUE(transpose->dimensions() == std::vector({1, 0, 2})); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RR) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6}, + {1, 1, 1, 1, 1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3}); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6}, + {1, 1, 1, 1, 1, 1}}) + t0 = f32[2, 3, 2] parameter(0) + t1 = f32[3, 2, 2] transpose(t0), dimensions={1, 0, 2} + lhs = f32[6, 2] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {6, 2}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3}); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR2) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{7, 7},{8, 8}}) + t0 = f32[2, 2, 2, 2] parameter(0) + t1 = f32[2, 2, 2, 2] transpose(t0), dimensions={0, 2, 3, 1} + lhs = f32[2, 8] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 8}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); + const HloInstruction* transpose; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)))))); + ASSERT_TRUE(transpose->dimensions() == std::vector({2, 0, 1, 3})); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_MM) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6, 2] constant({{{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}, + {{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}}) + t0 = f32[2, 2, 3, 2] parameter(0) + t1 = f32[2, 3, 2, 2] transpose(t0), dimensions={0, 2, 1, 3} + lhs = f32[2, 6, 2] reshape(t1) + ROOT dot.5 = f32[2, 2, 2] dot(lhs, rhs), lhs_batch_dims={0}, lhs_contracting_dims={1}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6, 2}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3, 2}); + const HloInstruction* transpose; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); + ASSERT_TRUE(transpose->dimensions() == std::vector({0, 2, 1, 3})); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegTranspose) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[12, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}) + t0 = f32[3, 4, 2] parameter(0) + t1 = f32[2, 3, 4] transpose(t0), dimensions={2, 0, 1} + lhs = f32[2, 12] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Transpose affects non-contracting dimension. The transpose and reshape + // should not be moved to the constant side. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegReshape) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{1, 1},{2, 2},{3, 3},{4, 4}}) + t0 = f32[2, 4, 3] parameter(0) + t1 = f32[2, 3, 4] transpose(t0), dimensions={0, 2, 1} + lhs = f32[3, 8] reshape(t1) + ROOT dot.5 = f32[3, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Reshape affects non-contracting dimensions. The transpose and reshape + // should not be moved to the constant side. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegConstant) { + const char* kModuleStr = R"( + HloModule m + test { + t0 = f32[2, 3, 4] parameter(0) + t1 = f32[2, 4, 3] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 12] reshape(t1) + rhs = f32[12, 2] parameter(1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Both operands are non-constant, so the optimization should not happen. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegLayout) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // We disable converting reshape to bitcast to make sure algsimp pass does + // not catch the reshape in this test, then we can simply check if algsimp + // pass does not make any change. + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + // The transformation of moving transpose and reshape to the constant side is + // layout insensitive. It should not happen if AlgebraicSimplifier is set up + // to be layout sensitive. + ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index fcd66f4b4a06e0..587db49957b590 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1403,5 +1404,183 @@ ENTRY MatrixVectorComplex { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) { + Array3D input_arr(2, 3, 2); + Array2D const_arr(2, 6); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Transpose(t0, {1, 0, 2}); + auto rhs = Reshape(t1, {6, 2}); + auto lhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) { + Array3D input_arr(2, 3, 2); + Array2D const_arr(2, 6); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Transpose(t0, {1, 0, 2}); + auto lhs = Reshape(t1, {6, 2}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(0); + dims.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dims); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) { + Array4D input_arr(2, 2, 3, 4); + Array2D const_arr(24, 2); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR4FromArray4D(input_arr), &builder); + auto t1 = Transpose(t0, {0, 2, 3, 1}); + auto lhs = Reshape(t1, {2, 24}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) { + Array3D input_arr(2, 6, 2); + Array3D const_arr(2, 6, 3); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Reshape(t0, {2, 2, 3, 2}); + auto t2 = Transpose(t1, {0, 2, 1, 3}); + auto lhs = Reshape(t2, {2, 6, 2}); + auto rhs = ConstantR3FromArray3D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(1); + dims.add_rhs_contracting_dimensions(1); + dims.add_lhs_batch_dimensions(0); + dims.add_rhs_batch_dimensions(0); + DotGeneral(lhs, rhs, dims); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) { + Array4D input_arr(2, 2, 3, 5); + Array2D const_arr(2, 30); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR4FromArray4D(input_arr), &builder); + auto t1 = Transpose(t0, {0, 2, 1, 3}); + auto t2 = Reshape(t1, {2, 6, 5}); + auto t3 = Transpose(t2, {0, 2, 1}); + auto lhs = Reshape(t3, {2, 30}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(1); + dims.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dims); + + // Constant folding are disabled by default in unit tests. algsimp + // optimization can be applied multiple times if we fold the transpose + // and reshape that are moved to the constant side of the dot. + mutable_debug_options()->clear_xla_disable_hlo_passes(); + ComputeAndCompare(&builder, {}, error_spec_); +} + +// This benchmark is to show the performance impact of the following +// transformation: +// dot(reshape(transpose(A)), Const) ==> +// dot(reshape(A), reshape(transpose(reshape(Const)))), +// and then fold the reshape and transpose on the Const side. +// We can compare performance with and without algsimp pass to see the impact. +void DOT_ReorderContracting(int num_iters) { + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + auto client = + ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); + + int device_ordinal = client->default_device_ordinal(); + + const int64 d0 = 128; + const int64 d1 = 128; + const int64 d2 = 128; + const int64 d3 = 128; + + Array3D input_arr(d0, d1, d2); + Array2D const_arr(d1 * d2, d3); + input_arr.FillIota(0); + const_arr.FillIota(0); + XlaBuilder builder("ReorderContracting"); + auto t0 = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0"); + auto t1 = Transpose(t0, {0, 2, 1}); + auto lhs = Reshape(t1, {d0, d2 * d1}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_arr); + ScopedShapedBuffer buffer0 = + client->LiteralToShapedBuffer(input_literal, device_ordinal) + .ConsumeValueOrDie(); + + std::unique_ptr executable = + client + ->Compile(computation, {&buffer0.on_host_shape()}, + ExecutableBuildOptions()) + .ConsumeValueOrDie(); + + se::Stream stream(executors[device_ordinal]); + stream.Init(); + + ExecutableRunOptions options; + options.set_allocator(&allocator); + + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + ASSERT_IS_OK(executable->Run({&buffer0}, options)); + } + + const int64 total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3; + tensorflow::testing::BytesProcessed(static_cast(num_iters) * + total_bytes * sizeof(float)); + tensorflow::testing::UseRealTime(); + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + ASSERT_IS_OK(executable->Run({&buffer0}, options)); + } +} + +BENCHMARK(DOT_ReorderContracting); + } // namespace } // namespace xla From e7f555ded70ed0c6a2bb8adbc85c7e2c21abc31c Mon Sep 17 00:00:00 2001 From: Bin Fan Date: Fri, 26 Apr 2019 15:56:41 -0700 Subject: [PATCH 2/2] Address review comments --- .../xla/service/algebraic_simplifier.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index c257610ebf665b..a10de2cd8028f1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1548,19 +1548,19 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( return nullptr; } - // Check if reshape squishes some dims into one dim, and that this one - // dim is the dot's lhs contracting dim. - // The size of unmodified_dims should be N - 1, where N is the rank of the - // reshape output. This means that the reshape squishes some dims into one - // dim. lhs contracting dim should not be in unmodified_dims. This means - // that the squishing target dim is the lhs contracting dim. + // Check that reshape squishes some dims into one dim and that this one + // dim is the dot's lhs contracting dim. The size of unmodified_dims should + // be N - 1, where N is the rank of the reshape output. This means that the + // reshape squishes some dims into one dim. lhs contracting dim should not + // be in unmodified_dims. This means that the squishing target dim is the + // lhs contracting dim. auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape( reshape->operand(0)->shape(), reshape->shape()); CHECK_EQ(lhs_contracting_dims.size(), 1); if ((unmodified_dims.size() != reshape->shape().rank() - 1) || - (absl::c_find_if(unmodified_dims, [&](const std::pair& p) { - return p.second == lhs_contracting_dims[0]; - }) != unmodified_dims.end())) { + absl::c_any_of(unmodified_dims, [&](const std::pair& p) { + return p.second == lhs_contracting_dims[0]; + })) { return nullptr; } // Virtually pull the reshape into the dot. Now the dot is equivalent to a