Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTEL MKL] Slice and Reshape op support with MKLDNN 1.0 #36497

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 16 additions & 16 deletions tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mul = "Mul";
csinfo_.squared_difference = "SquaredDifference";
csinfo_.sub = "Sub";
// End - element-wise ops. See note above.
// End - element-wise ops. See note above.

// NOTE: names are alphabetically sorted.
// NOTE: names are alphabetically sorted.
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
Expand Down Expand Up @@ -671,25 +671,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back(
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
// Disable these two MKL operators for now due to some test failures caused
// by these two ops
/*
rinfo_.push_back({csinfo_.tanh,
mkl_op_registry::GetMklOpName(csinfo_.tanh),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.tanh_grad,
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
*/
#endif // !ENABLE_MKLDNN_V1
// Disable these two MKL operators for now due to some test failures caused
// by these two ops
/*
rinfo_.push_back({csinfo_.tanh,
mkl_op_registry::GetMklOpName(csinfo_.tanh),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.tanh_grad,
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
*/
rinfo_.push_back(
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.slice,
mkl_op_registry::GetMklOpName(csinfo_.slice),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back(
{csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
Expand All @@ -701,12 +703,10 @@ rinfo_.push_back({csinfo_.tanh_grad,
rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
#endif // !ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.transpose,
mkl_op_registry::GetMklOpName(csinfo_.transpose),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});

#ifndef ENABLE_MKLDNN_V1
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
Expand Down
29 changes: 16 additions & 13 deletions tensorflow/core/kernels/mkl_reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@ limitations under the License.
#ifdef INTEL_MKL

#include <memory>
#include "mkldnn.hpp"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"

#include "mkldnn.hpp"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"

using mkldnn::stream;

namespace tensorflow {

using CPUDevice = Eigen::ThreadPoolDevice;

template <typename Device, typename T>
class MklReshapeOp : public OpKernel {
public:
Expand All @@ -43,7 +46,6 @@ class MklReshapeOp : public OpKernel {
bool SkipReorder(const MklDnnShape& mkl_shape_input,
const TensorShape& reshape_to) {
CHECK_EQ(mkl_shape_input.IsMklTensor(), true);
bool ret = false;

// If Tensorflow's data format and the underlying format maintained by
// MKLDNN are equivalent (both are NHWC or both are NCHW), then we can
Expand All @@ -52,12 +54,7 @@ class MklReshapeOp : public OpKernel {
// blocking_desc_is_equal() for checking all the stride arrays in
// mkl-dnn/blob/master/src/common/type_helpers.hpp
auto input_mkl_md = mkl_shape_input.GetMklLayout();
if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format &&
mkl_shape_input.GetTfDataFormat() != memory::format::blocked) {
ret = true;
}

return ret;
return SKIP_INPUT_REORDER(mkl_shape_input, input_mkl_md);
}

public:
Expand Down Expand Up @@ -140,7 +137,7 @@ class MklReshapeOp : public OpKernel {
return;
} else {
try {
auto cpu_engine = engine(engine::cpu, 0);
auto cpu_engine = engine(ENGINE_CPU, 0);
MklDnnData<T> dnn_data_input(&cpu_engine);
// Reshape is just a logical view change operation for a tensor.
// It does not change underlying layout. But MKLDNN may maintain
Expand All @@ -162,8 +159,10 @@ class MklReshapeOp : public OpKernel {
dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor);
// Get expected Tensorflow layout of input tensor.
auto output_tf_md = mkl_shape_input.GetTfLayout();
#ifndef ENABLE_MKLDNN_V1
auto output_tf_pd =
memory::primitive_desc(output_tf_md, cpu_engine);
#endif // !ENABLE_MKLDNN_V1

Tensor* output_tensor = nullptr;
MklDnnShape mkl_shape_output;
Expand All @@ -177,7 +176,7 @@ class MklReshapeOp : public OpKernel {
// shape_from != shape_to), then we just copy input tensor to
// output tensor with target shape (we cannot forward Mkl layout
// in such case because shape has changed.)
if (dnn_data_input.CheckReorderToOpMem(output_tf_pd,
if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD,
output_tensor)) {
} else {
OP_REQUIRES(
Expand All @@ -194,16 +193,18 @@ class MklReshapeOp : public OpKernel {
auto output_strides = CalculateTFStrides(output_dims);
auto output_tf_md = MklDnnData<T>::CreateBlockedMemDesc(
output_dims, output_strides);
#ifndef ENABLE_MKLDNN_V1
auto output_tf_pd =
memory::primitive_desc(output_tf_md, cpu_engine);
#endif // !ENABLE_MKLDNN_V1

// Set MklDnnShape
MklDnnShape mkl_shape_output;
mkl_shape_output.SetMklTensor(true);
mkl_shape_output.SetMklLayout(&output_tf_pd);
mkl_shape_output.SetMklLayout(&OUTPUT_TF_MD);
mkl_shape_output.SetElemType(MklDnnType<T>());
mkl_shape_output.SetTfLayout(output_dims.size(), output_dims,
memory::format::blocked);
MKL_TENSOR_FORMAT_BLOCKED);

// We now simply forward input Mkl tensor to output and change its
// output MklDnnShape object.
Expand Down Expand Up @@ -280,7 +281,9 @@ class MklReshapeOp : public OpKernel {

TF_CALL_float(REGISTER_MKL_CPU);
TF_CALL_bfloat16(REGISTER_MKL_CPU);

#undef REGISTER_MKL_CPU

} // namespace tensorflow

#endif // INTEL_MKL
107 changes: 77 additions & 30 deletions tensorflow/core/kernels/mkl_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@ limitations under the License.
#ifdef INTEL_MKL

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::stream;
#ifndef ENABLE_MKLDNN_V1
using mkldnn::view;
#endif

namespace tensorflow {

Expand Down Expand Up @@ -177,8 +180,9 @@ struct MklSliceParams {
template <typename T>
class MklSlicePrimitive : public MklPrimitive {
public:
explicit MklSlicePrimitive(const MklSliceParams& sliceParams) {
context_.slice_stream.reset(new stream(stream::kind::eager));
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.slice_stream.reset(new CPU_STREAM(cpu_engine_));
Setup(sliceParams);
}

Expand All @@ -187,7 +191,13 @@ class MklSlicePrimitive : public MklPrimitive {
void Execute(const MklSliceParams& sliceParams) {
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());

#ifdef ENABLE_MKLDNN_V1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned in #36496 that we should refactor this part (lines 193-203) out (either to a real function or just a macro). If #36496 does that, please call the new refactored function/macro here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@penpornk That specific change is not addressable in near term. So No refactor is attempted. See the PR #36496 for comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@penpornk @mdfaijul Changes Done in the other PR. FYI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@penpornk Addressed the comments.

execute_primitives(context_.slice_primitives, context_.slice_stream,
context_.slice_primitives_args);
#else
context_.slice_stream->submit(context_.slice_primitives);
#endif

// We should set it back to DummyData so as to make the primitive
// in cache pool stateless. Otherwise, if the result for previous
Expand All @@ -206,25 +216,45 @@ class MklSlicePrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::memory> dst_mem;
std::shared_ptr<primitive> reorder_prim;
std::shared_ptr<reorder::primitive_desc> reorder_pd;
std::shared_ptr<view::primitive_desc> view_pd;
std::shared_ptr<mkldnn::stream> slice_stream;
std::vector<mkldnn::primitive> slice_primitives;
#ifdef ENABLE_MKLDNN_V1
std::shared_ptr<mkldnn::memory> src_sub_mem;
std::vector<std::unordered_map<int, memory>> slice_primitives_args;
#else
std::shared_ptr<view::primitive_desc> view_pd;
#endif // ENABLE_MKLDNN_V1
SliceContext()
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_;

engine cpu_engine_ = engine(engine::cpu, 0);
engine cpu_engine_;

void Setup(const MklSliceParams& sliceParams) {
// Actually, this DummyData will not be used in computation,
// because the real data will be filled before real execution.
context_.src_mem.reset(
new memory({sliceParams.from->get_primitive_desc().desc(), cpu_engine_},
DummyData));
context_.dst_mem.reset(new memory(
{sliceParams.to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
auto src_pd = context_.src_mem->get_primitive_desc();
auto dst_pd = context_.dst_mem->get_primitive_desc();
// Actually, DummyData will not be used in computation,
// because the real data will be filled before execution.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
sliceParams.from, cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
sliceParams.to, cpu_engine_, DummyData));
auto src_pd = context_.src_mem->GET_DESC;
auto dst_pd = context_.dst_mem->GET_DESC;
#ifdef ENABLE_MKLDNN_V1
// MKL-DNN 1.x removes struct view, alias of memory in 0.x version.
// So the implementation is based on submemory.
auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
sliceParams.size_dims, sliceParams.begin_dims);
context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
context_.reorder_pd = std::make_shared<reorder::primitive_desc>(
reorder::primitive_desc(*context_.src_sub_mem, *context_.dst_mem));
context_.reorder_prim =
std::make_shared<mkldnn::reorder>(reorder(*context_.reorder_pd));

context_.slice_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
#else
context_.view_pd =
std::make_shared<view::primitive_desc>(view::primitive_desc(
src_pd, sliceParams.size_dims, sliceParams.begin_dims));
Expand All @@ -233,6 +263,7 @@ class MklSlicePrimitive : public MklPrimitive {
context_.view_pd->dst_primitive_desc(), dst_pd));
context_.reorder_prim = std::make_shared<mkldnn::reorder>(
reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
#endif
context_.slice_primitives.push_back(*context_.reorder_prim);
}
};
Expand Down Expand Up @@ -263,27 +294,35 @@ class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
static string CreateKey(const MklSliceParams& sliceParams) {
string prefix = "reorder";
FactoryKeyCreator key_creator;
auto const& from_desc = sliceParams.from->get_primitive_desc().desc().data;
auto const& to_desc = sliceParams.to->get_primitive_desc().desc().data;
auto const& from_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.from).data;
auto const& to_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.to).data;
const int kIdxFirstStride = 0;
memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
memory::dims from_strides(
from_desc.layout_desc.blocking.strides[kIdxFirstStride],
&from_desc.layout_desc.blocking
.strides[kIdxFirstStride][from_desc.ndims]);
memory::dims to_strides(
to_desc.layout_desc.blocking.strides[kIdxFirstStride],
&to_desc.layout_desc.blocking.strides[kIdxFirstStride][to_desc.ndims]);

// MKL-DNN removes "struct view". Submemory has similar capability.
auto from_strides = from_desc.MEMORY_FORMAT_DESC.blocking.strides;
auto to_strides = to_desc.MEMORY_FORMAT_DESC.blocking.strides;
memory::dims from_strides_outer_blocks(
GET_BLOCK_STRIDES(from_strides, kIdxFirstStride),
&GET_BLOCK_STRIDES(from_strides, kIdxFirstStride)[from_desc.ndims]);
memory::dims to_strides_outer_blocks(
GET_BLOCK_STRIDES(to_strides, kIdxFirstStride),
&GET_BLOCK_STRIDES(to_strides, kIdxFirstStride)[to_desc.ndims]);

key_creator.AddAsKey(prefix);
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(from_desc.format));
#endif
key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
key_creator.AddAsKey(from_dims);
key_creator.AddAsKey(from_strides);
key_creator.AddAsKey(from_strides_outer_blocks);
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(to_desc.format));
#endif
key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
key_creator.AddAsKey(to_dims);
key_creator.AddAsKey(to_strides);
key_creator.AddAsKey(to_strides_outer_blocks);
key_creator.AddAsKey(sliceParams.begin_dims);
key_creator.AddAsKey(sliceParams.size_dims);
return key_creator.GetKey();
Expand Down Expand Up @@ -358,7 +397,7 @@ class MklSliceOp : public OpKernel {
// primitive descriptor. And the reorder uses source memory as input but
// traverses it according to a view in_submem_pd.

auto cpu_engine = engine(engine::cpu, 0);
auto cpu_engine = engine(ENGINE_CPU, 0);
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> output(&cpu_engine);

Expand Down Expand Up @@ -425,14 +464,22 @@ class MklSliceOp : public OpKernel {
// Or else do nothing for it.
auto op_md =
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
#ifdef ENABLE_MKLDNN_V1
src.CheckReorderToOpMem(op_md, cpu_engine);
#else
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
src.CheckReorderToOpMem(op_pd);
#endif

// Step 2 - Create memory for output.
auto output_strides = CalculateTFStrides(size_dims);
auto output_md =
MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
#ifdef ENABLE_MKLDNN_V1
auto output_pd = output_md;
#else
auto output_pd = memory::primitive_desc(output_md, cpu_engine);
#endif
AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
&output_tensor, &output_mkl_shape);
DCHECK(output_tensor);
Expand All @@ -447,9 +494,9 @@ class MklSliceOp : public OpKernel {
// Execute slice reorder.
reorder_prim->Execute(sliceParams);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
Expand All @@ -459,7 +506,7 @@ class MklSliceOp : public OpKernel {
private:
void AllocateOutputTensor(OpKernelContext* context,
const MklDnnShape& input_mkl_shape,
memory::primitive_desc* output_pd,
MEMORY_PRIMITIVE_DESC* output_pd,
const memory::dims& output_dims,
Tensor** output_tensor,
MklDnnShape* output_mkl_shape) {
Expand Down