Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Use Shard function instead of Eigen device to parallelize Adam kernel. #26424

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
148 changes: 89 additions & 59 deletions tensorflow/core/kernels/training_ops.cc
Expand Up @@ -34,6 +34,7 @@ namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
using SYCLDevice = Eigen::SyclDevice;
using Index = Eigen::Index;

namespace {
template <class T>
Expand Down Expand Up @@ -310,21 +311,65 @@ struct ApplyAdamNonCuda {
typename TTypes<T>::ConstScalar beta2,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
// Get params length and check if they can be vectorized by packet size.
Index length = var.size();
Index packet_size = Eigen::internal::packet_traits<T>::size;
if (length % packet_size == 0) {
length = length / packet_size;
} else {
packet_size = 1;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to divide the input size by the packet size, and do "manual vectorization". If it's desirable to have shard size (end-begin) to be a multiple of a packet size, you can pass block_align to parallelFor (see https://bitbucket.org/eigen/eigen/src/4b28c8008901c6d760f48f26ee2e3423fd8a2b40/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h#lines-185). \

I think this should work:

[](Index index) -> Index { return Eigen::divup(index, packet_size); }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got some question when try to use this function, please see my comment below.


T* var_ptr = var.data();
T* m_ptr = m.data();
T* v_ptr = v.data();
const T* g_ptr = grad.data();
const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) /
(T(1) - beta1_power());
// beta1 == μ
// beta2 == ν
// v == n
// var == θ

m.device(d) += (grad - m) * (T(1) - beta1());
v.device(d) += (grad.square() - v) * (T(1) - beta2());
if (use_nesterov) {
var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) /
(v.sqrt() + epsilon());
} else {
var.device(d) -= (m * alpha) / (v.sqrt() + epsilon());
}
auto shard = [this, var_ptr, m_ptr, v_ptr, g_ptr, alpha, beta1, beta2,
epsilon, use_nesterov, packet_size](int begin, int end) {
int t_size = (end - begin) * packet_size;
begin = begin * packet_size;
auto var = typename TTypes<T>::UnalignedTensor(var_ptr + begin, t_size);
auto m = typename TTypes<T>::UnalignedTensor(m_ptr + begin, t_size);
auto v = typename TTypes<T>::UnalignedTensor(v_ptr + begin, t_size);
auto g = typename TTypes<T>::UnalignedConstTensor(g_ptr + begin, t_size);

if (use_nesterov) {
Zantares marked this conversation as resolved.
Show resolved Hide resolved
m += (g - m) * (T(1) - beta1());
v += (g.square() - v) * (T(1) - beta2());
var -= ((g * (T(1) - beta1()) + beta1() * m) * alpha) /
(v.sqrt() + epsilon());
} else {
m += (g - m) * (T(1) - beta1());
v += (g.square() - v) * (T(1) - beta2());
var -= (m * alpha) / (v.sqrt() + epsilon());
}
};

// Input data: var, v, m, grad.
// Output data: var, v, m.
const int input_bytes = length * packet_size * sizeof(T) * 4;
const int output_bytes = length * packet_size * sizeof(T) * 3;
const int compute_cycles =
// Consider Sub as Add
(Eigen::TensorOpCost::AddCost<int>() * 5 +
Eigen::TensorOpCost::MulCost<int>() * 2 +
Eigen::TensorOpCost::AddCost<T>() * 10 +
Eigen::TensorOpCost::MulCost<T>() * 6 +
Eigen::TensorOpCost::DivCost<T>()) *
length;
const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);

// Eigen device must update 3 variables with 3 different expressions,
// which is bad for cache locality on CPU. Here use ParallelFor instead of
// "regular" tensor expressions to get better performance.
d.parallelFor(length, cost, shard);
}
};

Expand Down Expand Up @@ -1250,22 +1295,19 @@ class ApplyProximalAdagradOp : public OpKernel {
var.shape().DebugString(), " ",
accum.shape().DebugString()));
const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
errors::InvalidArgument("lr is not a positive scalar: ",
lr.shape().DebugString()));
const Tensor& l1 = ctx->input(3);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l1 regularization strength is not a "
"non-negative scalar: ",
l1.shape().DebugString()));
const Tensor& l2 = ctx->input(4);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l2 regularization strength is not a "
"non-negative scalar: ",
l2.shape().DebugString()));
Expand Down Expand Up @@ -1497,22 +1539,19 @@ class SparseApplyProximalAdagradOp : public OpKernel {
errors::InvalidArgument("var must be at least 1 dimensional"));

const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
errors::InvalidArgument("lr is not a positive scalar: ",
lr.shape().DebugString()));
const Tensor& l1 = ctx->input(3);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l1 regularization strength is not a "
"non-negative scalar: ",
l1.shape().DebugString()));
const Tensor& l2 = ctx->input(4);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l2 regularization strength is not a "
"non-negative scalar: ",
l2.shape().DebugString()));
Expand Down Expand Up @@ -1989,30 +2028,26 @@ class ApplyFtrlOp : public OpKernel {
grad.shape().DebugString()));

const Tensor& lr = ctx->input(4);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
errors::InvalidArgument("lr is not a positive scalar: ",
lr.shape().DebugString()));
const Tensor& l1 = ctx->input(5);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l1 regularization strength is not a "
"non-negative scalar: ",
l1.shape().DebugString()));
const Tensor& l2 = ctx->input(6);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l2 regularization strength is not a "
"non-negative scalar: ",
l2.shape().DebugString()));
const int lr_power_index = has_l2_shrinkage ? 8 : 7;
const Tensor& lr_power = ctx->input(lr_power_index);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(lr_power.shape()) &&
lr_power.scalar<T>()() <= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power.shape()) &&
lr_power.scalar<T>()() <= static_cast<T>(0),
errors::InvalidArgument("lr_power is not a"
" non-positive scalar: ",
lr_power.shape().DebugString()));
Expand All @@ -2021,9 +2056,8 @@ class ApplyFtrlOp : public OpKernel {
if (has_l2_shrinkage) {
const Tensor& l2_shrinkage = ctx->input(7);
OP_REQUIRES(
ctx,
TensorShapeUtils::IsScalar(l2_shrinkage.shape()) &&
l2_shrinkage.scalar<T>()() >= static_cast<T>(0),
ctx, TensorShapeUtils::IsScalar(l2_shrinkage.shape()) &&
l2_shrinkage.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l2 shrinkage regularization strength "
"is not a non-negative scalar: ",
l2_shrinkage.shape().DebugString()));
Expand Down Expand Up @@ -2141,31 +2175,27 @@ class SparseApplyFtrlOp : public OpKernel {
errors::InvalidArgument("indices must be one-dimensional"));

const Tensor& lr = ctx->input(5);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) &&
lr.scalar<T>()() > static_cast<T>(0),
errors::InvalidArgument("lr is not a positive scalar: ",
lr.shape().DebugString()));

const Tensor& l1 = ctx->input(6);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) &&
l1.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l1 regularization strength is not a "
"non-negative scalar: ",
l1.shape().DebugString()));
const Tensor& l2 = ctx->input(7);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) &&
l2.scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l2 regularization strength is not a "
"non-negative scalar: ",
l2.shape().DebugString()));
const int lr_power_index = has_l2_shrinkage ? 9 : 8;
const Tensor& lr_power = ctx->input(lr_power_index);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(lr_power.shape()) &&
lr_power.scalar<T>()() <= static_cast<T>(0),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power.shape()) &&
lr_power.scalar<T>()() <= static_cast<T>(0),
errors::InvalidArgument("lr_power is not a "
"non-positive scalar: ",
lr_power.shape().DebugString()));
Expand All @@ -2190,9 +2220,8 @@ class SparseApplyFtrlOp : public OpKernel {
if (has_l2_shrinkage) {
l2_shrinkage = &ctx->input(8);
OP_REQUIRES(
ctx,
TensorShapeUtils::IsScalar(l2_shrinkage->shape()) &&
l2_shrinkage->scalar<T>()() >= static_cast<T>(0),
ctx, TensorShapeUtils::IsScalar(l2_shrinkage->shape()) &&
l2_shrinkage->scalar<T>()() >= static_cast<T>(0),
errors::InvalidArgument("l2 shrinkage regularization strength "
"is not a non-negative scalar: ",
l2_shrinkage->shape().DebugString()));
Expand Down Expand Up @@ -2234,9 +2263,10 @@ class SparseApplyFtrlOp : public OpKernel {
linear += grad_maybe_with_shrinkage - \
(new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
} else { \
linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \
accum.pow(-lr_power_scalar)) / \
lr_scalar * var; \
linear += \
grad_maybe_with_shrinkage - \
(new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) / \
lr_scalar * var; \
} \
auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \
auto x = l1_reg_adjust - linear; \
Expand Down
17 changes: 11 additions & 6 deletions tensorflow/core/kernels/training_ops_test.cc
Expand Up @@ -176,23 +176,28 @@ static void Adam(int32 n, Graph** init_g, Graph** train_g) {
auto beta2 = Scalar(g, 0.99);
auto epsilon = Scalar(g, 1e-8);
auto grad = Random(g, n);
test::graph::Multi(
g, "ApplyAdam",
{var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad});
test::graph::Multi(g, "ApplyAdam", {var, m, v, beta1_power, beta2_power, lr,
beta1, beta2, epsilon, grad});
*train_g = g;
}
}

static void BM_Adam(int iters, int params) {
static void BM_Adam(int iters, int params, int is_multi_threaded) {
const int64 tot = static_cast<int64>(iters) * params;
testing::ItemsProcessed(tot);
testing::BytesProcessed(tot * sizeof(float));
Graph* init;
Graph* train;
Adam(params, &init, &train);
test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
if (is_multi_threaded) {
// Use max thread number if test performance.
test::Benchmark("cpu", train, nullptr, init).Run(iters);
} else {
test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
}
}
BENCHMARK(BM_Adam)->Arg(128 << 10)->Arg(256 << 10);
BENCHMARK(BM_Adam)->ArgPair(128 << 10, 0)->ArgPair(256 << 10, 0);
BENCHMARK(BM_Adam)->ArgPair(256 << 5, 1)->ArgPair(256 << 16, 1);

static void RMSProp(int32 n, Graph** init_g, Graph** train_g) {
TensorShape shape({n});
Expand Down