New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fused batchnorm with any ndims and axis #40338
Support fused batchnorm with any ndims and axis #40338
Conversation
- Modifies the BatchNormalizationBase layer to support fused=True with and number of dimensions and any (single) axis by adding reshape operations around the calls to nn.fused_batch_norm (which itself only supports NCHW and NHWC tensors). - Fixes tests to account for the broader support and adds a new test for a 3D convnet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@@ -315,18 +352,22 @@ def build(self, input_shape): | |||
raise ValueError('When using virtual_batch_size, adjustment cannot ' | |||
'be specified') | |||
|
|||
fused_axis = self.axis | |||
self._input_fused_shape = None | |||
if self.fused in (None, True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.fused
can no longer be None, so just have this be if self.fused:
. Then, this can be merged with the if self.fused:
block below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self._input_fused_shape = [-1] + fused_shape.as_list()[1:] | ||
fused_axis = [fused_axis] | ||
|
||
if not self._USE_V2_BEHAVIOR: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't want to run the if len(self.axis) == 1
block above if self.fused
ends up being False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -281,6 +273,51 @@ def _support_zero_size_input(self): | |||
distribution_strategy_context.get_strategy().extended, | |||
'experimental_enable_get_next_as_optional', False) | |||
|
|||
def _get_shape_and_axis_for_fused(self, nd_shape, nd_axis): | |||
'''Returns a 4D shape and axis (1 or 3) to which nd_shape and nd_axis can |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format the docstring like the others: use triple quotes ("""
), have the first line be a single setence, then a blank link, then a more detailed description, and have an Args
and Returns
section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
'''Returns a 4D shape and axis (1 or 3) to which nd_shape and nd_axis can | ||
be changed without changing the result of the batch normalization operation. | ||
''' | ||
assert(isinstance(nd_axis, int)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Python 2, a long
is also acceptable. Use six.integer_types
or just remove the assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isinstance(axis, int)
is already used in several places in this class, including in an error check at the beginning of the __init__
function.
else: | ||
# Merge excess pre-axis dims into first dim. | ||
# Transform [N, ..., C, ...] to [product(N, ...), C, ...]. | ||
for dim in range(axis - 1, 0, -1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this block, I think the following is simpler
product = 1
for elem in shape[:axis]:
product *= elem
shape[:axis] = [product]
ndims -= (axis - 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# Merge excess spatial dims into the second spatial dim. | ||
# Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)]. | ||
# Or [N, H, W, ..., C] to [N, H, product(W, ...), C]. | ||
merge_dim = 2 if is_channels_last else 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to above, maybe do:
product = 1
for elem in shape[merge_dim:merg_dim + ndims - 4]:
product *= elem
shape[merge_dim:merg_dim + ndims - 4] = [product]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# reshape the input/output tensor to/from an equivalent 4D shape. | ||
fused_shape, fused_axis = self._get_shape_and_axis_for_fused( | ||
input_shape.dims, self.axis[0]) | ||
fused_shape = tensor_shape.TensorShape(fused_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why convert this to a TensorShape then immediately convert back to a list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. (Twas a remnant of a fight I had initially with TensorShape vs. list, None vs. -1 etc.).
fused_shape, fused_axis = self._get_shape_and_axis_for_fused( | ||
input_shape.dims, self.axis[0]) | ||
fused_shape = tensor_shape.TensorShape(fused_shape) | ||
self._input_fused_shape = [-1] + fused_shape.as_list()[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with computing this in build()
is that the input shape might change between calls to the layer. The channel dimension must stay the same since that affects the size of the weights, but other dimensions can differ.
Instead, compute this in call
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -138,6 +138,27 @@ def test_batchnorm_convnet_channel_last(self): | |||
np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1) | |||
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1) | |||
|
|||
@keras_parameterized.run_all_keras_modes | |||
def test_batchnorm_convnet_channel_last_3d_fused(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also add a test with simple hand-written inputs. No need to test gradients. I'm worried this won't catch errors in _get_shape_and_axis_for_fused
. Note gamma and beta default to ones and zeros, so you can effectively ignore them in tests that don't take gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
- Simplify condition. - Improve and format docstring. - Use simpler shape manipulation logic. - Remove redundant shape -> list transform.
- Moves fused shape calculation from build() to call() so that it handles dynamic input shapes. - Adds a new test that checks for specific results and ensures that dynamic input shapes work.
else: | ||
raise ValueError('Unsupported axis, fused batch norm only supports ' | ||
'axis == [1] or axis == [3]') | ||
self._input_fused_shape = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
@@ -499,6 +525,27 @@ def _fused_batch_norm(self, inputs, training): | |||
beta = self.beta if self.center else self._beta_const | |||
gamma = self.gamma if self.scale else self._gamma_const | |||
|
|||
original_shape = None | |||
fused_axis = self.axis[0] | |||
input_shape = inputs.shape.as_list() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't going to work if in graph mode and the shape is not fully defined, which would cause Nones to be in inputs.shape
. Luckily, the layer's input spec requires the input to at least have known rank, so the shape itself cannot be None.
I think you can use tf.shape(inputs)
instead of inputs.shape
. You'll have to change _get_shape_and_axis_for_fused
to deal with a tensor instead of a list, which is irritating but should work. Unfortunately, I don't think AutoGraph is used within layers defined inside Keras itself, so you can't use for-loops to iterate over the tensor anymore.
Also add a test for unknown shapes. Probably easiest to make the test V1 only and use a placeholder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
moving_mean = keras.backend.eval(norm.moving_mean) | ||
moving_variance = keras.backend.eval(norm.moving_variance) | ||
np.testing.assert_allclose( | ||
moving_mean, np.array([936., 937., 938.]), rtol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain where these numbers came from? Maybe use smaller input values to make it simpler, e.g. (1, 2, 1, 2, 2).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed it to compute the target values on the fly so that it's clear where they come from.
@benbarsdell Can you please check @reedwm's comments and keep us posted. Thanks! |
- Remove unneeded variable. - Do 4D shape computation in-graph so that unknown shapes are supported in graph mode. - Compute test target values on the fly so that it's clear where they come from. - Add a test to check that unknown shapes are supported in graph mode.
I managed to root-cause the Let me know how you want to deal with it. |
for dim in shape[:axis]: | ||
product *= dim | ||
shape[:axis] = [product] | ||
product = math_ops.reduce_prod(shape[:axis]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you pass keepdims=True
here, I think you don't need to call array_ops.reshape
on the line below. This makes it slightly simpler.
And same for the place below you call math_ops.reduce_prod
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
original_shape = [-1] + input_shape[1:] | ||
inputs = array_ops.reshape(inputs, [-1] + fused_shape[1:]) | ||
original_shape = array_ops.concat( | ||
[constant_op.constant([-1]), input_shape[1:]], axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you pass [-1]
here instead of constant_op.constant([-1])
? The former is slightly simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
original_shape = array_ops.concat( | ||
[constant_op.constant([-1]), input_shape[1:]], axis=0) | ||
fused_shape = array_ops.concat( | ||
[constant_op.constant([-1]), fused_shape[1:]], axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to replace the first element of fused_shape
with -1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I've removed these operations.
@@ -213,7 +277,7 @@ def call(self, x, training): | |||
model = MyModel() | |||
|
|||
for _ in range(10): | |||
x = constant_op.constant(0.5, shape=[1, 1]) | |||
x = constant_op.constant(0.5, shape=[2, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is... complicated :)
First, Bessel's correction (N/(N-1)
) is used for the fused implementation but not for the non-fused. There is a discussion of this here.
Then, for the fused case, the CPU implementation avoids returning NaN when N < 2, while the GPU implementation (CUDNN) explicitly returns NaN in these cases.
So when the test calls batchnorm with N = 1 and it (with this PR) calls the fused implementation, it returns NaN (on the GPU) and the test fails.
Changing the test size to N = 2 avoids this problem.
This probably isn't ever observed in the real world because batchnorm doesn't work when N is small anyway.
Let me know if you'd like me to take another action here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we know that N = 1, we should probably not allow calling the fused version on GPU. I'd prefer an error message over silently producing NaNs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this can't easily be done in the Python layer because it only applies in training mode. Adding a check to build() causes a bunch of tests to fail (ones that don't use training mode). It would have to be an in-graph check, which I don't think there's much precedent for(?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add a tf.assert in-graph check here. This is not unprecedented and no different from adding a check in the kernel itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added an in-graph check using Assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It caused a bunch of other tests to fail. I'll have to look into them.
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2, 3)), 1.0, atol=1e-1) | ||
|
||
@keras_parameterized.run_all_keras_modes | ||
def test_batchnorm_convnet_channel_last_3d_fused_correctness(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are no non-4D fused tests where the axis isn't the last element. Either change this test or the previous to have the axis be not the last element, or add a new test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added channels-first versions of the two tests.
Can you create a new PR to fix it? This line has an example of how to cast. If this is trickier than I'm imagining and you don't want to deal with this, I can also get someone internally to fix it. |
Thanks, I tried that but unfortunately only Eigen tensors have a |
- Use keepdims=True to avoid needing to reshape. - Remove unnecessary replacements of first dim with -1. - Add channel-first versions of the two 3d tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change looks good, but I am worried about using Bessel's correction in more cases now. It's possible (but unlikely) we may break someone who currently relies on Bessel's correction not being applied.
I think the best solution is to change this layer to use Bessel's correction in the unfused case before submitting this change. Then we can extensively test the Bessel's correction change before submitting, potentially rolling it back if it breaks someone. Once submitted, this change is a lot safer since it doesn't affect correctness, only performance. I can work on using Bessel's correction in the unfused case since I can run all Google tests before submitting it.
Do you know why Bessel's correction for cuDNN is only used for the moving variance, not the internal variance the input is divided by? We copy this behavior in the CPU version of FusedBatchNorm as well.
@benbarsdell Can you please check @reedwm's comments and keep us posted. Thanks! |
Regarding Bessel's correction at inference time but not training time, it seems that the original batchnorm paper indicated doing it this way; there is some discussion here (evidently it is the same way in PyTorch): It sounds like you are going to take care of the remaining items internally @reedwm. If I've misinterpreted that or you need me to do anything else just let me know. |
Yes I will take care of using Bessel's correction for the nonfused implementation internally. Then we can merge this PR. |
@reedwm Any update on this PR? Please. Thanks! |
@reedwm Any update on this PR? Please. Thanks! |
@benbarsdell Can you please resolve conflicts? Thanks! |
- Conflicts fixed in normalization.py and normalization_test.py in tensorflow/python/keras/layers/.
- This is an in-graph check required due to the use of Bessel's correction in fused batchnorm. - Previously the layer would silently return invalid or NaN results if N <= 1. - Also fixes a test that failed this condition (since it is now run with fused=True).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reed wanted to discuss a few design issues before approving.
@benbarsdell Can you please resolve conflicts? Thanks! |
# The use of Bessel's correction in training mode imposes the requirement | ||
# that the number of elements in the reduced dimensions is > 1. | ||
check = control_flow_ops.Assert( | ||
training == False or math_ops.reduce_prod(input_shape) > self.depth, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think training can also be a tensor. The assert should be added to this function I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also probably better to also raise an error immediately if the shape and training boolean are statically known
It turns out that the reshapes added by this PR prevent the grappler layout optimizer from optimizing out all the NCHW<->NHWC transposes. For this reason I think we will have to abandon this PR and use #42970 instead. |
cc @reedwm @nluehr