Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Fix incorrect DNNL1.2 integration in pooling backprop #38137

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 9 additions & 7 deletions tensorflow/core/kernels/mkl_avgpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
: memory::desc(orig_input_dims_mkl_order, MklDnnType<T>(),
this->data_format_mkldnn_);

// Get diff_dst memory::desc.
memory::desc diff_dst_md =
grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);

// Pass prop_kind::forward_training to create a forward primitive
// that is used in the backward pass.
#ifdef ENABLE_MKLDNN_V1
Expand All @@ -241,7 +248,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right,
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md,
diff_dst_md);
#else
MklPoolingParams bwdParams(
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
Expand All @@ -256,12 +264,6 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
orig_input_dims_mkl_order,
this->tensor_format_mkldnn_, &output_tensor);
// Get diff_dst memory::desc.
memory::desc diff_dst_md =
grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);

// TODO(nammbash): Refactor (lines 249-262) common code for
// max & avg pooling into superclass or common utils function.
Expand Down
17 changes: 9 additions & 8 deletions tensorflow/core/kernels/mkl_maxpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,21 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
: memory::desc(orig_input_dims_mkl_order, MklDnnType<T>(),
this->data_format_mkldnn_);

// Get diff_dst memory descriptor.
memory::desc diff_dst_md =
grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);

#ifdef ENABLE_MKLDNN_V1
// TODO(DNNL): Find out what should be used for src_md.data.format.
MklPoolingParams bwdParams(
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right, ALGORITHM::pooling_max,
prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md,
diff_dst_md);
#else
MklPoolingParams bwdParams(
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
Expand All @@ -320,13 +328,6 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
orig_input_dims_mkl_order,
this->tensor_format_mkldnn_, &output_tensor);

// Get diff_dst memory descriptor.
memory::desc diff_dst_md =
grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);

// Check if diff_dst needs to be reordered.
std::shared_ptr<PoolingBwdPd> pooling_bwd_pd =
pooling_bwd->GetPoolingBwdPd();
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/mkl_pooling_ops_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,12 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
// Create memory descriptor.
context_.diff_src_md.reset(new memory::desc(
{bwdParams.src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
#ifndef ENABLE_MKLDNN_V1
context_.diff_dst_md.reset(new memory::desc(
{bwdParams.dst_dims}, MklDnnType<T>(), bwdParams.src_format));
#else
context_.diff_dst_md.reset(new memory::desc(bwdParams.diff_dst_md.data));
#endif // !ENABLE_MKLDNN_V1

#ifndef ENABLE_MKLDNN_V1
context_.bwd_desc.reset(new pooling_backward::desc(
Expand Down
17 changes: 16 additions & 1 deletion tensorflow/core/kernels/mkl_pooling_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,20 @@ struct MklPoolingParams {
mkldnn::prop_kind prop_kind;
MEMORY_FORMAT src_format;
memory::desc src_md;
#ifdef ENABLE_MKLDNN_V1
memory::desc diff_dst_md;
#endif // ENABLE_MKLDNN_V1

MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
memory::dims filter_dims, memory::dims strides,
memory::dims padding_left, memory::dims padding_right,
mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind,
#ifdef ENABLE_MKLDNN_V1
MEMORY_FORMAT src_format, memory::desc src_md,
memory::desc diff_dst_md = memory::desc())
#else
MEMORY_FORMAT src_format, memory::desc src_md)
#endif // ENABLE_MKLDNN_V1
: src_dims(src_dims),
dst_dims(dst_dims),
filter_dims(filter_dims),
Expand All @@ -64,7 +72,14 @@ struct MklPoolingParams {
alg_kind(alg_kind),
prop_kind(prop_kind),
src_format(src_format),
src_md(src_md) {}
#ifdef ENABLE_MKLDNN_V1
src_md(src_md),
diff_dst_md(diff_dst_md) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be guarded with #ifdef ENABLE_MKLDNN_V1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed it.

}
#else
src_md(src_md) {
}
#endif // ENABLE_MKLDNN_V1
};

template <typename T>
Expand Down