Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MKL: cifar 10 divergance fix and batchnorm unit test fix #17004

Merged
merged 2 commits into from
Mar 1, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 65 additions & 31 deletions tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1110,19 +1110,12 @@ class MklFusedBatchNormGradOp : public OpKernel {
return;
}

if (dnn_shape_src.IsMklTensor())
depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
else
ExtractParams(context);

memory::format format_m;
if (dnn_shape_src.IsMklTensor()) {
if (dnn_shape_src.IsTensorInNCHWFormat())
format_m = memory::format::nchw;
else
format_m = memory::format::nhwc;
depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
} else if (dnn_shape_diff_dst.IsMklTensor()) {
depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C);
} else {
format_m = TFDataFormatToMklDnnDataFormat(tensor_format_);
ExtractParams(context);
}

MklDnnData<T> src(&cpu_engine);
Expand All @@ -1146,20 +1139,20 @@ class MklFusedBatchNormGradOp : public OpKernel {
diff_dst_dims =
TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_);

// set src and diff_dst primitives
// set src and diff_dst primitives according to input layout
memory::desc src_md({}, memory::data_undef, memory::format_undef);
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout();
diff_dst_md = src_md;
} else {
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
src_md = diff_dst_md;
}
if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout();
} else {
src_md = memory::desc(src_dims, MklDnnType<T>(), format_m);
diff_dst_md = src_md;
src_md = memory::desc(src_dims, MklDnnType<T>(),
TFDataFormatToMklDnnDataFormat(tensor_format_));
}
if (dnn_shape_diff_dst.IsMklTensor()) {
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
} else {
diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
TFDataFormatToMklDnnDataFormat(tensor_format_));
}
src.SetUsrMem(src_md, &src_tensor);
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
Expand Down Expand Up @@ -1211,28 +1204,64 @@ class MklFusedBatchNormGradOp : public OpKernel {
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
if (dnn_shape_src.IsMklTensor()) {

// MKL-DNN's BN primitive not provide API to fetch internal format
// set common_md as OpMem
// src and diff_dst will reorder to common_md
// diff_src will set as common_md
memory::desc common_md({}, memory::data_undef, memory::format_undef);
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
if (dnn_shape_src.IsMklTensor()) {
common_md = dnn_shape_src.GetMklLayout();
} else {
common_md = dnn_shape_diff_dst.GetMklLayout();
}
} else {
common_md = memory::desc(src_dims, MklDnnType<T>(),
TFDataFormatToMklDnnDataFormat(tensor_format_));
}
// if any of src and diff_dst as mkl layout,
// then we set diff_src as mkl layout
if (dnn_shape_src.IsMklTensor() ||
dnn_shape_diff_dst.IsMklTensor()) {
dnn_shape_diff_src.SetMklTensor(true);
auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc();
// set diff_src's mkl layout as common_md
auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine);
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), src_dims,
format_m);
dnn_shape_diff_src.SetTfDimOrder(dnn_shape_src.GetDimension(),
tensor_format_);
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_diff_src.SetTfLayout(
dnn_shape_src.GetDimension(),
src_dims,
dnn_shape_src.GetTfDataFormat());
dnn_shape_diff_src.SetTfDimOrder(
dnn_shape_src.GetDimension(),
tensor_format_);
} else {
dnn_shape_diff_src.SetTfLayout(
dnn_shape_diff_dst.GetDimension(),
src_dims,
dnn_shape_diff_dst.GetTfDataFormat());
dnn_shape_diff_src.SetTfDimOrder(
dnn_shape_diff_dst.GetDimension(),
tensor_format_);
}
tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
dnn_shape_diff_src.SetMklTensor(false);
// both src and diff_dst are tf layout,
// so get tf shape from anyont should be ok
tf_shape_diff_src = src_tensor.shape();
}
AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
tf_shape_diff_src, dnn_shape_diff_src);

diff_src.SetUsrMem(src_md, diff_src_tensor);
// set diff_src
diff_src.SetUsrMem(common_md, diff_src_tensor);

prop_kind pk = prop_kind::backward;
auto bnrm_bwd_desc = batch_normalization_backward::desc(
pk, diff_src.GetUsrMemDesc(), src.GetUsrMemDesc(), epsilon_,
pk, common_md, common_md, epsilon_,
/* for inference, specify use_global_stats
1. on fwd prop, use mean and variance
provided as inputs
Expand All @@ -1245,11 +1274,16 @@ class MklFusedBatchNormGradOp : public OpKernel {
auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc(
bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd);

std::vector<primitive> net;
src.CheckReorderToOpMem(memory::primitive_desc(common_md,
cpu_engine), &net);
diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md,
cpu_engine), &net);

auto bnrm_bwd_op = batch_normalization_backward(
bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(),
diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m);

std::vector<primitive> net;
net.push_back(bnrm_bwd_op);
stream(stream::kind::eager).submit(net).wait();

Expand Down
20 changes: 16 additions & 4 deletions tensorflow/core/kernels/mkl_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,11 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.MklCleanup();
}



#else // INTEL_MKL_ML


template <typename Device, typename T, algorithm alg_kind>
class MklReluOpBase : public OpKernel {
public:
Expand Down Expand Up @@ -579,17 +582,26 @@ class MklReluGradOpBase : public OpKernel {
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
if (dnn_shape_src.IsMklTensor()) {
if (dnn_shape_src.IsMklTensor() ||
dnn_shape_diff_dst.IsMklTensor()) {
dnn_shape_diff_src.SetMklTensor(true);
auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc();
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
dnn_shape_src.GetSizesAsMklDnnDims(),
dnn_shape_src.GetTfDataFormat());
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
dnn_shape_src.GetSizesAsMklDnnDims(),
dnn_shape_src.GetTfDataFormat());
} else {
dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(),
dnn_shape_diff_dst.GetSizesAsMklDnnDims(),
dnn_shape_diff_dst.GetTfDataFormat());
}
tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
dnn_shape_diff_src.SetMklTensor(false);
// both src and diff_dst are tf layout,
// so get tf shape from anyone should be ok
tf_shape_diff_src = src_tensor.shape();
}
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
Expand Down