-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update fold_old_batch_norms.cc to accommodate 'NCHW' format. #17602
Conversation
Fixes the problem of using fused batch normalization and this transform, only shows up when using 'NCHW' as the default is 'NHWC'.
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here (e.g. What to do if you already signed the CLAIndividual signers
Corporate signers
|
I signed it! |
CLAs look good, thanks! |
@@ -159,6 +159,7 @@ Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values, | |||
NodeDef bias_add_node; | |||
bias_add_node.set_op("BiasAdd"); | |||
bias_add_node.set_name(conv_output_name); | |||
CopyNodeAttr(conv_node, "data_format", "data_format", &bias_add_node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would fail the whole program if the conv_node doesn't have "data_format" attribute. You probably need to first verify it has the attribute "data_format" before calling CopyNodeAttr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any checks like this in any other part of the code, please could you direct me to some.
Also, since it is now part of the api in tf.nn.conv2d
does it not get this attribute regardless? - I'm not sure but I thought that it would fallback to the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just an extra safety check for legacy tf graphs. You are right, If the original graph is created with tf.nn.conv2d, then this check is not needed; but I am not whether there are any legacy tf graphs that do not contain this attr.
@mingxingtan Added the check as requested. |
@mingxingtan can you take another look? |
@@ -159,6 +159,9 @@ Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values, | |||
NodeDef bias_add_node; | |||
bias_add_node.set_op("BiasAdd"); | |||
bias_add_node.set_name(conv_output_name); | |||
if (HasAttr(conv_node, "data_format")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see where is HasAttr defined. Maybe you can directly use something like:
if (!conv_node.attr().count("data_format")) {
CopyNodeAttr(conv_node, "data_format", "data_format", &bias_add_node);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated as requested
Updated as requested
Fixes the problem of using fused batch normalization and this transform, only shows up when using 'NCHW' as the default is 'NHWC'.