diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 45653908458cf1..a10de2cd8028f1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -382,6 +382,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr OptimizeDotOfGather(HloInstruction* dot); + StatusOr OptimizeDotOfReorderContractingDims( + HloInstruction* dot); + HloComputation* GetOrCreateScalarAddComputation() { if (scalar_add_computation_) { return scalar_add_computation_; @@ -1499,6 +1502,202 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( return memoized_lookup; } +// This function tries to transform +// dot(reshape(transpose(A)), Const) to +// dot(reshape(A), reshape(transpose(reshape(Const)))), +// so that the reshape and transpose on the Const side can be constant folded. +// +// The basic idea is that since the accumulation in the dot operation is +// associative, so as long as we permute the elements of the contracting +// dimensions on both sides of the dot in the same way, the result of the +// dot is not affected. +StatusOr +AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( + HloInstruction* dot) { + // This transformation assumes layout is not assigned yet. + if (options_.is_layout_sensitive()) { + return nullptr; + } + + // Canonicalize dot(, rhs) to dot(rhs, ) to make the + // remainder of this function easier. + auto dnums = dot->dot_dimension_numbers(); + auto lhs_contracting_dims = dnums.lhs_contracting_dimensions(); + auto rhs_contracting_dims = dnums.rhs_contracting_dimensions(); + auto* lhs = dot->mutable_operand(0); + auto* rhs = dot->mutable_operand(1); + if (dot->operand(0)->IsConstant()) { + std::swap(lhs, rhs); + std::swap(lhs_contracting_dims, rhs_contracting_dims); + } + + // Require single contracting dim to make the implementation easier to + // track contracting dims. + if (dnums.lhs_contracting_dimensions_size() != 1) { + return nullptr; + } + + // Pattern match Dot(reshape(transpose(input), constant)) + HloInstruction* reshape; + HloInstruction* transpose; + HloInstruction* input; + HloInstruction* constant; + if (!Match(lhs, + m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) || + !Match(rhs, m::Constant(&constant))) { + return nullptr; + } + + // Check that reshape squishes some dims into one dim and that this one + // dim is the dot's lhs contracting dim. The size of unmodified_dims should + // be N - 1, where N is the rank of the reshape output. This means that the + // reshape squishes some dims into one dim. lhs contracting dim should not + // be in unmodified_dims. This means that the squishing target dim is the + // lhs contracting dim. + auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape( + reshape->operand(0)->shape(), reshape->shape()); + CHECK_EQ(lhs_contracting_dims.size(), 1); + if ((unmodified_dims.size() != reshape->shape().rank() - 1) || + absl::c_any_of(unmodified_dims, [&](const std::pair& p) { + return p.second == lhs_contracting_dims[0]; + })) { + return nullptr; + } + // Virtually pull the reshape into the dot. Now the dot is equivalent to a + // new dot with "unsquished" lhs contracting dims. We don't need to actually + // create a new dot instruction. We can just keep track of lhs and + // lhs_contracting_dims. + CHECK_GT(reshape->operand(0)->shape().rank(), reshape->shape().rank()); + lhs_contracting_dims.Resize( + reshape->operand(0)->shape().rank() - reshape->shape().rank() + 1, + lhs_contracting_dims[0]); + absl::c_iota(lhs_contracting_dims, lhs_contracting_dims[0]); + lhs = lhs->mutable_operand(0); + + // Check if transpose only permutes the contracting dims. + const auto& transpose_dims = transpose->dimensions(); + for (int64 i = 0; i < transpose_dims.size(); ++i) { + if (transpose_dims[i] != i && + !absl::c_linear_search(lhs_contracting_dims, i)) { + return nullptr; + } + } + // Virtually pull the transpose into the dot. Now the dot is equivalent to + // a new dot with "permuted" lhs contracting dims. + std::vector permutation; + for (auto dim : lhs_contracting_dims) { + permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); + } + auto new_lhs_contracting_dims = + ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); + lhs_contracting_dims.Clear(); + for (auto dim : new_lhs_contracting_dims) { + lhs_contracting_dims.Add(dim); + } + lhs = lhs->mutable_operand(0); + + // All checks are passed at this point. + // + // Transform lhs. Remove the transpose and reshape by sorting the lhs + // contracting dims and squishing them into a single one. We don't actually + // squish the lhs_contracting_dims here because we still need the unsquished + // contracting dims to invert reshape and transpose. + absl::c_sort(lhs_contracting_dims); + lhs = computation_->AddInstruction( + HloInstruction::CreateReshape(reshape->shape(), lhs)); + + // Transform rhs. Say the input HLO is: + // + // t0 = f32[2, 2, 3] parameter(0) + // t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1} + // t2 = f32[2, 6] reshape(t1) + // t3 = f32[6, 2] constant(...) + // dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1}, + // rhs_contracting_dims={0} + // + // At this point in the function, we have decided that the second and third + // dims of t0 can be switched to remove the transpose, and we have + // "virtually decomposed" the input HLO to: + // + // t0 = f32[2, 2, 3] parameter(0) + // t2' = f32[2, 6] reshape(t0) + // t3' = f32[6, 2] ops-to-be-filled ... + // dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1}, + // rhs_contracting_dims={0} + // + // The rest of this function is to fill in the ops of t3'. To do this, we + // unsquish the contracting dimensions in t3 and then apply the inverse of + // the transpose from t1. + + // Invert reshape. + CHECK_EQ(rhs_contracting_dims.size(), 1); + auto rhs_unsquished_shape_dims = constant->shape().dimensions(); + auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() + + rhs_contracting_dims[0]); + for (auto dim : lhs_contracting_dims) { + it = rhs_unsquished_shape_dims.insert(it, + transpose->shape().dimensions(dim)); + ++it; + } + HloInstruction* rhs_reshape = + computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(constant->shape().element_type(), + rhs_unsquished_shape_dims), + constant)); + rhs = rhs_reshape; + + // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims. + rhs_contracting_dims.Resize(lhs_contracting_dims.size(), + rhs_contracting_dims[0]); + absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]); + + // Invert transpose. First compute the shape. + auto rhs_transpose_shape_dims = rhs_reshape->shape().dimensions(); + it = rhs_transpose_shape_dims.erase( + rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0], + rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] + + rhs_contracting_dims.size()); + for (auto dim : lhs_contracting_dims) { + it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim)); + ++it; + } + // Then compute the transpose dims. + std::vector rhs_transpose_dims(rhs_reshape->shape().rank()); + absl::c_iota(rhs_transpose_dims, 0); + it = rhs_transpose_dims.erase( + rhs_transpose_dims.begin() + rhs_contracting_dims[0], + rhs_transpose_dims.begin() + rhs_contracting_dims[0] + + rhs_contracting_dims.size()); + auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims); + for (auto dim : lhs_contracting_dims) { + it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] - + lhs_contracting_dims[0] + + rhs_contracting_dims[0]); + ++it; + } + HloInstruction* rhs_transpose = + computation_->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(constant->shape().element_type(), + rhs_transpose_shape_dims), + rhs_reshape, rhs_transpose_dims)); + rhs = rhs_transpose; + + // Squish the multiple rhs contracting dims into a single one. + rhs = computation_->AddInstruction( + HloInstruction::CreateReshape(constant->shape(), rhs)); + + // If we virtually swapped lhs and rhs, we need to swap it back before + // creating new dot. + if (dot->operand(0)->IsConstant()) { + std::swap(lhs, rhs); + } + + HloInstruction* new_dot = + computation_->AddInstruction(HloInstruction::CreateDot( + dot->shape(), lhs, rhs, dnums, dot->precision_config())); + return new_dot; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); @@ -1632,6 +1831,18 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, new_dot); } + // Simplify dot(reshape(transpose(A)), Const) to: + // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape + // and transpose on the Const side can be constant folded. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized, + OptimizeDotOfReorderContractingDims(dot)); + if (dot_of_reorder_optimized) { + VLOG(10) << " Replaced dot " << dot->ToString() + << " with new dot operation: " + << dot_of_reorder_optimized->ToString(); + return ReplaceInstruction(dot, dot_of_reorder_optimized); + } + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index f1f2c77a27bcfc..f8d8084242f60d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5094,5 +5094,213 @@ TEST_F(AlgebraicSimplifierTest, CopyReshape) { GmockMatch(m::Reshape(m::Parameter(0)).WithShapeEqualTo(&result_shape))); } +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RL) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6}); + auto shape2 = ShapeUtil::MakeShape(F32, {3, 2, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + // The transformation of moving transpose and reshape to the constant side + // is layout insensitive. We ignore layout when checking shapes. + const HloInstruction* transpose; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); + ASSERT_TRUE(transpose->dimensions() == std::vector({1, 0, 2})); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RR) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6}, + {1, 1, 1, 1, 1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3}); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6}, + {1, 1, 1, 1, 1, 1}}) + t0 = f32[2, 3, 2] parameter(0) + t1 = f32[3, 2, 2] transpose(t0), dimensions={1, 0, 2} + lhs = f32[6, 2] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {6, 2}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3}); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR2) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{7, 7},{8, 8}}) + t0 = f32[2, 2, 2, 2] parameter(0) + t1 = f32[2, 2, 2, 2] transpose(t0), dimensions={0, 2, 3, 1} + lhs = f32[2, 8] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 8}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); + const HloInstruction* transpose; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)))))); + ASSERT_TRUE(transpose->dimensions() == std::vector({2, 0, 1, 3})); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_MM) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6, 2] constant({{{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}, + {{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}}) + t0 = f32[2, 2, 3, 2] parameter(0) + t1 = f32[2, 3, 2, 2] transpose(t0), dimensions={0, 2, 1, 3} + lhs = f32[2, 6, 2] reshape(t1) + ROOT dot.5 = f32[2, 2, 2] dot(lhs, rhs), lhs_batch_dims={0}, lhs_contracting_dims={1}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6, 2}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3, 2}); + const HloInstruction* transpose; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); + ASSERT_TRUE(transpose->dimensions() == std::vector({0, 2, 1, 3})); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegTranspose) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[12, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}) + t0 = f32[3, 4, 2] parameter(0) + t1 = f32[2, 3, 4] transpose(t0), dimensions={2, 0, 1} + lhs = f32[2, 12] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Transpose affects non-contracting dimension. The transpose and reshape + // should not be moved to the constant side. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegReshape) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{1, 1},{2, 2},{3, 3},{4, 4}}) + t0 = f32[2, 4, 3] parameter(0) + t1 = f32[2, 3, 4] transpose(t0), dimensions={0, 2, 1} + lhs = f32[3, 8] reshape(t1) + ROOT dot.5 = f32[3, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Reshape affects non-contracting dimensions. The transpose and reshape + // should not be moved to the constant side. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegConstant) { + const char* kModuleStr = R"( + HloModule m + test { + t0 = f32[2, 3, 4] parameter(0) + t1 = f32[2, 4, 3] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 12] reshape(t1) + rhs = f32[12, 2] parameter(1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Both operands are non-constant, so the optimization should not happen. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegLayout) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // We disable converting reshape to bitcast to make sure algsimp pass does + // not catch the reshape in this test, then we can simply check if algsimp + // pass does not make any change. + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + // The transformation of moving transpose and reshape to the constant side is + // layout insensitive. It should not happen if AlgebraicSimplifier is set up + // to be layout sensitive. + ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index fcd66f4b4a06e0..587db49957b590 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1403,5 +1404,183 @@ ENTRY MatrixVectorComplex { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) { + Array3D input_arr(2, 3, 2); + Array2D const_arr(2, 6); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Transpose(t0, {1, 0, 2}); + auto rhs = Reshape(t1, {6, 2}); + auto lhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) { + Array3D input_arr(2, 3, 2); + Array2D const_arr(2, 6); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Transpose(t0, {1, 0, 2}); + auto lhs = Reshape(t1, {6, 2}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(0); + dims.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dims); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) { + Array4D input_arr(2, 2, 3, 4); + Array2D const_arr(24, 2); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR4FromArray4D(input_arr), &builder); + auto t1 = Transpose(t0, {0, 2, 3, 1}); + auto lhs = Reshape(t1, {2, 24}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) { + Array3D input_arr(2, 6, 2); + Array3D const_arr(2, 6, 3); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Reshape(t0, {2, 2, 3, 2}); + auto t2 = Transpose(t1, {0, 2, 1, 3}); + auto lhs = Reshape(t2, {2, 6, 2}); + auto rhs = ConstantR3FromArray3D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(1); + dims.add_rhs_contracting_dimensions(1); + dims.add_lhs_batch_dimensions(0); + dims.add_rhs_batch_dimensions(0); + DotGeneral(lhs, rhs, dims); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) { + Array4D input_arr(2, 2, 3, 5); + Array2D const_arr(2, 30); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR4FromArray4D(input_arr), &builder); + auto t1 = Transpose(t0, {0, 2, 1, 3}); + auto t2 = Reshape(t1, {2, 6, 5}); + auto t3 = Transpose(t2, {0, 2, 1}); + auto lhs = Reshape(t3, {2, 30}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(1); + dims.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dims); + + // Constant folding are disabled by default in unit tests. algsimp + // optimization can be applied multiple times if we fold the transpose + // and reshape that are moved to the constant side of the dot. + mutable_debug_options()->clear_xla_disable_hlo_passes(); + ComputeAndCompare(&builder, {}, error_spec_); +} + +// This benchmark is to show the performance impact of the following +// transformation: +// dot(reshape(transpose(A)), Const) ==> +// dot(reshape(A), reshape(transpose(reshape(Const)))), +// and then fold the reshape and transpose on the Const side. +// We can compare performance with and without algsimp pass to see the impact. +void DOT_ReorderContracting(int num_iters) { + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + auto client = + ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); + + int device_ordinal = client->default_device_ordinal(); + + const int64 d0 = 128; + const int64 d1 = 128; + const int64 d2 = 128; + const int64 d3 = 128; + + Array3D input_arr(d0, d1, d2); + Array2D const_arr(d1 * d2, d3); + input_arr.FillIota(0); + const_arr.FillIota(0); + XlaBuilder builder("ReorderContracting"); + auto t0 = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0"); + auto t1 = Transpose(t0, {0, 2, 1}); + auto lhs = Reshape(t1, {d0, d2 * d1}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_arr); + ScopedShapedBuffer buffer0 = + client->LiteralToShapedBuffer(input_literal, device_ordinal) + .ConsumeValueOrDie(); + + std::unique_ptr executable = + client + ->Compile(computation, {&buffer0.on_host_shape()}, + ExecutableBuildOptions()) + .ConsumeValueOrDie(); + + se::Stream stream(executors[device_ordinal]); + stream.Init(); + + ExecutableRunOptions options; + options.set_allocator(&allocator); + + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + ASSERT_IS_OK(executable->Run({&buffer0}, options)); + } + + const int64 total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3; + tensorflow::testing::BytesProcessed(static_cast(num_iters) * + total_bytes * sizeof(float)); + tensorflow::testing::UseRealTime(); + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + ASSERT_IS_OK(executable->Run({&buffer0}, options)); + } +} + +BENCHMARK(DOT_ReorderContracting); + } // namespace } // namespace xla