Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MKL: Adding MKL-DNN Reshape op #14682

Merged
merged 1 commit into from
Dec 6, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
182 changes: 182 additions & 0 deletions tensorflow/core/kernels/mkl_reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"

#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
using mkldnn::stream;
#endif

namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
template <typename Device, typename T>
class MklReshapeOp : public OpKernel {
public:
explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}

#ifndef INTEL_MKL_DNN
void Compute(OpKernelContext* context) override {
const Tensor& input = MklGetInput(context, 0);
const Tensor& sizes = MklGetInput(context, 1);
Expand Down Expand Up @@ -129,7 +135,183 @@ class MklReshapeOp : public OpKernel {
}
}

#else

private:
// When the input tensor is in MKL layout and we are reshaping the tensor to a
// different shape than its actual shape, then we use MKLDNN reorder primitive
// to put tensor back in Tensorflow layout. But we can skip this reordering
// some times. This function checks for all such cases.
bool SkipReorder(const MklDnnShape& mkl_shape_input,
const TensorShape& reshape_to) {
CHECK_EQ(mkl_shape_input.IsMklTensor(), true);
bool ret = false;

// If Tensorflow's data format and the underlying format maintained by
// MKLDNN are equivalent (both are NHWC or both are NCHW), then we can
// safely return true.
auto input_mkl_md = mkl_shape_input.GetMklLayout();
if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) {
ret = true;
}

return ret;
}

public:
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = MklGetInput(context, 0);
const Tensor& sizes = MklGetInput(context, 1);

MklDnnShape mkl_shape_input;
GetMklShape(context, kInputSlotIdx, &mkl_shape_input);
bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
const int64 nelems = input_in_mkl_format ?
mkl_shape_input.GetTfShape().num_elements()
: input_tensor.NumElements();

// Preliminary validation of sizes.
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes.shape().DebugString()));

// Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one.
TensorShape shape;
int64 product = 1;
int unknown_index = -1;
switch (sizes.dtype()) {
case DT_INT32:
OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
&unknown_index, &shape));
break;
case DT_INT64:
OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
&unknown_index, &shape));
break;
default:
context->CtxFailure(errors::InvalidArgument(
"desired shape must be a DT_INT32 or DT_INT64 vector, not a ",
DataTypeString(sizes.dtype())));
return;
}
if (unknown_index != -1) {
OP_REQUIRES(
context, product > 0,
errors::InvalidArgument("Reshape cannot infer the missing input size "
"for an empty tensor unless all specified "
"input sizes are non-zero"));
const int64 missing = nelems / product;
OP_REQUIRES(
context, product * missing == nelems,
errors::InvalidArgument(
"Input to reshape is a tensor with ", nelems,
" values, but the requested shape requires a multiple of ",
product));
shape.set_dim(unknown_index, missing);
}
OP_REQUIRES(context, shape.num_elements() == nelems,
errors::InvalidArgument("Input to reshape is a tensor with ",
nelems,
" values, but the requested shape has ",
shape.num_elements()));

if (input_in_mkl_format) {
TensorShape& shape_to = shape;
TensorShape shape_from = mkl_shape_input.GetTfShape();
if (shape_from == shape_to) {
CopyMklTensorInToOut(context, kInputSlotIdx, kOutputSlotIdx);
return;
} else {
try {
auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> dnn_data_input(&cpu_engine);
// Reshape is just a logical view change operation for a tensor.
// It does not change underlying layout. But MKLDNN may maintain
// tensor data in different layout than that specified by Tensorflow.
// If MKLDNN maintains input tensor in different layout than that
// specified by Tensorflow, we will need to reorder tensor and then
// put it in the shape expected by Tensorflow. But if MKLDNN has
// maintained input tensor in the same layout as it is expected by
// Tensorflow, we don't need to reorder tensor contents, we just
// need to update MklDnnShape object associated with the input
// tensor to reflect the shape change expected by reshape.
if (!SkipReorder(mkl_shape_input, shape_to)) {
// If dimensions that are being expanded or collapsed are not
// maintained contiguously by MKLDNN, then we use reorder.

// Get Mkl layout of input tensor.
auto input_mkl_md = mkl_shape_input.GetMklLayout();
// Set input Mkl layout as the user layout.
dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor);
// Get expected Tensorflow layout of input tensor.
auto output_tf_md = mkl_shape_input.GetTfLayout();
auto output_tf_pd = memory::primitive_desc(output_tf_md,
cpu_engine);

Tensor* output_tensor = nullptr;
MklShape mkl_shape_output;
mkl_shape_output.SetMklTensor(false);
// We allocate output tensor in the shape expected by Reshape.
AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor,
shape_to, mkl_shape_output);

// Insert reorder between Mkl layout and TensorFlow layout.
std::vector<primitive> net;
CHECK_EQ(dnn_data_input.CheckReorderToOpMem(output_tf_pd,
output_tensor, &net), true);
stream(stream::kind::eager).submit(net).wait();
return;
} else {
// If dimensions that are being expanded or collapsed are
// maintained contiguously by MKLDNN, then we skip reorder, just
// update MklDnnShape object for the tensorflow tensor, and forward
// Tensorflow tensor as it is to the output.
auto output_dims = TFShapeToMklDnnDims(shape_to);
auto output_strides = CalculateTFStrides(output_dims);
auto output_tf_md = MklDnnData<T>::CreateBlockedMemDesc(output_dims,
output_strides);
auto output_tf_pd = memory::primitive_desc(output_tf_md,
cpu_engine);

// Set MklDnnShape
MklDnnShape mkl_shape_output;
mkl_shape_output.SetMklTensor(true);
mkl_shape_output.SetMklLayout(&output_tf_pd);
mkl_shape_output.SetElemType(MklDnnType<T>());
mkl_shape_output.SetTfLayout(output_dims.size(), output_dims,
memory::format::blocked);

// We now simply forward input Mkl tensor to output and change its
// output MklDnnShape object.
ForwardMklTensorInToOutWithMklShape(context, kInputSlotIdx,
kOutputSlotIdx, mkl_shape_output);
return;
}
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) +
", in file " + string(__FILE__) + ":" +
std::to_string(__LINE__);
OP_REQUIRES_OK(context,
errors::Aborted("Operation received an exception:",
error_msg));
}
}
} else {
// If input tensor is not in Mkl format, then just copy Tensorflow tensor
// to output with specified shape.
CopyTfTensorInToOutWithShape(context, kInputSlotIdx, kOutputSlotIdx,
shape);
}
}

#endif // INTEL_MKL_DNN

private:
const int kInputSlotIdx = 0;
const int kOutputSlotIdx = 0;

template <typename Tshape>
Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
TensorShape* shape) {
Expand Down