Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Adding support for AvgPool with native format #43505

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions tensorflow/core/framework/common_shape_fns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,14 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
return Status::OK();
}

Status AvgPoolGradShape(shape_inference::InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
c->set_output(0, s);
return Status::OK();
}

Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
Expand Down Expand Up @@ -1771,6 +1779,14 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
return Status::OK();
}

Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
c->set_output(0, s);
return Status::OK();
}

Status UnknownShape(shape_inference::InferenceContext* c) {
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/framework/common_shape_fns.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c);
// Shape function for AvgPool-like operations.
Status AvgPoolShape(shape_inference::InferenceContext* c);

// Shape function for AvgPoolGrad-like operations.
Status AvgPoolGradShape(shape_inference::InferenceContext* c);

// Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
Status FusedBatchNormShape(shape_inference::InferenceContext* c);

Expand Down Expand Up @@ -181,6 +184,9 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
// Shape function for 3D Pooling operations.
Status Pool3DShape(shape_inference::InferenceContext* c);

// Shape function for AvgPool3DGrad-like operations.
Status AvgPool3DGradShape(shape_inference::InferenceContext* c);

// Shape function for use with ops whose output shapes are unknown.
Status UnknownShape(shape_inference::InferenceContext* c);

Expand Down
99 changes: 64 additions & 35 deletions tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,24 @@ namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;

template <typename Device, typename T>
template <typename Device, typename T, bool native_format = false>
class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
public:
explicit MklAvgPoolingOp(OpKernelConstruction* context)
: MklPoolingForwardOpBase<T>(context) {
// Workspace is an MKLDNN construct that is only used in Max Pooling.
// So set workspace_enabled_ to false.
this->workspace_enabled_ = false;
this->native_format_ = native_format;
}

void Compute(OpKernelContext* context) override {
try {
const Tensor& input_tensor =
MklGetInput(context, this->kInputTensorIndexInput);
MklDnnShape dnn_shape_input;
GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input,
this->native_format_);
this->SanityCheckInput(context, input_tensor, dnn_shape_input);
if (!context->status().ok()) return;

Expand Down Expand Up @@ -116,7 +118,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
padding_right, ALGORITHM::pooling_avg_exclude_padding,
pooling_prop_kind,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md,
this->native_format_);
#else
MklPoolingParams fwdParams(
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
Expand Down Expand Up @@ -174,11 +177,13 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
engine cpu_engine_ = engine(ENGINE_CPU, 0);
}; // MklAvgPoolingOp

template <class Device, class T>
template <class Device, class T, bool native_format = false>
class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
public:
explicit MklAvgPoolingGradOp(OpKernelConstruction* context)
: MklPoolingBackwardOpBase<T>(context) {}
: MklPoolingBackwardOpBase<T>(context) {
this->native_format_ = native_format;
}

void Compute(OpKernelContext* context) override {
try {
Expand All @@ -188,8 +193,10 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklGetInput(context, kInputTensorIndexInputGradient);

MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
GetMklShape(context, kInputTensorIndexInputShape, &orig_input_mkl_shape);
GetMklShape(context, kInputTensorIndexInputGradient, &grad_mkl_shape);
GetMklShape(context, kInputTensorIndexInputShape, &orig_input_mkl_shape,
this->native_format_);
GetMklShape(context, kInputTensorIndexInputGradient, &grad_mkl_shape,
this->native_format_);
if (!context->status().ok()) return;

// Used to allocate output_diff_src/diff_src.
Expand Down Expand Up @@ -249,7 +256,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right,
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md,
this->native_format_);
#else
MklPoolingParams bwdParams(
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
Expand All @@ -273,7 +281,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
std::shared_ptr<PoolingBwdPd> pooling_bwd_pd =
pooling_bwd->GetPoolingBwdPd();
T* diff_dst_data = nullptr;
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
if (!this->native_format_ &&
IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
pooling_bwd)) {
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
Expand Down Expand Up @@ -307,36 +316,56 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
engine cpu_engine_ = engine(ENGINE_CPU, 0);
}; // MklAvgPoolingGradOp

#define REGISTER_MKL_AVGPOOL3D_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPool3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPool3DGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingGradOp<CPUDevice, T>);
#define REGISTER_MKL_AVGPOOL3D_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPool3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPool3DGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingGradOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPool3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklAvgPoolingOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPool3DGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklAvgPoolingGradOp<CPUDevice, T, true>);

TF_CALL_float(REGISTER_MKL_AVGPOOL3D_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_AVGPOOL3D_KERNELS);

#define REGISTER_MKL_AVGPOOL_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPool") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingGradOp<CPUDevice, T>);
#define REGISTER_MKL_AVGPOOL_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPool") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklAvgPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklAvgPoolingGradOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPool") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklAvgPoolingOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklAvgPoolingGradOp<CPUDevice, T, true>);

TF_CALL_float(REGISTER_MKL_AVGPOOL_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_AVGPOOL_KERNELS);
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
MklPoolingParams fwdParams(
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md,
this->native_format_);
#else
MklPoolingParams fwdParams(
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
Expand Down Expand Up @@ -312,7 +313,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right, ALGORITHM::pooling_max,
prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md,
this->native_format_);
#else
MklPoolingParams bwdParams(
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
#else
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
#endif // !ENABLE_MKLDNN_V1
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
MEMORY_FORMAT::any));
context_.dst_md.reset(new memory::desc(
{fwdParams.dst_dims}, MklDnnType<T>(),
fwdParams.native_format ? fwdParams.src_format : MEMORY_FORMAT::any));

#ifndef ENABLE_MKLDNN_V1
// Create a pooling descriptor.
Expand Down Expand Up @@ -187,8 +188,9 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
{bwdParams.dst_dims}, MklDnnType<T>(), bwdParams.src_format));
#else
context_.src_md.reset(new memory::desc(bwdParams.src_md.data));
context_.dst_md.reset(new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(),
MEMORY_FORMAT::any));
context_.dst_md.reset(new memory::desc(
{bwdParams.dst_dims}, MklDnnType<T>(),
bwdParams.native_format ? bwdParams.src_format : MEMORY_FORMAT::any));
#endif // !ENABLE_MKLDNN_V1

#ifndef ENABLE_MKLDNN_V1
Expand Down
24 changes: 19 additions & 5 deletions tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ struct MklPoolingParams {
mkldnn::prop_kind prop_kind;
MEMORY_FORMAT src_format;
memory::desc src_md;
bool native_format;

MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
memory::dims filter_dims, memory::dims strides,
memory::dims padding_left, memory::dims padding_right,
mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind,
MEMORY_FORMAT src_format, memory::desc src_md)
MEMORY_FORMAT src_format, memory::desc src_md,
bool native_format)
: src_dims(src_dims),
dst_dims(dst_dims),
filter_dims(filter_dims),
Expand All @@ -64,7 +66,8 @@ struct MklPoolingParams {
alg_kind(alg_kind),
prop_kind(prop_kind),
src_format(src_format),
src_md(src_md) {}
src_md(src_md),
native_format(native_format) {}
};

template <typename T>
Expand Down Expand Up @@ -583,7 +586,8 @@ class MklPoolingOpBase : public OpKernel {
output_tf_shape = MklDnnDimsToTFShape(output_dims_order);
}
AllocateOutputSetMklShape(context, kOutputIndex, output_tensor,
output_tf_shape, output_mkl_shape);
output_tf_shape, output_mkl_shape,
native_format_);
DCHECK(output_tensor);
}

Expand All @@ -608,6 +612,7 @@ class MklPoolingOpBase : public OpKernel {
// Either memory::format (MKL-DNN v-0.x) or memory::format_tag (MKL-DNN v-1.x)
MEMORY_FORMAT data_format_mkldnn_;
bool workspace_enabled_;
bool native_format_ = false;
};

template <class T>
Expand Down Expand Up @@ -671,8 +676,13 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
output_dims_mkl_order, output_tf_format);
// Only allocate enough space for the elements we need.
output_tf_shape.AddDim(this->GetNumTElements(dst_pd));

if (this->native_format_) {
output_tf_shape = output_mkl_shape.GetTfShape();
}
AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
output_tf_shape, output_mkl_shape);
output_tf_shape, output_mkl_shape,
this->native_format_);
DCHECK(*output_tensor);
}

Expand Down Expand Up @@ -719,8 +729,12 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {

TensorShape output_tf_shape;
output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
if (this->native_format_) {
output_tf_shape = output_mkl_shape.GetTfShape();
}
AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
output_tf_shape, output_mkl_shape);
output_tf_shape, output_mkl_shape,
this->native_format_);
DCHECK(*output_tensor);
}
};
Expand Down