Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Enabling MKL Conv2D BWD in eager mode #30402

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 55 additions & 47 deletions tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
Expand Up @@ -357,7 +357,8 @@ class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
}
};

template <typename Device, class T, bool bias_enabled, bool is_depthwise>
template <typename Device, class T, bool bias_enabled, bool is_depthwise,
bool eager_mode>
class MklConvCustomBackpropFilterOp
: public MklConvBackpropCommonOp<Device, T, is_depthwise> {
public:
Expand All @@ -382,9 +383,9 @@ class MklConvCustomBackpropFilterOp
const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);

MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
GetMklShape(context, kInputIdx, &src_mkl_shape);
GetMklShape(context, kFilterIdx, &filter_mkl_shape);
GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode);
// Allow operator-specific sanity checking of shapes.
ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);

Expand All @@ -395,7 +396,8 @@ class MklConvCustomBackpropFilterOp
// allow this class to handle this case.
TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
TensorShape diff_dst_tf_shape =
GetTfShape(context, kOutbpropIdx, eager_mode);

// Corner cases: output with 0 elements and 0 batch size.
Tensor* diff_filter_tensor = nullptr;
Expand All @@ -408,7 +410,8 @@ class MklConvCustomBackpropFilterOp
GetOutputTfShape(src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
const int kOutputIdx = 0;
AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
diff_filter_tf_shape, diff_filter_mkl_shape);
diff_filter_tf_shape, diff_filter_mkl_shape,
eager_mode);
CHECK_NOTNULL(diff_filter_tensor);

// if output tensor has more than 0 elements, we need to 0 them out.
Expand Down Expand Up @@ -493,8 +496,8 @@ class MklConvCustomBackpropFilterOp
bwd_output_dims[MklDnnDims::Dim_I],
bwd_output_dims[MklDnnDims::Dim_O]});
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
diff_filter_tf_shape,
diff_filter_mkl_shape);
diff_filter_tf_shape, diff_filter_mkl_shape,
eager_mode);
} else {
// Depthwise Conv2d: bwd_output_dims is GOIHW format
// | TensorFlow | MKLDNN
Expand Down Expand Up @@ -592,9 +595,9 @@ class MklConvCustomBackpropFilterOp
// delete primitive since it is not cached.
if (do_not_cache) delete conv_bwd_filter;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
Expand All @@ -620,7 +623,7 @@ class MklConvCustomBackpropFilterOp
TensorShape MakeInputTfShape(OpKernelContext* context,
const Tensor& input_tensor) {
size_t input_idx = 0;
return GetTfShape(context, input_idx);
return GetTfShape(context, input_idx, eager_mode);
}

// Get TensorFlow shape of filter tensor.
Expand Down Expand Up @@ -654,10 +657,9 @@ class MklConvCustomBackpropFilterOp
// Output layout is Tensorflow's filter layout
// Conv2D: HWIO; Conv3D: DHWIO; Depthwise Conv: HWIGO
memory::format GetOutputFormat(const memory::format data_format) {
return is_depthwise
? memory::format::hwigo
: ((this->strides_.size() == 4) ? memory::format::hwio
: memory::format::dhwio);
return is_depthwise ? memory::format::hwigo : ((this->strides_.size() == 4)
? memory::format::hwio
: memory::format::dhwio);
}

// Allocate output tensor.
Expand Down Expand Up @@ -699,37 +701,43 @@ class MklConvCustomBackpropFilterOp
}
};

#define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropFilterWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, true, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklDepthwiseConv2dNativeBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, true>); \
REGISTER_KERNEL_BUILDER( \
Name("__MklDummyConv2DBackpropFilterWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklDummyOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, false>);
#define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklEagerConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, true>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropFilterWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, true, false, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklDepthwiseConv2dNativeBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, true, false>); \
REGISTER_KERNEL_BUILDER( \
Name("__MklDummyConv2DBackpropFilterWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklDummyOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>);

TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS);
Expand Down
95 changes: 62 additions & 33 deletions tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
Expand Up @@ -295,7 +295,7 @@ class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
}
};

template <typename Device, class T, bool is_depthwise>
template <typename Device, class T, bool is_depthwise, bool eager_mode>
class MklConvCustomBackpropInputOp
: public MklConvBackpropCommonOp<Device, T, is_depthwise> {
public:
Expand All @@ -319,9 +319,9 @@ class MklConvCustomBackpropInputOp
const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);

MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
GetMklShape(context, kInputIdx, &src_mkl_shape);
GetMklShape(context, kFilterIdx, &filter_mkl_shape);
GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode);
// Allow operator-specific sanity checking of shapes.
ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);

Expand All @@ -332,7 +332,8 @@ class MklConvCustomBackpropInputOp
// allow this class to handle this case.
TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
TensorShape diff_dst_tf_shape =
GetTfShape(context, kOutbpropIdx, eager_mode);

// Corner cases: output with 0 elements and 0 batch size.
Tensor* diff_src_tensor = nullptr;
Expand All @@ -345,7 +346,8 @@ class MklConvCustomBackpropInputOp
GetOutputTfShape(src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
const int kOutputIdx = 0;
AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor,
diff_src_tf_shape, diff_src_mkl_shape);
diff_src_tf_shape, diff_src_mkl_shape,
eager_mode);
CHECK_NOTNULL(diff_src_tensor);

// if output tensor has more than 0 elements, we need to 0 them out.
Expand Down Expand Up @@ -429,9 +431,13 @@ class MklConvCustomBackpropInputOp
bwd_diff_src_dims, bwd_diff_src_format);
TensorShape diff_src_tf_shape;
diff_src_tf_shape.AddDim(diff_src_pd.get_size() / sizeof(T));
Tensor tmp_tensor;
if (eager_mode) {
penpornk marked this conversation as resolved.
Show resolved Hide resolved
AllocTmpBuffer<T>(context, &tmp_tensor, diff_src_tf_shape);
diff_src_tf_shape = diff_src_mkl_shape.GetTfShape();
}
AllocateOutputSetMklShape(context, 0, &diff_src_tensor, diff_src_tf_shape,
diff_src_mkl_shape);

diff_src_mkl_shape, eager_mode);
T* diff_src_data =
static_cast<T*>(const_cast<T*>(diff_src_tensor->flat<T>().data()));

Expand All @@ -458,16 +464,34 @@ class MklConvCustomBackpropInputOp
}

// execute convolution input bwd
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
if (!eager_mode) {
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
} else {
// In eager mode we first write the output to temporary
// buffer in MKL format. Then we convert the data to TF format.
T* tmp_data =
static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
mkldnn::reorder::primitive_desc reorder_pd =
mkldnn::reorder::primitive_desc(diff_src_pd, output_tf_pd);
std::vector<mkldnn::primitive> net;
memory* tmp_data_mem = new memory(diff_src_pd, tmp_data);
memory* dst_data_mem = new memory(output_tf_pd, diff_src_data);
net.push_back(
mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem));
stream(stream::kind::eager).submit(net).wait();
}

// delete primitive since it is not cached.
if (do_not_cache) {
delete conv_bwd_input;
}
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
Expand Down Expand Up @@ -506,7 +530,7 @@ class MklConvCustomBackpropInputOp
// Get TensorFlow shape of filter tensor.
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
return GetTfShape(context, kInputIndex_Filter);
return GetTfShape(context, kInputIndex_Filter, eager_mode);
}

// Get the Tensorflow shape of Output (diff_src),
Expand Down Expand Up @@ -557,26 +581,31 @@ class MklConvCustomBackpropInputOp
}
};

#define REGISTER_MKL_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropInput") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv3DBackpropInputV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklDepthwiseConv2dNativeBackpropInput") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, true>);

#define REGISTER_MKL_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropInput") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, false, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklEagerConv2DBackpropInput") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, false, true>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv3DBackpropInputV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, false, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklDepthwiseConv2dNativeBackpropInput") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, true, false>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_CPU_KERNELS);
#undef REGISTER_MKL_CPU_KERNELS
Expand Down