Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Enable 3d shapes in the MklSlice op #31432

Closed
Closed
12 changes: 9 additions & 3 deletions tensorflow/core/kernels/mkl_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#ifdef INTEL_MKL

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand All @@ -27,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::stream;
using mkldnn::view;
Expand Down Expand Up @@ -313,6 +313,7 @@ class MklSliceOp : public OpKernel {
bool done = false;

CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);

if (!context->status().ok() || done == true) return;

// Though MKL-DNN supports more than 8 dimension and
Expand Down Expand Up @@ -394,8 +395,13 @@ class MklSliceOp : public OpKernel {
if (input_mkl_shape.IsMklTensor()) {
auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);

bool is_slice2d = (input_mkl_shape.GetDimension() == 4);
begin_dims = is_slice2d
? MklDnnDimsInNCHW(begin_dims, input_tf_format)
: MklDnnDimsInNCDHW(begin_dims, input_tf_format);
size_dims = is_slice2d ? MklDnnDimsInNCHW(size_dims, input_tf_format)
: MklDnnDimsInNCDHW(size_dims, input_tf_format);
auto input_md = input_mkl_shape.GetMklLayout();
src.SetUsrMem(input_md, &input_tensor);

Expand Down
22 changes: 20 additions & 2 deletions tensorflow/core/util/mkl_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -1101,8 +1101,8 @@ inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
return memory::dims({n, c, d, h, w});
}

/// Overloaded version of function above. Input parameters are
/// self-explanatory.
/// Overloaded version of function TFShapeToMklDnnDimsInNCHW above.
/// Input parameters are self-explanatory.
inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
TensorFormat format) {
// Validate format.
Expand All @@ -1117,6 +1117,24 @@ inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
return memory::dims({n, c, h, w});
}

/// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above.
/// Input parameters are self-explanatory.
inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims,
TensorFormat format) {
// Validate format.
CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
Copy link
Member

@penpornk penpornk Aug 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please change this to DCHECK_NE?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind. I think I'll fix it internally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

memory::format::format_undef);

int n = in_dims[GetTensorDimIndex<3>(format, 'N')];
int c = in_dims[GetTensorDimIndex<3>(format, 'C')];
int d = in_dims[GetTensorDimIndex<3>(format, '0')];
int h = in_dims[GetTensorDimIndex<3>(format, '1')];
int w = in_dims[GetTensorDimIndex<3>(format, '2')];

// MKL DNN requires dimensions in NCDHW format.
return memory::dims({n, c, d, h, w});
}

/// Map MklDnn memory::dims object into TensorShape object.
///
/// This function will simply map input shape in MKL-DNN memory::dims format
Expand Down
32 changes: 31 additions & 1 deletion tensorflow/python/kernel_tests/slice_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test


class SliceTest(test.TestCase):

def testEmpty(self):
Expand Down Expand Up @@ -146,6 +146,36 @@ def testSingleDimension(self):
slice_val = self.evaluate(slice_t)
self.assertAllEqual(slice_val, inp[lo:hi])

def test3Dimension(self):
with self.session():
input_shape = [8, 16, 16, 16, 8]
total_input_size = 1
for s in input_shape:
total_input_size *= s
inputs = [i * 1.0 / total_input_size for i in range(1, total_input_size
+ 1)]
a = constant_op.constant(inputs, shape=input_shape,
dtype=dtypes.float32)

filter_shape = [1, 1, 1, 8, 8]
total_filter_size = 1
for s in filter_shape:
total_filter_size *= s
filters = [i * 1.0 / total_filter_size for i in range(1,
total_filter_size + 1)]
f = constant_op.constant(filters, shape=filter_shape,
dtype=dtypes.float32)

conv_t = nn_ops.conv3d(a,
filter=f,
strides=[1, 1, 1, 1, 1],
padding="VALID")
slice_t = array_ops.slice(conv_t, [0, 1, 1, 1, 0], [1, 1, 1, 1, 8])
result = self.evaluate(slice_t)
penpornk marked this conversation as resolved.
Show resolved Hide resolved
expected = [0.03028321, 0.03132677, 0.03237033, 0.03341389,
0.03445745, 0.035501, 0.03654456, 0.03758812]
self.assertAllClose(expected, result.flatten(), rtol=1e-6)

@test_util.run_deprecated_v1
def testScalarInput(self):
input_val = 0
Expand Down