Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fused batchnorm with any ndims and axis #40338

Conversation

benbarsdell
Copy link
Contributor

  • Modifies the BatchNormalizationBase layer to support fused=True with any number of dimensions and any (single) axis by adding reshape operations around the calls to nn.fused_batch_norm (which itself only supports NCHW and NHWC tensors).
  • Fixes tests to account for the broader support and adds a new test for a 3D convnet.
  • Use of fused=True is particularly important for mixed-precision training.

cc @reedwm @nluehr

- Modifies the BatchNormalizationBase layer to support fused=True with
  and number of dimensions and any (single) axis by adding reshape
  operations around the calls to nn.fused_batch_norm (which itself only
  supports NCHW and NHWC tensors).
- Fixes tests to account for the broader support and adds a new test for
  a 3D convnet.
@google-ml-butler google-ml-butler bot added the size:M CL Change Size: Medium label Jun 9, 2020
@reedwm reedwm self-requested a review June 9, 2020 21:46
Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@@ -315,18 +352,22 @@ def build(self, input_shape):
raise ValueError('When using virtual_batch_size, adjustment cannot '
'be specified')

fused_axis = self.axis
self._input_fused_shape = None
if self.fused in (None, True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.fused can no longer be None, so just have this be if self.fused:. Then, this can be merged with the if self.fused: block below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self._input_fused_shape = [-1] + fused_shape.as_list()[1:]
fused_axis = [fused_axis]

if not self._USE_V2_BEHAVIOR:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't want to run the if len(self.axis) == 1 block above if self.fused ends up being False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -281,6 +273,51 @@ def _support_zero_size_input(self):
distribution_strategy_context.get_strategy().extended,
'experimental_enable_get_next_as_optional', False)

def _get_shape_and_axis_for_fused(self, nd_shape, nd_axis):
'''Returns a 4D shape and axis (1 or 3) to which nd_shape and nd_axis can
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format the docstring like the others: use triple quotes ("""), have the first line be a single setence, then a blank link, then a more detailed description, and have an Args and Returns section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

'''Returns a 4D shape and axis (1 or 3) to which nd_shape and nd_axis can
be changed without changing the result of the batch normalization operation.
'''
assert(isinstance(nd_axis, int))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python 2, a long is also acceptable. Use six.integer_types or just remove the assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(axis, int) is already used in several places in this class, including in an error check at the beginning of the __init__ function.

else:
# Merge excess pre-axis dims into first dim.
# Transform [N, ..., C, ...] to [product(N, ...), C, ...].
for dim in range(axis - 1, 0, -1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this block, I think the following is simpler

product = 1
for elem in shape[:axis]:
    product *= elem
shape[:axis] = [product]
ndims -= (axis - 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# Merge excess spatial dims into the second spatial dim.
# Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)].
# Or [N, H, W, ..., C] to [N, H, product(W, ...), C].
merge_dim = 2 if is_channels_last else 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, maybe do:

product = 1
for elem in shape[merge_dim:merg_dim + ndims - 4]:
  product *= elem
shape[merge_dim:merg_dim + ndims - 4] = [product]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# reshape the input/output tensor to/from an equivalent 4D shape.
fused_shape, fused_axis = self._get_shape_and_axis_for_fused(
input_shape.dims, self.axis[0])
fused_shape = tensor_shape.TensorShape(fused_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert this to a TensorShape then immediately convert back to a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. (Twas a remnant of a fight I had initially with TensorShape vs. list, None vs. -1 etc.).

fused_shape, fused_axis = self._get_shape_and_axis_for_fused(
input_shape.dims, self.axis[0])
fused_shape = tensor_shape.TensorShape(fused_shape)
self._input_fused_shape = [-1] + fused_shape.as_list()[1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with computing this in build() is that the input shape might change between calls to the layer. The channel dimension must stay the same since that affects the size of the weights, but other dimensions can differ.

Instead, compute this in call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -138,6 +138,27 @@ def test_batchnorm_convnet_channel_last(self):
np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)

@keras_parameterized.run_all_keras_modes
def test_batchnorm_convnet_channel_last_3d_fused(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add a test with simple hand-written inputs. No need to test gradients. I'm worried this won't catch errors in _get_shape_and_axis_for_fused. Note gamma and beta default to ones and zeros, so you can effectively ignore them in tests that don't take gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

- Simplify condition.
- Improve and format docstring.
- Use simpler shape manipulation logic.
- Remove redundant shape -> list transform.
- Moves fused shape calculation from build() to call() so that it
  handles dynamic input shapes.
- Adds a new test that checks for specific results and ensures that
  dynamic input shapes work.
@gbaned gbaned self-assigned this Jun 10, 2020
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Jun 10, 2020
else:
raise ValueError('Unsupported axis, fused batch norm only supports '
'axis == [1] or axis == [3]')
self._input_fused_shape = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -499,6 +525,27 @@ def _fused_batch_norm(self, inputs, training):
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const

original_shape = None
fused_axis = self.axis[0]
input_shape = inputs.shape.as_list()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't going to work if in graph mode and the shape is not fully defined, which would cause Nones to be in inputs.shape. Luckily, the layer's input spec requires the input to at least have known rank, so the shape itself cannot be None.

I think you can use tf.shape(inputs) instead of inputs.shape. You'll have to change _get_shape_and_axis_for_fused to deal with a tensor instead of a list, which is irritating but should work. Unfortunately, I don't think AutoGraph is used within layers defined inside Keras itself, so you can't use for-loops to iterate over the tensor anymore.

Also add a test for unknown shapes. Probably easiest to make the test V1 only and use a placeholder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

moving_mean = keras.backend.eval(norm.moving_mean)
moving_variance = keras.backend.eval(norm.moving_variance)
np.testing.assert_allclose(
moving_mean, np.array([936., 937., 938.]), rtol=1e-5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain where these numbers came from? Maybe use smaller input values to make it simpler, e.g. (1, 2, 1, 2, 2).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed it to compute the target values on the fly so that it's clear where they come from.

PR Queue automation moved this from Assigned Reviewer to Reviewer Requested Changes Jun 12, 2020
@gbaned
Copy link
Contributor

gbaned commented Jun 17, 2020

@benbarsdell Can you please check @reedwm's comments and keep us posted. Thanks!

@gbaned gbaned added the stat:awaiting response Status - Awaiting response from author label Jun 17, 2020
- Remove unneeded variable.
- Do 4D shape computation in-graph so that unknown shapes are supported
  in graph mode.
- Compute test target values on the fly so that it's clear where they
  come from.
- Add a test to check that unknown shapes are supported in graph mode.
@benbarsdell
Copy link
Contributor Author

I managed to root-cause the layer_correctness_test failure. There appears to be a bug here (also here) where the input is interpreted as OutputT instead of InputT.
I'm not sure how to add a tensor typecast there, but just removing that piece of code does fix the problem.

Let me know how you want to deal with it.

for dim in shape[:axis]:
product *= dim
shape[:axis] = [product]
product = math_ops.reduce_prod(shape[:axis])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass keepdims=True here, I think you don't need to call array_ops.reshape on the line below. This makes it slightly simpler.

And same for the place below you call math_ops.reduce_prod

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

original_shape = [-1] + input_shape[1:]
inputs = array_ops.reshape(inputs, [-1] + fused_shape[1:])
original_shape = array_ops.concat(
[constant_op.constant([-1]), input_shape[1:]], axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pass [-1] here instead of constant_op.constant([-1])? The former is slightly simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

original_shape = array_ops.concat(
[constant_op.constant([-1]), input_shape[1:]], axis=0)
fused_shape = array_ops.concat(
[constant_op.constant([-1]), fused_shape[1:]], axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to replace the first element of fused_shape with -1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've removed these operations.

@@ -213,7 +277,7 @@ def call(self, x, training):
model = MyModel()

for _ in range(10):
x = constant_op.constant(0.5, shape=[1, 1])
x = constant_op.constant(0.5, shape=[2, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is... complicated :)

First, Bessel's correction (N/(N-1)) is used for the fused implementation but not for the non-fused. There is a discussion of this here.

Then, for the fused case, the CPU implementation avoids returning NaN when N < 2, while the GPU implementation (CUDNN) explicitly returns NaN in these cases.

So when the test calls batchnorm with N = 1 and it (with this PR) calls the fused implementation, it returns NaN (on the GPU) and the test fails.
Changing the test size to N = 2 avoids this problem.

This probably isn't ever observed in the real world because batchnorm doesn't work when N is small anyway.

Let me know if you'd like me to take another action here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know that N = 1, we should probably not allow calling the fused version on GPU. I'd prefer an error message over silently producing NaNs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this can't easily be done in the Python layer because it only applies in training mode. Adding a check to build() causes a bunch of tests to fail (ones that don't use training mode). It would have to be an in-graph check, which I don't think there's much precedent for(?).

Copy link
Member

@rmlarsen rmlarsen Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a tf.assert in-graph check here. This is not unprecedented and no different from adding a check in the kernel itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an in-graph check using Assert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It caused a bunch of other tests to fail. I'll have to look into them.

np.testing.assert_allclose(np.std(out, axis=(0, 1, 2, 3)), 1.0, atol=1e-1)

@keras_parameterized.run_all_keras_modes
def test_batchnorm_convnet_channel_last_3d_fused_correctness(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are no non-4D fused tests where the axis isn't the last element. Either change this test or the previous to have the axis be not the last element, or add a new test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added channels-first versions of the two tests.

@reedwm
Copy link
Member

reedwm commented Jun 19, 2020

I managed to root-cause the layer_correctness_test failure. There appears to be a bug here (also here) where the input is interpreted as OutputT instead of InputT.
I'm not sure how to add a tensor typecast there, but just removing that piece of code does fix the problem.

Let me know how you want to deal with it.

Can you create a new PR to fix it? This line has an example of how to cast. If this is trickier than I'm imagining and you don't want to deal with this, I can also get someone internally to fix it.

@benbarsdell
Copy link
Contributor Author

I managed to root-cause the layer_correctness_test failure. There appears to be a bug here (also here) where the input is interpreted as OutputT instead of InputT.
I'm not sure how to add a tensor typecast there, but just removing that piece of code does fix the problem.
Let me know how you want to deal with it.

Can you create a new PR to fix it? This line has an example of how to cast. If this is trickier than I'm imagining and you don't want to deal with this, I can also get someone internally to fix it.

Thanks, I tried that but unfortunately only Eigen tensors have a cast method; the TF Tensor type does not. At this stage it's probably easier for someone internal to fix it, unless there's some other very easy fix that I haven't identified.

- Use keepdims=True to avoid needing to reshape.
- Remove unnecessary replacements of first dim with -1.
- Add channel-first versions of the two 3d tests.
@gbaned gbaned removed the stat:awaiting response Status - Awaiting response from author label Jun 23, 2020
@gbaned gbaned requested a review from reedwm June 23, 2020 15:49
Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks good, but I am worried about using Bessel's correction in more cases now. It's possible (but unlikely) we may break someone who currently relies on Bessel's correction not being applied.

I think the best solution is to change this layer to use Bessel's correction in the unfused case before submitting this change. Then we can extensively test the Bessel's correction change before submitting, potentially rolling it back if it breaks someone. Once submitted, this change is a lot safer since it doesn't affect correctness, only performance. I can work on using Bessel's correction in the unfused case since I can run all Google tests before submitting it.

Do you know why Bessel's correction for cuDNN is only used for the moving variance, not the internal variance the input is divided by? We copy this behavior in the CPU version of FusedBatchNorm as well.

@gbaned
Copy link
Contributor

gbaned commented Jun 29, 2020

@benbarsdell Can you please check @reedwm's comments and keep us posted. Thanks!

@gbaned gbaned added the stat:awaiting response Status - Awaiting response from author label Jun 29, 2020
@benbarsdell
Copy link
Contributor Author

Regarding Bessel's correction at inference time but not training time, it seems that the original batchnorm paper indicated doing it this way; there is some discussion here (evidently it is the same way in PyTorch):
pytorch/pytorch#1410
https://stats.stackexchange.com/questions/311074/batch-normalization-variance-calculation

It sounds like you are going to take care of the remaining items internally @reedwm. If I've misinterpreted that or you need me to do anything else just let me know.

@reedwm
Copy link
Member

reedwm commented Jun 30, 2020

Yes I will take care of using Bessel's correction for the nonfused implementation internally. Then we can merge this PR.

@reedwm reedwm added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels Jun 30, 2020
@gbaned
Copy link
Contributor

gbaned commented Jul 9, 2020

@reedwm Any update on this PR? Please. Thanks!

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 11, 2020
@gbaned gbaned added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 14, 2020
@gbaned
Copy link
Contributor

gbaned commented Aug 6, 2020

@reedwm Any update on this PR? Please. Thanks!

@gbaned
Copy link
Contributor

gbaned commented Aug 25, 2020

@benbarsdell Can you please resolve conflicts? Thanks!

@gbaned gbaned removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 25, 2020
- Conflicts fixed in normalization.py and normalization_test.py in
  tensorflow/python/keras/layers/.
@gbaned gbaned requested review from rmlarsen and reedwm and removed request for rmlarsen September 1, 2020 17:09
@gbaned gbaned added the awaiting review Pull request awaiting review label Sep 1, 2020
- This is an in-graph check required due to the use of Bessel's
  correction in fused batchnorm.
- Previously the layer would silently return invalid or NaN results
  if N <= 1.
- Also fixes a test that failed this condition (since it is now run with
  fused=True).
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Sep 3, 2020
@rmlarsen rmlarsen self-requested a review September 3, 2020 19:23
Copy link
Member

@rmlarsen rmlarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reed wanted to discuss a few design issues before approving.

@rmlarsen rmlarsen removed the ready to pull PR ready for merge process label Sep 3, 2020
@tensorflowbutler tensorflowbutler removed the awaiting review Pull request awaiting review label Sep 5, 2020
@gbaned
Copy link
Contributor

gbaned commented Sep 7, 2020

@benbarsdell Can you please resolve conflicts? Thanks!

@gbaned gbaned added comp:keras Keras related issues and removed kokoro:force-run Tests on submitted change labels Sep 7, 2020
# The use of Bessel's correction in training mode imposes the requirement
# that the number of elements in the reduced dimensions is > 1.
check = control_flow_ops.Assert(
training == False or math_ops.reduce_prod(input_shape) > self.depth,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think training can also be a tensor. The assert should be added to this function I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also probably better to also raise an error immediately if the shape and training boolean are statically known

@benbarsdell
Copy link
Contributor Author

It turns out that the reshapes added by this PR prevent the grappler layout optimizer from optimizing out all the NCHW<->NHWC transposes.

For this reason I think we will have to abandon this PR and use #42970 instead.

PR Queue automation moved this from Reviewer Requested Changes to Closed/Rejected Sep 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes comp:keras Keras related issues size:M CL Change Size: Medium
Projects
PR Queue
  
Closed/Rejected
Development

Successfully merging this pull request may close these issues.

None yet

6 participants