Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTEL MKL] MKL-DNNL v1.0 integration with AddN ops #36498

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// End - element-wise ops. See note above.

// NOTE: names are alphabetically sorted.
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
Expand Down
54 changes: 41 additions & 13 deletions tensorflow/core/kernels/mkl_aggregate_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,23 +152,30 @@ class MklAddNOp : public OpKernel {
return;
}

auto cpu_engine = engine(engine::cpu, 0);
auto cpu_engine = engine(ENGINE_CPU, 0);
std::vector<float> coeff(num_inputs, 1.0);
std::vector<memory::primitive_desc> srcs_pd;
std::vector<MklDnnData<T>> srcs(num_inputs, MklDnnData<T>(&cpu_engine));
std::vector<MEMORY_PRIMITIVE_DESC> srcs_pd;

#ifdef ENABLE_MKLDNN_V1
std::vector<memory> inputs;
#else
std::vector<primitive::at> inputs;
#endif

MklDnnData<T> dst(&cpu_engine);
MklDnnData<T> src(&cpu_engine);
bool has_mkl_input = false;
int mkl_input_index = FindMKLInputIndex(ctx);
memory::format mkl_data_format;
MKL_TENSOR_FORMAT mkl_data_format;
TensorFormat tf_data_format;
MEMORY_FORMAT dnn_fmt = MEMORY_FORMAT::any;
if (mkl_input_index >= 0) {
has_mkl_input = true;
GetMklShape(ctx, mkl_input_index, &mkl_shape);
// MKL input has the data format information.
mkl_data_format = mkl_shape.GetTfDataFormat();
tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format);
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
}

// Create memory descriptor for MKL-DNN.
Expand All @@ -177,7 +184,8 @@ class MklAddNOp : public OpKernel {
for (int src_idx = 0; src_idx < num_inputs; ++src_idx) {
MklDnnShape src_mkl_shape;
GetMklShape(ctx, src_idx, &src_mkl_shape);
memory::desc md({}, memory::data_undef, memory::format_undef);
memory::desc md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
src = MklDnnData<T>(&cpu_engine);
const Tensor& src_tensor = MklGetInput(ctx, src_idx);

if (src_mkl_shape.IsMklTensor()) {
Expand All @@ -193,22 +201,30 @@ class MklAddNOp : public OpKernel {
src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
tf_data_format);
}
md = memory::desc(src_dims, MklDnnType<T>(), mkl_data_format);
md = memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
} else {
// Create block memory descriptor for TensorFlow format input.
auto dims = TFShapeToMklDnnDims(src_tensor.shape());
auto strides = CalculateTFStrides(dims);
md = MklDnnData<T>::CreateBlockedMemDesc(dims, strides);
}
}
#ifdef ENABLE_MKLDNN_V1
srcs_pd.push_back(memory::desc(md));
#else
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
srcs[src_idx].SetUsrMem(md, &src_tensor);
inputs.push_back(srcs[src_idx].GetOpMem());
#endif
src.SetUsrMem(md, &src_tensor);
inputs.push_back(src.GetOpMem());
}

#ifdef ENABLE_MKLDNN_V1
auto sum_pd = sum::primitive_desc(coeff, srcs_pd, cpu_engine);
#else
auto sum_pd = sum::primitive_desc(coeff, srcs_pd);
#endif
output_mkl_shape.SetMklTensor(has_mkl_input);
auto output_pd = sum_pd.dst_primitive_desc();
auto output_pd = sum_pd.PRIMITIVE_DESC_DST;
dst.SetUsrMem(output_pd);

if (has_mkl_input) {
Expand All @@ -228,12 +244,24 @@ class MklAddNOp : public OpKernel {

// Create Sum op, and submit net for execution.
std::vector<primitive> net;
auto sum_stream = CPU_STREAM(cpu_engine);
#ifdef ENABLE_MKLDNN_V1
mkldnn::sum sum_op(sum_pd);
std::unordered_map<int, memory> net_args = {
{ MKLDNN_ARG_DST,
dst.GetOpMem() }};
for (int i = 0; i < num_inputs; ++i) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
}
sum_op.execute(sum_stream, net_args);
#else
net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
stream(stream::kind::eager).submit(net).wait();
sum_stream.submit(net).wait();
#endif
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
ctx, errors::Aborted("Operation received an exception:", error_msg));
}
Expand Down