Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTEL MKL: Enhance MkL Pooling ops with primitive reuse #19403

Merged
274 changes: 116 additions & 158 deletions tensorflow/core/kernels/mkl_avgpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,22 +442,21 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {

void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
const Tensor& input_tensor =
MklGetInput(context, this->kInputTensorIndexInput);
const Tensor& input_tensor = MklGetInput(context,
this->kInputTensorIndexInput);
MklDnnShape dnn_shape_input;
GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
this->SanityCheckInput(context, input_tensor, dnn_shape_input);
if (!context->status().ok()) return;

MklDnnData<T> dnn_data_input(&cpu_engine);
MklDnnData<T> dnn_data_output(&cpu_engine);
MklDnnData<T> dnn_data_input(&cpu_engine_);

// initialize variables for the pooling op
MklPoolParameters pool_params;
// Get the input tensor and initialize the pooling parameters
this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
&dnn_data_input);
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params,
dnn_shape_input, input_tensor_shape);
OP_REQUIRES_OK(context, context->status());

// Declare output tensor
Expand All @@ -467,65 +466,58 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {

// If input is an empty tensor, allocate an empty output tensor and return
if (input_tensor.NumElements() == 0) {
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
TensorShape output_tf_shape;
if (pool_params.data_format == TensorFormat::FORMAT_NCHW) {
output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
} else {
memory::dims output_dims_NHWC_order;
output_dims_NHWC_order = {pool_params.tensor_in_batch,
static_cast<int>(pool_params.out_height),
static_cast<int>(pool_params.out_width),
pool_params.out_depth};
output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
}
const int kOutputIndex = 0;
AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor,
output_tf_shape, output_mkl_shape);
CHECK_NOTNULL(output_tensor);
this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params,
output_dims_mkl_order, &output_tensor);
return;
}

// If input is in Mkl layout, then just get the memory format from it
// directly, instead of using input data_format to AvgPool.
if (dnn_shape_input.IsMklTensor()) {
dnn_data_output.SetUsrMem(
output_dims_mkl_order,
static_cast<memory::format>(
dnn_data_input.GetUsrMemDesc().data.format));

} else {
dnn_data_output.SetUsrMem(output_dims_mkl_order,
this->data_format_mkldnn_);
}

// describe the memory layout
dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);

// 3. create a pooling primitive descriptor
auto pool_desc = pooling_forward::desc(
prop_kind::forward, algorithm::pooling_avg_exclude_padding,
dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
memory::dims({pool_params.row_stride, pool_params.col_stride}),
memory::dims({pool_params.window_rows, pool_params.window_cols}),
memory::dims({static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)}),
memory::dims({static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)}),
TFPaddingToMklDnnPadding(this->padding_));
auto pool_prim_desc =
pooling_forward::primitive_desc(pool_desc, cpu_engine);

this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order,
this->data_format_mkldnn_, &output_tensor);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
&padding_left, &padding_right);

// Get the input memory descriptor
memory::desc input_md = dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetMklLayout()
: memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);

// Get src/filter/stride/padding information
memory::dims src_dims = dnn_shape_input.IsMklTensor()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please put all these memory::dims in a struct and add a helper function PoolParamsToDims to populate them. Use it in the functions below as well as much as possible. There's a lot of duplicated boilerplate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done by code change.

? dnn_shape_input.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
this->data_format_tf_);

// Get an average pooling primitive from the op pool
MklPoolingFwdPrimitive<T> *pooling_fwd = nullptr;
MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right,
algorithm::pooling_avg_exclude_padding);
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);

// allocate output tensor
this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
output_dims_mkl_order, this->data_format_mkldnn_, &output_tensor);
CHECK_NOTNULL(output_tensor);

OP_REQUIRES_OK(context, context->status());
dnn_data_output.SetUsrMemDataHandle(output_tensor);

this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input,
&dnn_data_output);
// check whether we need to reorder src
const T* src_data = input_tensor.flat<T>().data();
if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) {
dnn_data_input.SetUsrMem(input_md, &input_tensor);
auto src_target_primitive_desc = memory::primitive_desc({{src_dims},
MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()}, cpu_engine_);
dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc);
src_data = const_cast<T*>(
reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle()));
}

T* dst_data = output_tensor->flat<T>().data();

// execute pooling
pooling_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
Expand All @@ -535,9 +527,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
errors::Aborted("Operation received an exception:", error_msg));
}
} // Compute
}; // MklAvgPoolingOp

//-----------------------------------------------------------------------------
private:
engine cpu_engine_ = engine(engine::cpu, 0);
}; // MklAvgPoolingOp

template <class Device, class T>
class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
Expand All @@ -547,125 +540,90 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {

void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape;
const Tensor& tensor_in_shape =
const Tensor& orig_input_tensor =
MklGetInput(context, kInputTensorIndexInputShape);
const Tensor& input_gradient_tensor =
const Tensor& grad_tensor =
MklGetInput(context, kInputTensorIndexInputGradient);

MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
GetMklShape(context, kInputTensorIndexInputShape,
&original_input_mkl_shape);
&orig_input_mkl_shape);
GetMklShape(context, kInputTensorIndexInputGradient,
&input_gradient_mkl_shape);

SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor,
original_input_mkl_shape, input_gradient_mkl_shape);
&grad_mkl_shape);
if (!context->status().ok()) return;

// Used to allocate output_diff_src/diff_src
// and create pool_fwd mdm desc
// 0. Input("orig_input_shape: int32") //NOT a T Tensor!
// 1. Input("grad: T")

MklDnnData<T> input_gradient_diff_dst(&cpu_engine);
MklDnnData<T> output_diff_src(&cpu_engine);
Tensor* output_tensor_diff_src = nullptr;
TensorShape original_input_shape;
MklDnnData<T> grad_dnn_data(&cpu_engine_);
MklPoolParameters pool_params;
memory::dims output_dims_mkl_order, original_input_dims_nchw;
// Configure the original input memory descriptor
memory::desc original_input_md = ConfigureOriginalInput(
context, tensor_in_shape, original_input_mkl_shape,
&original_input_dims_nchw, &pool_params, &original_input_shape);

// configure the original output memory descriptor
// by definition, the shape of the original output is the same
// as the shape of the gradient diff_dst
memory::desc original_output_md = this->ConfigureOriginalOutput(
pool_params, input_gradient_mkl_shape, output_dims_mkl_order);

memory::desc target_diff_dst_md = this->ConfigureInputGradient(
input_gradient_mkl_shape, input_gradient_tensor,
&input_gradient_diff_dst, original_output_md);
// The shape of the output diff src needs to be the same shape as the
// original input. But we will set its format to be same as the format of
// input gradient. We won't use format of original input since it will
// always be in Tensorflow layout (given that AvgPoolGrad gets shape of
// the input rather than actual input).
output_diff_src.SetUsrMem(
original_input_dims_nchw,
static_cast<memory::format>(target_diff_dst_md.data.format));

// Create the forward pooling primitive descriptor so we can reference it
// in the backward pooling primitive descriptor
auto pool_fwd_desc = pooling_forward::desc(
prop_kind::forward, algorithm::pooling_avg_exclude_padding,
original_input_md, original_output_md,
memory::dims({pool_params.row_stride, pool_params.col_stride}),
memory::dims({pool_params.window_rows, pool_params.window_cols}),
memory::dims({static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)}),
memory::dims({static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)}),
TFPaddingToMklDnnPadding(this->padding_));
auto pool_fwd_prim_desc =
pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);

auto pool_bkwd_desc = pooling_backward::desc(
algorithm::pooling_avg_exclude_padding,
output_diff_src.GetUsrMemDesc(), target_diff_dst_md,
memory::dims({pool_params.row_stride, pool_params.col_stride}),
memory::dims({pool_params.window_rows, pool_params.window_cols}),
memory::dims({static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)}),
memory::dims({static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)}),
TFPaddingToMklDnnPadding(this->padding_));
auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
this->AllocateOutputTensor(
context, pool_bkwd_prim_desc, original_input_dims_nchw,
this->data_format_mkldnn_, &output_tensor_diff_src);

output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src);

this->PrepareAndExecuteNet(
pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src,
memory::primitive_desc(target_diff_dst_md, cpu_engine));
auto shape_vec = orig_input_tensor.vec<int32>();
TensorShape orig_input_shape;
for (int i = 0; i < orig_input_tensor.NumElements(); i++) {
orig_input_shape.AddDim(shape_vec(i));
}
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);

memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
&padding_left, &padding_right);

memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(orig_input_shape, this->data_format_tf_);

memory::dims diff_dst_dims = grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);

MklPoolingParams bwdParams(orig_input_dims_mkl_order,
output_dims_mkl_order, filter_dims, strides,
padding_left, padding_right, algorithm::pooling_avg_exclude_padding);
MklPoolingBwdPrimitive<T> *pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);

Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
orig_input_dims_mkl_order,
this->data_format_mkldnn_, &output_tensor);
// get diff_dst memory::desc
memory::desc diff_dst_md = grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);
// Check whether we need to reorder diff_dst
const T* diff_dst_data = grad_tensor.flat<T>().data();
if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) {
auto target_diff_dst = memory::primitive_desc({{diff_dst_dims},
MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()}, cpu_engine_);
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
grad_dnn_data.CheckReorderToOpMem(target_diff_dst);
diff_dst_data = const_cast<T*>(
reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle()));
}

T* diff_src_data = output_tensor->flat<T>().data();

// execute pooling op
pooling_bwd->Execute(diff_dst_data, diff_src_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
error_msg));
}
} // Compute
}

private:
// 0. Input("orig_input_shape: int32")
// 1. Input("grad: T")
const int kInputTensorIndexInputShape = 0;
const int kInputTensorIndexInputGradient = 1;

memory::desc ConfigureOriginalInput(
OpKernelContext* context, const Tensor& tensor_original_input_shape,
const MklDnnShape& original_input_mkl_shape,
memory::dims* original_input_dims_mkl_order,
MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
CHECK_NOTNULL(original_input_dims_mkl_order);
CHECK_NOTNULL(pool_params);
CHECK_NOTNULL(input_tensor_shape);
// For AvgPoolGrad, we only get the size of the original input because
// The original data is irrelvant.
auto shape_vec = tensor_original_input_shape.vec<int32>();
for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) {
input_tensor_shape->AddDim(shape_vec(i));
}

return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
context, tensor_original_input_shape, original_input_mkl_shape,
original_input_dims_mkl_order, pool_params, *input_tensor_shape);
}
engine cpu_engine_ = engine(engine::cpu, 0);

void SanityCheckInputs(OpKernelContext* context,
const Tensor& tensor_in_shape,
Expand Down