Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTEL MKL: Fix concat related issues #19065

Merged
merged 2 commits into from
May 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
213 changes: 153 additions & 60 deletions tensorflow/core/kernels/mkl_concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.

#include <limits>
#include <vector>
#include <unordered_map>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
Expand Down Expand Up @@ -589,8 +590,8 @@ class MklConcatOp : public OpKernel {
const int N = input_tensors.size();

// Get Tensor shapes.
std::vector<MklDnnShape> input_shapes(N);
GetMklShapeList(context, "values", &input_shapes);
std::vector<MklDnnShape> mkl_input_shapes(N);
GetMklShapeList(context, "values", &mkl_input_shapes);

const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM)
? MklGetInput(context, 0)
Expand All @@ -609,19 +610,14 @@ class MklConcatOp : public OpKernel {
int i = 0;
bool invoke_eigen = false;
bool are_all_mkl_inputs = true, are_all_tf_inputs = true;
const TensorShape expected_shape = input_shapes[0].IsMklTensor()
? input_shapes[0].GetTfShape()
: input_tensors[0].shape();
const TensorShape expected_shape = mkl_input_shapes[0].IsMklTensor()
? mkl_input_shapes[0].GetTfShape()
: input_tensors[0].shape();
size_t expected_dims = expected_shape.dims();

if (concat_dim < 0) concat_dim = expected_dims + concat_dim;

for (auto& s : input_shapes) {
if (s == expected_shape) {
++i;
continue;
}

for (auto& s : mkl_input_shapes) {
TensorShape s_shape =
s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape();
size_t s_dims = s_shape.dims();
Expand Down Expand Up @@ -664,21 +660,14 @@ class MklConcatOp : public OpKernel {

// Call Eigen library
if (invoke_eigen) {
TensorShapeList tf_input_shapes;
i = 0;
for (auto& s : input_shapes) {
TensorShape s_shape =
s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape();
tf_input_shapes.push_back(s_shape);
++i;
}
CallEigenVersion(context, input_tensors, tf_input_shapes);
CallEigenVersion(context, input_tensors, mkl_input_shapes);
return;
}

memory::dims dst_dims;

if (are_all_mkl_inputs)
dst_dims = TFShapeToMklDnnDims(input_shapes[0].GetTfShape());
dst_dims = TFShapeToMklDnnDims(mkl_input_shapes[0].GetTfShape());
else
// When all the inputs are in Tensorflow format, we don't know
// what is the input data format. In that case, we just use
Expand All @@ -688,26 +677,61 @@ class MklConcatOp : public OpKernel {
std::vector<memory::primitive_desc> srcs_pd;
std::vector<MklDnnData<T>> srcs(N, MklDnnData<T>(&cpu_engine));
int64 dst_concat_dim_size = 0;
for (int k = 0; k < N; k++) {
bool is_mkl_tensor = input_shapes[k].IsMklTensor();
memory::dims src_dims;

// Same comment as dst_dims for src_dims.
src_dims = (is_mkl_tensor)
? TFShapeToMklDnnDims(input_shapes[k].GetTfShape())
: TFShapeToMklDnnDims(input_tensors[k].shape());

dst_concat_dim_size += src_dims[concat_dim];
auto src_md =
is_mkl_tensor ? input_shapes[k].GetMklLayout() :
// It does not matter what data format we use here
// (NHWC or NCHW). We just need to ensure that output
// of Concat uses same data format as input.
memory::desc(src_dims, MklDnnType<T>(), memory::format::nchw);

srcs[k].SetUsrMem(src_md, &input_tensors[k]);
auto src_mpd = srcs[k].GetUsrMemPrimDesc();
srcs_pd.push_back(src_mpd);

bool isMklReorderNeeded = false;
memory::format mkl_common_format = memory::format::any;
if (are_all_mkl_inputs) {
mkl_common_format =
FindMklCommonFormat(mkl_input_shapes, concat_dim,
&isMklReorderNeeded, &dst_concat_dim_size);

if (!isMklReorderNeeded) {
// All MKL tensors have a same format. Reorder is not needed.
for (int k = 0; k < N; k++) {
if (input_tensors[k].NumElements() == 0)
continue;

auto src_md = mkl_input_shapes[k].GetMklLayout();
srcs[k].SetUsrMem(src_md, &input_tensors[k]);
auto src_mpd = srcs[k].GetUsrMemPrimDesc();
srcs_pd.push_back(src_mpd);
}
} else {
// MKL tensors have different formats.
// Reorder them to most common format.
for (int k = 0; k < N; k++) {
if (input_tensors[k].NumElements() == 0)
continue;

auto src_dims = TFShapeToMklDnnDims(
mkl_input_shapes[k].GetTfShape());
auto src_md = mkl_input_shapes[k].GetMklLayout();
srcs[k].SetUsrMem(src_md, &input_tensors[k]);

if (src_md.data.format != mkl_common_format)
src_md = memory::desc(src_dims, MklDnnType<T>(),
mkl_common_format);

srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine));
}
}
} else { // All TF inputs
for (int k = 0; k < N; k++) {
if (input_tensors[k].NumElements() == 0)
continue;

memory::dims src_dims = TFShapeToMklDnnDims(input_tensors[k].shape());
dst_concat_dim_size += src_dims[concat_dim];

// It does not matter what data format to be used (NHWC versus NCHW).
// We just need to ensure that output uses same data format as inputs.
auto src_md =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand your comment correctly: For clarity, should we copy the format from the input rather than hardcoding it to NCHW here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently there is no "data_format" attribute for Tensorflow tensors.
So it is not possible to copy format from them.
Only Mkl tensors tracks the formats (NHWC, NCHW, or MKL block format etc).

memory::desc(src_dims, MklDnnType<T>(), memory::format::nchw);

srcs[k].SetUsrMem(src_md, &input_tensors[k]);
auto src_mpd = srcs[k].GetUsrMemPrimDesc();
srcs_pd.push_back(src_mpd);
}
}
dst_dims[concat_dim] = dst_concat_dim_size;

Expand All @@ -717,25 +741,33 @@ class MklConcatOp : public OpKernel {
if (are_all_mkl_inputs) {
// Since we are passing a specific format for destination,
// we need to have dst_dims in MklDnn order (NCHW).
auto orig_tf_format = input_shapes[0].GetTfDataFormat();
auto orig_tf_format = mkl_input_shapes[0].GetTfDataFormat();
dst_dims_in_nchw = MklDnnDimsInNCHW(
dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format));
// We will set the output in the same format as input to avoid layout
// conversions.
// Currently we are setting dst format same as input format.
// See if we can make this choice in a better way.
// Set the output format same as the most common format of inputs
// to avoid layout conversions.
dst_md = memory::desc(
dst_dims_in_nchw, MklDnnType<T>(),
(memory::format)input_shapes[0].GetMklLayout().data.format);
dst_dims_in_nchw, MklDnnType<T>(), mkl_common_format);
} else {
// Again, format does not matter here. We just need to make it same as
// input format.
// All inputs are TF tensors.
// Set the output format same as input format (nchw).
dst_md = memory::desc(dst_dims, MklDnnType<T>(), memory::format::nchw);
}

std::vector<primitive::at> inputs;
for (int k = 0; k < input_tensors.size(); k++)
inputs.push_back(srcs[k].GetOpMem());
std::vector<primitive> net;
if (isMklReorderNeeded) {
for (int k = 0; k < input_tensors.size(); k++) {
if (input_tensors[k].NumElements() > 0) {
srcs[k].CheckReorderToOpMem(srcs_pd[k], &net);
}
}
}
for (int k = 0; k < input_tensors.size(); k++) {
if (input_tensors[k].NumElements() > 0) {
inputs.push_back(srcs[k].GetOpMem());
}
}

// If all inputs are in MKL format, then meaning of concat_dim needs to
// change. Value of concat_dim is tied to input Tensorflow data format
Expand All @@ -744,7 +776,8 @@ class MklConcatOp : public OpKernel {
// But ifinput tensors are in NHWC order, then semantics need to change.
// E.g., if we are concatinating over Channel (dimension 3 for NHWC),
// then since MklDnn order is NCHW, concat_dim needs to be 1.
if (are_all_mkl_inputs) concat_dim = input_shapes[0].TfDimIdx(concat_dim);
if (are_all_mkl_inputs)
concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim);

auto concat_pd = concat::primitive_desc(dst_md, concat_dim, srcs_pd);

Expand All @@ -757,7 +790,7 @@ class MklConcatOp : public OpKernel {
dnn_shape_dst.SetMklLayout(&dst_pd);
dnn_shape_dst.SetElemType(MklDnnType<T>());
dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw,
input_shapes[0].GetTfDataFormat());
mkl_input_shapes[0].GetTfDataFormat());
tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T)));
} else {
dnn_shape_dst.SetMklTensor(false);
Expand All @@ -772,7 +805,6 @@ class MklConcatOp : public OpKernel {
dst.SetUsrMem(dst_md, dst_tensor);

auto concat_op = concat(concat_pd, inputs, dst.GetOpMem());
std::vector<primitive> net;
net.push_back(concat_op);
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error& e) {
Expand All @@ -786,15 +818,27 @@ class MklConcatOp : public OpKernel {
}

void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
const TensorShapeList& input_shapes) {
CHECK_EQ(values.size(), input_shapes.size());
const MklDnnShapeList& mkl_input_shapes) {
CHECK_EQ(values.size(), mkl_input_shapes.size());

std::vector<Tensor> converted_values;
for (int i = 0; i < input_shapes.size(); i++)
converted_values.push_back(values[i]);
TensorShapeList tf_input_shapes;
for (int i = 0; i < mkl_input_shapes.size(); i++) {
if (mkl_input_shapes[i].IsMklTensor()) {
// do conversion from MKL to TF
Tensor tmp_tensor =
ConvertMklToTF<T>(context, values[i], mkl_input_shapes[i]);
converted_values.push_back(tmp_tensor);
tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape());
} else {
// no conversion since it is TF tensor already
converted_values.push_back(values[i]);
tf_input_shapes.push_back(values[i].shape());
}
}

// Call Eigen concat.
eigen_concat_op_.Compute(context, converted_values, input_shapes);
eigen_concat_op_.Compute(context, converted_values, tf_input_shapes);

// Set output Mkl tensor for this op.
MklDnnShape dnn_shape_output;
Expand All @@ -811,6 +855,55 @@ class MklConcatOp : public OpKernel {
output_tensor->flat<uint8>().data(),
output_tensor->flat<uint8>().size() * sizeof(uint8));
}

// This method finds the most commom format accross all MKL inputs
// Inputs:
// 1. input_shapes: shapes of input (MKL) tensors.
// 2. concat_dim: concat dimension.
// Outputs:
// 1. is_reorder_needed is set to true if inputs have difference formats
// It is set to false otherwise.
// 2. concat_dim_size is the size of concat_dim.
// Return:
// return the common MKL format.
memory::format FindMklCommonFormat(const MklDnnShapeList& input_shapes,
int concat_dim, bool* is_reorder_needed, int64* concat_dim_size) {
*is_reorder_needed = false;
*concat_dim_size = 0;
std::unordered_map<memory::format, int> occurrence_map;
if (input_shapes.size() == 0)
return memory::format::any;

// Compute ocurrences of each format of all inputs.
for (int k=0; k <input_shapes.size(); k++) {
auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape());
*concat_dim_size += src_dims[concat_dim];
memory::format fmt = static_cast<memory::format>(
input_shapes[k].GetMklLayout().data.format);
occurrence_map[fmt] += 1;
}

if (occurrence_map.size() == 1) {
// this means that all inputs have a same format
// return it with is_reorder_needed set false.
return static_cast<memory::format>(
input_shapes[0].GetMklLayout().data.format);
}

// Input tensors have different formats. Thus, reorder is needed.
// We pick up the most common format to minimize the total
// number of input reorder.
memory::format commonest_format = memory::format::any;
int max_occurrence = 0;
*is_reorder_needed = true;
for (auto item : occurrence_map) {
if (item.second > max_occurrence) {
commonest_format = item.first;
max_occurrence = item.second;
}
}
return commonest_format;
}
};

#endif
Expand Down
43 changes: 38 additions & 5 deletions tensorflow/core/util/mkl_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,15 +706,48 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
return output_tensor;
}
#else
using mkldnn::stream;
template <typename T> class MklDnnData;

template <typename T>
inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
const MklDnnShape& mkl_shape) {
Tensor output_tensor;
TensorShape output_shape;

TF_CHECK_OK(
Status(error::Code::UNIMPLEMENTED, "Unimplemented conversion function"));

try {
if (!mkl_shape.IsMklTensor())
return mkl_tensor; // return input since it is already TF tensor

TensorShape output_shape = mkl_shape.GetTfShape();;

// Allocate output tensor.
context->allocate_temp(DataTypeToEnum<T>::v(),
output_shape, &output_tensor);

auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> input(&cpu_engine);

// Get Mkl layout of input tensor.
auto input_mkl_md = mkl_shape.GetMklLayout();
auto output_tf_md = mkl_shape.GetTfLayout();
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
input.SetUsrMem(input_mkl_md, &mkl_tensor);

// reorder
if (input.IsReorderNeeded(output_tf_pd)) {
std::vector<primitive> net;
CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, &output_tensor, &net),
true);
stream(stream::kind::eager).submit(net).wait();
} else {
// If not, just forward input tensor to output tensor.
CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape));
}
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
LOG(FATAL) << "Operation received an exception: " << error_msg;
}
return output_tensor;
}
#endif
Expand Down