New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[INTEL MKL] Enabled Conv2D fprop for MKL-DNN v1.0. #30549
[INTEL MKL] Enabled Conv2D fprop for MKL-DNN v1.0. #30549
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your PR! I'm a little concerned with redundant codes. I think we should minimize them.
void AllocatePersistentTensor(OpKernelContext* context, | ||
const ConvFwdPd& conv_prim_desc, | ||
Tensor** filter_tensor, | ||
const MklDnnShape& filter_mkl_shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to merge with existing original code (same or the rest of the PR). This one can call the one with three parameters then do extra processing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
It has been 14 days with no activity and the |
@tensorflowbutler yes, this PR is still being worked on. I'll push the requested changes within a couple of days. Thanks! |
@penpornk I have addressed your review comments. Please take a look. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for all the changes! This looks great! I have a few more comments.
static_cast<int32>( | ||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)); | ||
#else | ||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
static_cast<int32>( | ||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)); | ||
#else | ||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks similar to the v1 case. static_cast<int32>
shouldn't make a difference since we are assigning it to the int32
version of second_tensor
anyway. Can we merge this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -136,8 +138,13 @@ class MklDnnConvUtil { | |||
CHECK_BOUNDS(input_cols_raw, "Input cols too large"); | |||
int input_cols = static_cast<int>(input_cols_raw); | |||
|
|||
#ifdef ENABLE_MKLDNN_V1 | |||
// MKL-DNN always requires input in NCHW format Conv2D. | |||
std::vector<long int> mkldnn_sizes(4, -1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can define MKLDNN_SIZE_DTYPE
(or any name you prefer) to be either long int
or int
at the beginning and then use it throughout the file.
std::vector<long int> mkldnn_sizes(4, -1); | |
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@penpornk I have addressed your latest review comments. Please let me know if it looks okay. Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your patience! One more comment, please.
@@ -658,6 +631,10 @@ class MklDummyOp : public OpKernel { | |||
} | |||
}; | |||
|
|||
#ifdef ENABLE_MKLDNN_V1 | |||
#undef MKLDNN_SIZE_DTYPE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MKLDNN_SIZE_DTYPE
is always defined, we don't need to guard it with v1 ifdef.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right :) Fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much again! I think we're good to go. :)
PiperOrigin-RevId: 260989299
No description provided.