Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fused batchnorm with any ndims and axis #40338

Closed
102 changes: 75 additions & 27 deletions tensorflow/python/keras/layers/normalization.py
Expand Up @@ -198,10 +198,8 @@ def __init__(self,
if self._USE_V2_BEHAVIOR:
if fused:
self._raise_if_fused_cannot_be_used()
# We leave fused as None if self._fused_can_be_used()==True, since we
# still may set it to False in self.build() if the input rank is not 4.
elif fused is None and not self._fused_can_be_used():
fused = False
elif fused is None:
fused = self._fused_can_be_used()
elif fused is None:
fused = True
self.supports_masking = True
Expand All @@ -221,26 +219,20 @@ def __init__(self,

def _raise_if_fused_cannot_be_used(self):
"""Raises a ValueError if fused implementation cannot be used.

In addition to the checks done in this function, the input tensors rank must
be 4. The input rank check can only be done once the input shape is known.
"""
# Note the ValueErrors in this function are caught and not reraised in
# _fused_can_be_used(). No other exception besides ValueError should be
# raised here.

# Currently fused batch norm doesn't support renorm. It also only supports a
# channel dimension on axis 1 or 3, when no virtual batch size or adjustment
# is used.
# single axis, when no virtual batch size or adjustment is used.
if self.renorm:
raise ValueError('Passing both fused=True and renorm=True is '
'unsupported')
axis = [self.axis] if isinstance(self.axis, int) else self.axis
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
# input rank is required to be 4 (which is checked later).
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
raise ValueError('Passing fused=True is only supported when axis is 1 '
'or 3')
if len(axis) > 1:
raise ValueError('Passing fused=True is only supported when operating '
'over a single axis.')
if self.virtual_batch_size is not None:
raise ValueError('Passing fused=True is unsupported when '
'virtual_batch_size is specified.')
Expand Down Expand Up @@ -281,6 +273,51 @@ def _support_zero_size_input(self):
distribution_strategy_context.get_strategy().extended,
'experimental_enable_get_next_as_optional', False)

def _get_shape_and_axis_for_fused(self, nd_shape, nd_axis):
'''Returns a 4D shape and axis (1 or 3) to which nd_shape and nd_axis can
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format the docstring like the others: use triple quotes ("""), have the first line be a single setence, then a blank link, then a more detailed description, and have an Args and Returns section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

be changed without changing the result of the batch normalization operation.
'''
assert(isinstance(nd_axis, int))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python 2, a long is also acceptable. Use six.integer_types or just remove the assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(axis, int) is already used in several places in this class, including in an error check at the beginning of the __init__ function.

ndims = len(nd_shape)
shape = nd_shape[:]
axis = nd_shape + nd_axis if nd_axis < 0 else nd_axis
# First check if the axis needs to be moved.
if axis not in (1, ndims - 1):
# Move axis to dim 1.
if axis == 0:
# Transform [C, ...] to [1, C, ...].
shape.insert(0, 1)
ndims += 1
else:
# Merge excess pre-axis dims into first dim.
# Transform [N, ..., C, ...] to [product(N, ...), C, ...].
for dim in range(axis - 1, 0, -1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this block, I think the following is simpler

product = 1
for elem in shape[:axis]:
    product *= elem
shape[:axis] = [product]
ndims -= (axis - 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

shape[0] *= shape[dim]
del shape[dim]
ndims -= 1
axis = 1
# Now change shape to 4D.
is_channels_last = axis == ndims - 1
if ndims < 4:
# Insert new dims after existing spatial dim or before channel dim.
new_dims = [1] * (4 - ndims)
if is_channels_last:
# Transform [..., C] to [..., 1..., C] (ndims=4).
shape = shape[:-1] + new_dims + shape[-1:]
else:
# Transform [N, C, ...] to [N, C, ..., 1...] (ndims=4).
shape += new_dims
elif ndims > 4:
# Merge excess spatial dims into the second spatial dim.
# Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)].
# Or [N, H, W, ..., C] to [N, H, product(W, ...), C].
merge_dim = 2 if is_channels_last else 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, maybe do:

product = 1
for elem in shape[merge_dim:merg_dim + ndims - 4]:
  product *= elem
shape[merge_dim:merg_dim + ndims - 4] = [product]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for dim in range(merge_dim + (ndims - 4), merge_dim, -1):
shape[merge_dim] *= shape[dim]
del shape[dim]
axis = 3 if is_channels_last else 1
return shape, axis

def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
Expand Down Expand Up @@ -315,18 +352,22 @@ def build(self, input_shape):
raise ValueError('When using virtual_batch_size, adjustment cannot '
'be specified')

fused_axis = self.axis
self._input_fused_shape = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

if self.fused in (None, True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.fused can no longer be None, so just have this be if self.fused:. Then, this can be merged with the if self.fused: block below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
if self._USE_V2_BEHAVIOR:
if self.fused is None:
self.fused = (ndims == 4)
elif self.fused and ndims != 4:
raise ValueError('Batch normalization layers with fused=True only '
'support 4D input tensors.')
else:
if len(self.axis) == 1 and (self.axis[0] not in (1, ndims - 1) or
ndims != 4):
# The fused implementation only supports NCHW or NHWC, so we will
# reshape the input/output tensor to/from an equivalent 4D shape.
fused_shape, fused_axis = self._get_shape_and_axis_for_fused(
input_shape.dims, self.axis[0])
fused_shape = tensor_shape.TensorShape(fused_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert this to a TensorShape then immediately convert back to a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. (Twas a remnant of a fight I had initially with TensorShape vs. list, None vs. -1 etc.).

self._input_fused_shape = [-1] + fused_shape.as_list()[1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with computing this in build() is that the input shape might change between calls to the layer. The channel dimension must stay the same since that affects the size of the weights, but other dimensions can differ.

Instead, compute this in call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

fused_axis = [fused_axis]

if not self._USE_V2_BEHAVIOR:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't want to run the if len(self.axis) == 1 block above if self.fused ends up being False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

assert self.fused is not None
self.fused = (ndims == 4 and self._fused_can_be_used())
self.fused = self._fused_can_be_used()
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping
Expand All @@ -335,9 +376,9 @@ def build(self, input_shape):
# common use case (turning 5D w/ virtual batch to NCHW)

if self.fused:
if self.axis == [1]:
if fused_axis == [1]:
self._data_format = 'NCHW'
elif self.axis == [3]:
elif fused_axis == [3]:
self._data_format = 'NHWC'
else:
raise ValueError('Unsupported axis, fused batch norm only supports '
Expand Down Expand Up @@ -499,6 +540,10 @@ def _fused_batch_norm(self, inputs, training):
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const

original_shape = [-1] + inputs.shape.as_list()[1:]
if self._input_fused_shape is not None:
inputs = array_ops.reshape(inputs, self._input_fused_shape)

# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if self._support_zero_size_input():
Expand Down Expand Up @@ -575,8 +620,11 @@ def _fused_batch_norm_inference():

output, mean, variance = tf_utils.smart_cond(training, train_op,
_fused_batch_norm_inference)
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)

if self._input_fused_shape is not None:
output = array_ops.reshape(output, original_shape)

variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)
training_value = tf_utils.constant_value(training)
if training_value or training_value is None:
if not use_fused_avg_updates:
Expand Down
48 changes: 34 additions & 14 deletions tensorflow/python/keras/layers/normalization_test.py
Expand Up @@ -138,6 +138,27 @@ def test_batchnorm_convnet_channel_last(self):
np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)

@keras_parameterized.run_all_keras_modes
def test_batchnorm_convnet_channel_last_3d_fused(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add a test with simple hand-written inputs. No need to test gradients. I'm worried this won't catch errors in _get_shape_and_axis_for_fused. Note gamma and beta default to ones and zeros, so you can effectively ignore them in tests that don't take gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(
axis=-1, input_shape=(4, 4, 4, 3), momentum=0.8, fused=True)
model.add(norm)
model.compile(
loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
run_eagerly=testing_utils.should_run_eagerly())

# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 4, 3))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= np.reshape(keras.backend.eval(norm.beta), (1, 1, 1, 3))
out /= np.reshape(keras.backend.eval(norm.gamma), (1, 1, 1, 3))

np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)

@keras_parameterized.run_all_keras_modes
def test_batchnorm_correctness(self):
_run_batchnorm_correctness_test(
Expand Down Expand Up @@ -213,7 +234,7 @@ def call(self, x, training):
model = MyModel()

for _ in range(10):
x = constant_op.constant(0.5, shape=[1, 1])
x = constant_op.constant(0.5, shape=[2, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is... complicated :)

First, Bessel's correction (N/(N-1)) is used for the fused implementation but not for the non-fused. There is a discussion of this here.

Then, for the fused case, the CPU implementation avoids returning NaN when N < 2, while the GPU implementation (CUDNN) explicitly returns NaN in these cases.

So when the test calls batchnorm with N = 1 and it (with this PR) calls the fused implementation, it returns NaN (on the GPU) and the test fails.
Changing the test size to N = 2 avoids this problem.

This probably isn't ever observed in the real world because batchnorm doesn't work when N is small anyway.

Let me know if you'd like me to take another action here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know that N = 1, we should probably not allow calling the fused version on GPU. I'd prefer an error message over silently producing NaNs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this can't easily be done in the Python layer because it only applies in training mode. Adding a check to build() causes a bunch of tests to fail (ones that don't use training mode). It would have to be an in-graph check, which I don't think there's much precedent for(?).

Copy link
Member

@rmlarsen rmlarsen Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a tf.assert in-graph check here. This is not unprecedented and no different from adding a check in the kernel itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an in-graph check using Assert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It caused a bunch of other tests to fail. I'll have to look into them.

model(x, training=True)

# Make sure the moving mean and variance have been updated
Expand Down Expand Up @@ -255,20 +276,28 @@ def test_basic_batchnorm_v2(self):
normalization_v2.BatchNormalization,
kwargs={'fused': None},
input_shape=(3, 3, 3))
testing_utils.layer_test(
normalization_v2.BatchNormalization,
kwargs={'fused': True},
input_shape=(3, 3, 3, 3, 3))
testing_utils.layer_test(
normalization_v2.BatchNormalization,
kwargs={'fused': True},
input_shape=(3, 3))

@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_v2_fused_attribute(self):
norm = normalization_v2.BatchNormalization()
self.assertEqual(norm.fused, None)
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, True)

norm = normalization_v2.BatchNormalization()
self.assertEqual(norm.fused, None)
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4))
norm(inp)
self.assertEqual(norm.fused, False)
self.assertEqual(norm.fused, True)

norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
self.assertEqual(norm.fused, False)
Expand All @@ -291,10 +320,7 @@ def test_v2_fused_attribute(self):
with self.assertRaisesRegexp(ValueError, 'fused.*renorm'):
normalization_v2.BatchNormalization(fused=True, renorm=True)

with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
normalization_v2.BatchNormalization(fused=True, axis=2)

with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
with self.assertRaisesRegexp(ValueError, 'fused.*over a single axis'):
normalization_v2.BatchNormalization(fused=True, axis=[1, 3])

with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'):
Expand All @@ -304,12 +330,6 @@ def test_v2_fused_attribute(self):
normalization_v2.BatchNormalization(fused=True,
adjustment=lambda _: (1, 0))

norm = normalization_v2.BatchNormalization(fused=True)
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4))
with self.assertRaisesRegexp(ValueError, '4D input tensors'):
norm(inp)

def test_updates_in_wrap_function(self):
with context.eager_mode():
layer = keras.layers.BatchNormalization()
Expand Down