Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel-MKL] Support for N-D Transpose using MKL-DNN #20066

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 12 additions & 2 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ package(

load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)

tf_kernel_library(
name = "xla_ops",
Expand Down Expand Up @@ -140,8 +144,14 @@ tf_kernel_library(
"//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:training_ops",
"//tensorflow/core/kernels:transpose_op",
],
] + if_mkl(
[
"//tensorflow/core/kernels:mkl_transpose_op",
],
[
"//tensorflow/core/kernels:transpose_op",
],
),
)

tf_kernel_library(
Expand Down
42 changes: 29 additions & 13 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,14 @@ cc_library(
":split_v_op",
":strided_slice_op",
":tile_ops",
":transpose_op",
] + if_mkl(
[
":mkl_transpose_op",
],
[
":transpose_op",
],
) + [
":unique_op",
":unpack_op",
":unravel_index_op",
Expand Down Expand Up @@ -885,18 +892,27 @@ tf_kernel_library(
deps = ARRAY_DEPS,
)

tf_kernel_library(
name = "transpose_op",
srcs = [
"transpose_op.cc",
] + if_mkl([
"mkl_transpose_op.cc",
]),
hdrs = ["transpose_op.h"],
deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn",
]),
if_mkl(
[tf_mkl_kernel_library(
name = "mkl_transpose_op",
srcs = [
"transpose_op.cc",
"mkl_transpose_op.cc",
],
hdrs = ["transpose_op.h"],
deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn",
]),
)],
[tf_kernel_library(
name = "transpose_op",
srcs = [
"transpose_op.cc",
],
hdrs = ["transpose_op.h"],
deps = ARRAY_DEPS,
)],
)

tf_kernel_library(
Expand Down
102 changes: 99 additions & 3 deletions tensorflow/core/kernels/mkl_transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@ limitations under the License.

// See docs in ../ops/array_ops.cc.

#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#if defined(INTEL_MKL)
#define EIGEN_USE_THREADS

#if !defined(DO_NOT_USE_ML)
#include "mkl_trans.h"
#endif

#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/kernels/transpose_op.h"

#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
#include "tensorflow/core/util/mkl_util.h"

using mkldnn::stream;
#endif

namespace tensorflow {

// output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
Expand All @@ -40,6 +50,7 @@ namespace tensorflow {
// REQUIRES: perm is a permutation.

namespace {
#if !defined(DO_NOT_USE_ML)
template <typename T>
Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);

Expand Down Expand Up @@ -93,11 +104,67 @@ Status MKLTranspose2D<complex128>(const char trans, const Tensor& in,
static const char kMKLTranspose = 'T';
static const char kMKLConjugateTranspose = 'C';

#endif // if !defined(DO_NOT_USE_ML)

#ifndef INTEL_MKL_ML
// MKL-DNN based Transpose implementation
template <typename T>
Status MKLTransposeND(OpKernelContext* ctx, const Tensor& in, Tensor* out,
const gtl::ArraySlice<int32>& perm);


static inline memory::dims ReorderStrides(const memory::dims& strides,
const gtl::ArraySlice<int32>& perm) {
memory::dims reordered_strides;
reordered_strides.resize(strides.size());
for (size_t i = 0; i < strides.size(); ++i) {
reordered_strides[perm[i]] = strides[i];
}
return reordered_strides;
}

// Transpose of N-dimensional tensor using MKL-DNN
template<typename T>
Status MKLTransposeND(OpKernelContext* context,
const Tensor& in_tensor, Tensor* out_tensor,
const gtl::ArraySlice<int32>& perm) {
try {
engine cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> in(&cpu_engine);
MklDnnData<T> out(&cpu_engine);

memory::dims in_dims = TFShapeToMklDnnDims(in_tensor.shape());
memory::dims out_dims = TFShapeToMklDnnDims(out_tensor->shape());
memory::dims in_strides = CalculateTFStrides(in_dims);
// Reorder output strides based on permutation requested.
memory::dims out_strides = ReorderStrides(CalculateTFStrides(out_dims),
perm);

in.SetUsrMem(in_dims, in_strides, &in_tensor);
// Output dimensions are same as input dimensions. We adjust the layout
// using strides.
out.SetUsrMem(in_dims, out_strides, out_tensor);

std::vector<primitive> net;
net.push_back(in.CreateReorder(in.GetUsrMem(), out.GetUsrMem()));
stream(stream::kind::eager).submit(net).wait();
return Status::OK();
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + std::string(e.message) +
", in file " + std::string(__FILE__) + ":" +
std::to_string(__LINE__);
return errors::Aborted("Operation received an exception:", error_msg);
}
}
#endif // #ifndef INTEL_MKL_ML

} // namespace

Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
gtl::ArraySlice<int32> perm,
Tensor* out) {
#if !defined(DO_NOT_USE_ML)
if (in.dims() == 2) {
if (perm[0] == 0 && perm[1] == 1) {
return Status::OK();
Expand All @@ -115,7 +182,21 @@ Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
break;
}
}
// Fallback to eigen if transpose parameters not supported by MKL
#endif

#ifndef INTEL_MKL_ML
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
// Fallback to Eigen for not supported cases.
if (in.dims() <= TENSOR_MAX_DIMS) {
switch (in.dtype()) {
case DT_FLOAT: return MKLTransposeND<float>(ctx, in, out, perm); break;
// TODO(nhasabni): support other types such as INT8.
default: break;
}
}
#endif

// Fallback to eigen if transpose parameters not supported by MKL or MKL-DNN
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
out);
Expand All @@ -125,6 +206,7 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
const Tensor& in,
gtl::ArraySlice<int32> perm,
Tensor* out) {
#if !defined(DO_NOT_USE_ML)
if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) {
// TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
// for any transpose that can be reduced to swapping the last two
Expand All @@ -143,7 +225,21 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
break;
}
}
// Fallback to eigen if transpose parameters not supported by MKL
#endif

#ifndef INTEL_MKL_ML
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
// Fallback to Eigen for not supported cases.
if (in.dims() <= TENSOR_MAX_DIMS) {
switch (in.dtype()) {
case DT_FLOAT: return MKLTransposeND<float>(ctx, in, out, perm); break;
// TODO(nhasabni): support other types such as INT8.
default: break;
}
}
#endif

// Fallback to eigen if transpose parameters not supported by MKL or MKL-DNN
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
perm, out);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}

#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#if defined(INTEL_MKL)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/transpose_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TransposeCpuOp : public TransposeOp {
gtl::ArraySlice<int32> perm, Tensor* out) override;
};

#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#if defined(INTEL_MKL)
class MklTransposeCpuOp : public TransposeOp {
public:
explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
Expand Down Expand Up @@ -85,7 +85,7 @@ class ConjugateTransposeCpuOp : public TransposeOp {
bool IsConjugate() const override { return true; }
};

#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#if defined(INTEL_MKL)
class MklConjugateTransposeCpuOp : public TransposeOp {
public:
explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx)
Expand Down