Skip to content

Commit

Permalink
threadpool support for relu, eltwise and softmax.
Browse files Browse the repository at this point in the history
  • Loading branch information
Srini511 committed May 13, 2020
1 parent ab67ad7 commit e0153c3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 75 deletions.
12 changes: 6 additions & 6 deletions tensorflow/core/kernels/mkl_aggregate_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class MklAddNOp : public OpKernel {

// Create Sum op, and submit net for execution.
std::vector<primitive> net;
auto sum_stream = CPU_STREAM(cpu_engine);
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
#ifdef ENABLE_MKLDNN_V1
mkldnn::sum sum_op(sum_pd);
std::unordered_map<int, memory> net_args = {
Expand All @@ -253,15 +253,15 @@ class MklAddNOp : public OpKernel {
for (int i = 0; i < num_inputs; ++i) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
}
sum_op.execute(sum_stream, net_args);
sum_op.execute(*fwd_cpu_stream, net_args);
#else
net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
sum_stream.submit(net).wait();
fwd_cpu_stream->submit(net).wait();
#endif
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
ctx, errors::Aborted("Operation received an exception:", error_msg));
}
Expand Down
93 changes: 40 additions & 53 deletions tensorflow/core/kernels/mkl_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ limitations under the License.
#include <unordered_map>

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::algorithm;
using mkldnn::eltwise_forward;
Expand Down Expand Up @@ -61,13 +61,11 @@ template <typename T>
class MklEltwiseFwdPrimitive : public MklPrimitive {
public:
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
: MklPrimitive(engine(ENGINE_CPU, 0)) {
#ifndef ENABLE_MKLDNN_V1
context_.src_fmt =
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
#endif
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));

// create eltwise primitive
if (context_.eltwise_fwd == nullptr) {
Setup(fwdParams);
Expand All @@ -79,20 +77,19 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// Eltwise forward execute
// src_data: input data buffer of src
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data) {
void Execute(const T* src_data, T* dst_data,
std::shared_ptr<stream> fwd_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));

#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
context_.fwd_primitives_args.at(i));
}
execute_primitives(context_.fwd_primitives, fwd_stream,
context_.fwd_primitives_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
fwd_stream->submit(context_.fwd_primitives);
#endif

// After execution, set data handle back.
Expand Down Expand Up @@ -134,7 +131,6 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// Eltwise primitive
std::shared_ptr<mkldnn::primitive> eltwise_fwd;

std::shared_ptr<stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;

#ifdef ENABLE_MKLDNN_V1
Expand All @@ -153,8 +149,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
src_md(nullptr),
dst_md(nullptr),
src_mpd(nullptr),
eltwise_fwd(nullptr),
fwd_stream(nullptr) {
eltwise_fwd(nullptr) {
}
};

Expand All @@ -169,14 +164,12 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
#else
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
#endif

// Create an eltwise forward descriptor and primitive descriptor
context_.fwd_desc.reset(new eltwise_forward::desc(
prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
fwdParams.alpha, fwdParams.beta));
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
auto fwd_pd = context_.fwd_pd.get();

#ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC,
Expand All @@ -195,12 +188,10 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
context_.eltwise_fwd.reset(new eltwise_forward(
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
#endif

context_.fwd_primitives.push_back(*context_.eltwise_fwd);
}

struct EltwiseFwdContext context_;
engine cpu_engine_;
};

template <typename T>
Expand Down Expand Up @@ -281,14 +272,13 @@ template <typename T>
class MklEltwiseBwdPrimitive : public MklPrimitive {
public:
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
: MklPrimitive(engine(ENGINE_CPU, 0)) {
#ifndef ENABLE_MKLDNN_V1
context_.src_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
context_.diff_dst_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
#endif
context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_)));
// create eltwise primitive
if (context_.eltwise_bwd == nullptr) {
Setup(bwdParams);
Expand All @@ -301,7 +291,8 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// src_data: input data buffer of src
// diff_dst_data: input data buffer of diff_dst
// diff_src_data: output data buffer of diff_src
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) {
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
std::shared_ptr<stream> bwd_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_dst_mem->set_data_handle(
Expand All @@ -311,12 +302,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.bwd_primitives.size(),
context_.bwd_primitives_args.size());
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
context_.bwd_primitives_args.at(i));
}
execute_primitives(context_.bwd_primitives, bwd_stream,
context_.bwd_primitives_args);
#else
context_.bwd_stream->submit(context_.bwd_primitives);
bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1

// after execution, set data handle back
Expand Down Expand Up @@ -367,7 +356,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// Eltwise primitive.
std::shared_ptr<mkldnn::primitive> eltwise_bwd;

std::shared_ptr<stream> bwd_stream;
std::vector<mkldnn::primitive> bwd_primitives;

#ifdef ENABLE_MKLDNN_V1
Expand All @@ -391,8 +379,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
fwd_desc(nullptr),
fwd_pd(nullptr),
bwd_pd(nullptr),
eltwise_bwd(nullptr),
bwd_stream(nullptr) {
eltwise_bwd(nullptr) {
}
};

Expand Down Expand Up @@ -448,7 +435,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
}

struct EltwiseBwdContext context_;
engine cpu_engine_;
};

template <typename T>
Expand Down Expand Up @@ -525,12 +511,10 @@ class MklReluOpBase : public OpKernel {
const Tensor& src_tensor = MklGetInput(context, src_index);
MklDnnShape dnn_shape_src;
GetMklShape(context, src_index, &dnn_shape_src);

if (src_tensor.dims() == 0) {
Compute_Scalar(context);
return;
}

MklDnnShape dnn_shape_dst;
TensorShape tf_shape_dst;
Tensor* dst_tensor = nullptr;
Expand All @@ -542,7 +526,6 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst);
return;
}

// Set DNN primitive - src
MklDnnData<T> src(&cpu_engine);
memory::dims src_dims;
Expand All @@ -556,26 +539,25 @@ class MklReluOpBase : public OpKernel {
// Create blocked memory descriptor
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
}

// Try to get an eltwise forward primitive from caching pool
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
beta_);

MklEltwiseFwdPrimitive<T>* eltwise_fwd =
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);

auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();

std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, eltwise_fwd->GetEngine()));
// Check if src needs to be reordered
const T* src_data = src_tensor.flat<T>().data();
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine));
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(eltwise_fwd_pd->PRIMITIVE_DESC_SRC,
cpu_engine),
context);
src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}

// Allocate dst tensor, always set it as MKL-DNN layout
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_dst.SetMklTensor(true);
Expand All @@ -590,7 +572,6 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}

OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{static_cast<const int>(src_index)},
static_cast<const int>(dst_index),
Expand All @@ -600,11 +581,11 @@ class MklReluOpBase : public OpKernel {
T* dst_data = dst_tensor->flat<T>().data();

// execute eltwise
eltwise_fwd->Execute(src_data, dst_data);
eltwise_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
Expand Down Expand Up @@ -727,13 +708,16 @@ class MklReluGradOpBase : public OpKernel {
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);

auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();

std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, eltwise_bwd->GetEngine()));
// check whether need reorder for src / diff_dst
const T* src_data = src_tensor.flat<T>().data();
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine),
context);
src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}
Expand All @@ -742,8 +726,10 @@ class MklReluGradOpBase : public OpKernel {
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd,
eltwise_bwd)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
diff_dst.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine),
context);
diff_dst_data = const_cast<T*>(
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
}
Expand Down Expand Up @@ -779,11 +765,12 @@ class MklReluGradOpBase : public OpKernel {
T* diff_src_data = diff_src_tensor->flat<T>().data();

// execute eltwise bwd
eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data);
eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data,
bwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
Expand Down
Loading

0 comments on commit e0153c3

Please sign in to comment.