diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 5594c998dd1f69..f1000d318f9bf0 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -34,6 +34,7 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; using SYCLDevice = Eigen::SyclDevice; +using Index = Eigen::Index; namespace { template @@ -310,6 +311,19 @@ struct ApplyAdamNonCuda { typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad, bool use_nesterov) { + // Get params length and check if they can be vectorized by packet size. + Index length = var.size(); + Index packet_size = Eigen::internal::packet_traits::size; + if (length % packet_size == 0) { + length = length / packet_size; + } else { + packet_size = 1; + } + + T* var_ptr = var.data(); + T* m_ptr = m.data(); + T* v_ptr = v.data(); + const T* g_ptr = grad.data(); const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / (T(1) - beta1_power()); // beta1 == μ @@ -317,14 +331,45 @@ struct ApplyAdamNonCuda { // v == n // var == θ - m.device(d) += (grad - m) * (T(1) - beta1()); - v.device(d) += (grad.square() - v) * (T(1) - beta2()); - if (use_nesterov) { - var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) / - (v.sqrt() + epsilon()); - } else { - var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); - } + auto shard = [this, var_ptr, m_ptr, v_ptr, g_ptr, alpha, beta1, beta2, + epsilon, use_nesterov, packet_size](int begin, int end) { + int t_size = (end - begin) * packet_size; + begin = begin * packet_size; + auto var = typename TTypes::UnalignedTensor(var_ptr + begin, t_size); + auto m = typename TTypes::UnalignedTensor(m_ptr + begin, t_size); + auto v = typename TTypes::UnalignedTensor(v_ptr + begin, t_size); + auto g = typename TTypes::UnalignedConstTensor(g_ptr + begin, t_size); + + if (use_nesterov) { + m += (g - m) * (T(1) - beta1()); + v += (g.square() - v) * (T(1) - beta2()); + var -= ((g * (T(1) - beta1()) + beta1() * m) * alpha) / + (v.sqrt() + epsilon()); + } else { + m += (g - m) * (T(1) - beta1()); + v += (g.square() - v) * (T(1) - beta2()); + var -= (m * alpha) / (v.sqrt() + epsilon()); + } + }; + + // Input data: var, v, m, grad. + // Output data: var, v, m. + const int input_bytes = length * packet_size * sizeof(T) * 4; + const int output_bytes = length * packet_size * sizeof(T) * 3; + const int compute_cycles = + // Consider Sub as Add + (Eigen::TensorOpCost::AddCost() * 5 + + Eigen::TensorOpCost::MulCost() * 2 + + Eigen::TensorOpCost::AddCost() * 10 + + Eigen::TensorOpCost::MulCost() * 6 + + Eigen::TensorOpCost::DivCost()) * + length; + const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); + + // Eigen device must update 3 variables with 3 different expressions, + // which is bad for cache locality on CPU. Here use ParallelFor instead of + // "regular" tensor expressions to get better performance. + d.parallelFor(length, cost, shard); } }; @@ -1250,22 +1295,19 @@ class ApplyProximalAdagradOp : public OpKernel { var.shape().DebugString(), " ", accum.shape().DebugString())); const Tensor& lr = ctx->input(2); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(lr.shape()) && - lr.scalar()() > static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) && + lr.scalar()() > static_cast(0), errors::InvalidArgument("lr is not a positive scalar: ", lr.shape().DebugString())); const Tensor& l1 = ctx->input(3); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l1.shape()) && - l1.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) && + l1.scalar()() >= static_cast(0), errors::InvalidArgument("l1 regularization strength is not a " "non-negative scalar: ", l1.shape().DebugString())); const Tensor& l2 = ctx->input(4); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l2.shape()) && - l2.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) && + l2.scalar()() >= static_cast(0), errors::InvalidArgument("l2 regularization strength is not a " "non-negative scalar: ", l2.shape().DebugString())); @@ -1497,22 +1539,19 @@ class SparseApplyProximalAdagradOp : public OpKernel { errors::InvalidArgument("var must be at least 1 dimensional")); const Tensor& lr = ctx->input(2); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(lr.shape()) && - lr.scalar()() > static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) && + lr.scalar()() > static_cast(0), errors::InvalidArgument("lr is not a positive scalar: ", lr.shape().DebugString())); const Tensor& l1 = ctx->input(3); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l1.shape()) && - l1.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) && + l1.scalar()() >= static_cast(0), errors::InvalidArgument("l1 regularization strength is not a " "non-negative scalar: ", l1.shape().DebugString())); const Tensor& l2 = ctx->input(4); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l2.shape()) && - l2.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) && + l2.scalar()() >= static_cast(0), errors::InvalidArgument("l2 regularization strength is not a " "non-negative scalar: ", l2.shape().DebugString())); @@ -1989,30 +2028,26 @@ class ApplyFtrlOp : public OpKernel { grad.shape().DebugString())); const Tensor& lr = ctx->input(4); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(lr.shape()) && - lr.scalar()() > static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) && + lr.scalar()() > static_cast(0), errors::InvalidArgument("lr is not a positive scalar: ", lr.shape().DebugString())); const Tensor& l1 = ctx->input(5); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l1.shape()) && - l1.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) && + l1.scalar()() >= static_cast(0), errors::InvalidArgument("l1 regularization strength is not a " "non-negative scalar: ", l1.shape().DebugString())); const Tensor& l2 = ctx->input(6); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l2.shape()) && - l2.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) && + l2.scalar()() >= static_cast(0), errors::InvalidArgument("l2 regularization strength is not a " "non-negative scalar: ", l2.shape().DebugString())); const int lr_power_index = has_l2_shrinkage ? 8 : 7; const Tensor& lr_power = ctx->input(lr_power_index); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(lr_power.shape()) && - lr_power.scalar()() <= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power.shape()) && + lr_power.scalar()() <= static_cast(0), errors::InvalidArgument("lr_power is not a" " non-positive scalar: ", lr_power.shape().DebugString())); @@ -2021,9 +2056,8 @@ class ApplyFtrlOp : public OpKernel { if (has_l2_shrinkage) { const Tensor& l2_shrinkage = ctx->input(7); OP_REQUIRES( - ctx, - TensorShapeUtils::IsScalar(l2_shrinkage.shape()) && - l2_shrinkage.scalar()() >= static_cast(0), + ctx, TensorShapeUtils::IsScalar(l2_shrinkage.shape()) && + l2_shrinkage.scalar()() >= static_cast(0), errors::InvalidArgument("l2 shrinkage regularization strength " "is not a non-negative scalar: ", l2_shrinkage.shape().DebugString())); @@ -2141,31 +2175,27 @@ class SparseApplyFtrlOp : public OpKernel { errors::InvalidArgument("indices must be one-dimensional")); const Tensor& lr = ctx->input(5); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(lr.shape()) && - lr.scalar()() > static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) && + lr.scalar()() > static_cast(0), errors::InvalidArgument("lr is not a positive scalar: ", lr.shape().DebugString())); const Tensor& l1 = ctx->input(6); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l1.shape()) && - l1.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) && + l1.scalar()() >= static_cast(0), errors::InvalidArgument("l1 regularization strength is not a " "non-negative scalar: ", l1.shape().DebugString())); const Tensor& l2 = ctx->input(7); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(l2.shape()) && - l2.scalar()() >= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) && + l2.scalar()() >= static_cast(0), errors::InvalidArgument("l2 regularization strength is not a " "non-negative scalar: ", l2.shape().DebugString())); const int lr_power_index = has_l2_shrinkage ? 9 : 8; const Tensor& lr_power = ctx->input(lr_power_index); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(lr_power.shape()) && - lr_power.scalar()() <= static_cast(0), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power.shape()) && + lr_power.scalar()() <= static_cast(0), errors::InvalidArgument("lr_power is not a " "non-positive scalar: ", lr_power.shape().DebugString())); @@ -2190,9 +2220,8 @@ class SparseApplyFtrlOp : public OpKernel { if (has_l2_shrinkage) { l2_shrinkage = &ctx->input(8); OP_REQUIRES( - ctx, - TensorShapeUtils::IsScalar(l2_shrinkage->shape()) && - l2_shrinkage->scalar()() >= static_cast(0), + ctx, TensorShapeUtils::IsScalar(l2_shrinkage->shape()) && + l2_shrinkage->scalar()() >= static_cast(0), errors::InvalidArgument("l2 shrinkage regularization strength " "is not a non-negative scalar: ", l2_shrinkage->shape().DebugString())); @@ -2234,9 +2263,10 @@ class SparseApplyFtrlOp : public OpKernel { linear += grad_maybe_with_shrinkage - \ (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \ } else { \ - linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \ - accum.pow(-lr_power_scalar)) / \ - lr_scalar * var; \ + linear += \ + grad_maybe_with_shrinkage - \ + (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) / \ + lr_scalar * var; \ } \ auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \ auto x = l1_reg_adjust - linear; \ diff --git a/tensorflow/core/kernels/training_ops_test.cc b/tensorflow/core/kernels/training_ops_test.cc index 2dcc4a500e6c64..9ccca71047d44a 100644 --- a/tensorflow/core/kernels/training_ops_test.cc +++ b/tensorflow/core/kernels/training_ops_test.cc @@ -176,23 +176,28 @@ static void Adam(int32 n, Graph** init_g, Graph** train_g) { auto beta2 = Scalar(g, 0.99); auto epsilon = Scalar(g, 1e-8); auto grad = Random(g, n); - test::graph::Multi( - g, "ApplyAdam", - {var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad}); + test::graph::Multi(g, "ApplyAdam", {var, m, v, beta1_power, beta2_power, lr, + beta1, beta2, epsilon, grad}); *train_g = g; } } -static void BM_Adam(int iters, int params) { +static void BM_Adam(int iters, int params, int is_multi_threaded) { const int64 tot = static_cast(iters) * params; testing::ItemsProcessed(tot); testing::BytesProcessed(tot * sizeof(float)); Graph* init; Graph* train; Adam(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + if (is_multi_threaded) { + // Use max thread number if test performance. + test::Benchmark("cpu", train, nullptr, init).Run(iters); + } else { + test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + } } -BENCHMARK(BM_Adam)->Arg(128 << 10)->Arg(256 << 10); +BENCHMARK(BM_Adam)->ArgPair(128 << 10, 0)->ArgPair(256 << 10, 0); +BENCHMARK(BM_Adam)->ArgPair(256 << 5, 1)->ArgPair(256 << 16, 1); static void RMSProp(int32 n, Graph** init_g, Graph** train_g) { TensorShape shape({n});