New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fused batchnorm with any ndims and axis #40338
Changes from 4 commits
63bd4d2
acd6126
4c81be3
cd26114
363b99c
0501f8f
e68ef3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -198,10 +198,8 @@ def __init__(self, | |
if self._USE_V2_BEHAVIOR: | ||
if fused: | ||
self._raise_if_fused_cannot_be_used() | ||
# We leave fused as None if self._fused_can_be_used()==True, since we | ||
# still may set it to False in self.build() if the input rank is not 4. | ||
elif fused is None and not self._fused_can_be_used(): | ||
fused = False | ||
elif fused is None: | ||
fused = self._fused_can_be_used() | ||
elif fused is None: | ||
fused = True | ||
self.supports_masking = True | ||
|
@@ -221,26 +219,20 @@ def __init__(self, | |
|
||
def _raise_if_fused_cannot_be_used(self): | ||
"""Raises a ValueError if fused implementation cannot be used. | ||
|
||
In addition to the checks done in this function, the input tensors rank must | ||
be 4. The input rank check can only be done once the input shape is known. | ||
""" | ||
# Note the ValueErrors in this function are caught and not reraised in | ||
# _fused_can_be_used(). No other exception besides ValueError should be | ||
# raised here. | ||
|
||
# Currently fused batch norm doesn't support renorm. It also only supports a | ||
# channel dimension on axis 1 or 3, when no virtual batch size or adjustment | ||
# is used. | ||
# single axis, when no virtual batch size or adjustment is used. | ||
if self.renorm: | ||
raise ValueError('Passing both fused=True and renorm=True is ' | ||
'unsupported') | ||
axis = [self.axis] if isinstance(self.axis, int) else self.axis | ||
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the | ||
# input rank is required to be 4 (which is checked later). | ||
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3): | ||
raise ValueError('Passing fused=True is only supported when axis is 1 ' | ||
'or 3') | ||
if len(axis) > 1: | ||
raise ValueError('Passing fused=True is only supported when operating ' | ||
'over a single axis.') | ||
if self.virtual_batch_size is not None: | ||
raise ValueError('Passing fused=True is unsupported when ' | ||
'virtual_batch_size is specified.') | ||
|
@@ -281,6 +273,66 @@ def _support_zero_size_input(self): | |
distribution_strategy_context.get_strategy().extended, | ||
'experimental_enable_get_next_as_optional', False) | ||
|
||
def _get_shape_and_axis_for_fused(self, nd_shape, nd_axis): | ||
"""Compute an equivalent shape and axis that are compatible with the fused | ||
implementation. | ||
|
||
The input/output of the layer can be reshaped to/from the shape returned by | ||
this function without affecting the correctness of the computation. | ||
|
||
Arguments: | ||
nd_shape: Tensor. The original shape of the operation. | ||
nd_axis: Integer. The original axis of the operation. | ||
|
||
Returns: | ||
shape: Tensor. A 4D shape. | ||
axis: Integer. An axis (always 1 or 3). | ||
""" | ||
assert(isinstance(nd_axis, int)) | ||
ndims = nd_shape.shape[0] | ||
shape = nd_shape[:] | ||
axis = nd_shape + nd_axis if nd_axis < 0 else nd_axis | ||
# First check if the axis needs to be moved. | ||
if axis not in (1, ndims - 1): | ||
# Move axis to dim 1. | ||
if axis == 0: | ||
# Transform [C, ...] to [1, C, ...]. | ||
shape = array_ops.concat([constant_op.constant([1]), shape], axis=0) | ||
ndims += 1 | ||
else: | ||
# Merge excess pre-axis dims into first dim. | ||
# Transform [N, ..., C, ...] to [product(N, ...), C, ...]. | ||
product = math_ops.reduce_prod(shape[:axis]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you pass And same for the place below you call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
shape = array_ops.concat([array_ops.reshape(product, | ||
constant_op.constant([1])), | ||
shape[axis:]], axis=0) | ||
ndims -= (axis - 1) | ||
axis = 1 | ||
# Now change shape to 4D. | ||
is_channels_last = axis == ndims - 1 | ||
if ndims < 4: | ||
# Insert new dims after existing spatial dim or before channel dim. | ||
new_dims = constant_op.constant([1] * (4 - ndims)) | ||
if is_channels_last: | ||
# Transform [..., C] to [..., 1..., C] (ndims=4). | ||
shape = array_ops.concat([shape[:-1], new_dims, shape[-1:]], axis=0) | ||
else: | ||
# Transform [N, C, ...] to [N, C, ..., 1...] (ndims=4). | ||
shape = array_ops.concat([shape, new_dims], axis=0) | ||
elif ndims > 4: | ||
# Merge excess spatial dims into the second spatial dim. | ||
# Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)]. | ||
# Or [N, H, W, ..., C] to [N, H, product(W, ...), C]. | ||
merge_dim = 2 if is_channels_last else 3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above, maybe do: product = 1
for elem in shape[merge_dim:merg_dim + ndims - 4]:
product *= elem
shape[merge_dim:merg_dim + ndims - 4] = [product] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
product = math_ops.reduce_prod( | ||
shape[merge_dim:merge_dim + 1 + (ndims - 4)]) | ||
shape = array_ops.concat([shape[:merge_dim], | ||
array_ops.reshape(product, | ||
constant_op.constant([1])), | ||
shape[merge_dim + 1 + (ndims - 4):]], axis=0) | ||
axis = 3 if is_channels_last else 1 | ||
return shape, axis | ||
|
||
def build(self, input_shape): | ||
input_shape = tensor_shape.TensorShape(input_shape) | ||
if not input_shape.ndims: | ||
|
@@ -315,33 +367,8 @@ def build(self, input_shape): | |
raise ValueError('When using virtual_batch_size, adjustment cannot ' | ||
'be specified') | ||
|
||
if self.fused in (None, True): | ||
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the | ||
# output back to its original shape accordingly. | ||
if self._USE_V2_BEHAVIOR: | ||
if self.fused is None: | ||
self.fused = (ndims == 4) | ||
elif self.fused and ndims != 4: | ||
raise ValueError('Batch normalization layers with fused=True only ' | ||
'support 4D input tensors.') | ||
else: | ||
assert self.fused is not None | ||
self.fused = (ndims == 4 and self._fused_can_be_used()) | ||
# TODO(chrisying): fused batch norm is currently not supported for | ||
# multi-axis batch norm and by extension virtual batches. In some cases, | ||
# it might be possible to use fused batch norm but would require reshaping | ||
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is | ||
# particularly tricky. A compromise might be to just support the most | ||
# common use case (turning 5D w/ virtual batch to NCHW) | ||
|
||
if self.fused: | ||
if self.axis == [1]: | ||
self._data_format = 'NCHW' | ||
elif self.axis == [3]: | ||
self._data_format = 'NHWC' | ||
else: | ||
raise ValueError('Unsupported axis, fused batch norm only supports ' | ||
'axis == [1] or axis == [3]') | ||
if self.fused and not self._USE_V2_BEHAVIOR: | ||
self.fused = self._fused_can_be_used() | ||
|
||
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} | ||
for x in axis_to_dim: | ||
|
@@ -499,6 +526,30 @@ def _fused_batch_norm(self, inputs, training): | |
beta = self.beta if self.center else self._beta_const | ||
gamma = self.gamma if self.scale else self._gamma_const | ||
|
||
original_shape = None | ||
fused_axis = self.axis[0] | ||
input_shape = array_ops.shape(inputs) | ||
ndims = len(inputs.shape) | ||
if self.axis[0] not in (1, ndims - 1) or ndims != 4: | ||
# The fused implementation only supports NCHW or NHWC, so we reshape the | ||
# input/output tensor to/from an equivalent 4D shape. | ||
fused_shape, fused_axis = self._get_shape_and_axis_for_fused(input_shape, | ||
self.axis[0]) | ||
original_shape = array_ops.concat( | ||
[constant_op.constant([-1]), input_shape[1:]], axis=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed. |
||
fused_shape = array_ops.concat( | ||
[constant_op.constant([-1]), fused_shape[1:]], axis=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need to replace the first element of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I've removed these operations. |
||
inputs = array_ops.reshape(inputs, fused_shape) | ||
|
||
# TODO(chrisying): fused batch norm is currently not supported for | ||
# multi-axis batch norm and by extension virtual batches. In some cases, | ||
# it might be possible to use fused batch norm but would require reshaping | ||
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is | ||
# particularly tricky. A compromise might be to just support the most | ||
# common use case (turning 5D w/ virtual batch to NCHW) | ||
|
||
data_format = 'NCHW' if fused_axis == 1 else 'NHWC' | ||
|
||
# TODO(b/129279393): Support zero batch input in non DistributionStrategy | ||
# code as well. | ||
if self._support_zero_size_input(): | ||
|
@@ -548,7 +599,7 @@ def _fused_batch_norm_training(): | |
self.moving_variance, remove=False), | ||
epsilon=self.epsilon, | ||
is_training=True, | ||
data_format=self._data_format, | ||
data_format=data_format, | ||
exponential_avg_factor=exponential_avg_factor) | ||
|
||
def _fused_batch_norm_training_empty(): | ||
|
@@ -563,7 +614,7 @@ def _fused_batch_norm_inference(): | |
variance=self.moving_variance, | ||
epsilon=self.epsilon, | ||
is_training=False, | ||
data_format=self._data_format) | ||
data_format=data_format) | ||
|
||
train_op = _fused_batch_norm_training | ||
if use_fused_avg_updates and input_batch_size is not None: | ||
|
@@ -575,8 +626,11 @@ def _fused_batch_norm_inference(): | |
|
||
output, mean, variance = tf_utils.smart_cond(training, train_op, | ||
_fused_batch_norm_inference) | ||
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True) | ||
|
||
if original_shape is not None: | ||
output = array_ops.reshape(output, original_shape) | ||
|
||
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True) | ||
training_value = tf_utils.constant_value(training) | ||
if training_value or training_value is None: | ||
if not use_fused_avg_updates: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,6 +138,70 @@ def test_batchnorm_convnet_channel_last(self): | |
np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1) | ||
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1) | ||
|
||
@keras_parameterized.run_all_keras_modes | ||
def test_batchnorm_convnet_channel_last_3d_fused(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would also add a test with simple hand-written inputs. No need to test gradients. I'm worried this won't catch errors in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
model = keras.models.Sequential() | ||
norm = keras.layers.BatchNormalization( | ||
axis=-1, input_shape=(4, 4, 4, 3), momentum=0.8, fused=True) | ||
model.add(norm) | ||
model.compile( | ||
loss='mse', | ||
optimizer=gradient_descent.GradientDescentOptimizer(0.01), | ||
run_eagerly=testing_utils.should_run_eagerly()) | ||
|
||
# centered on 5.0, variance 10.0 | ||
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 4, 3)) | ||
model.fit(x, x, epochs=4, verbose=0) | ||
out = model.predict(x) | ||
out -= np.reshape(keras.backend.eval(norm.beta), (1, 1, 1, 1, 3)) | ||
out /= np.reshape(keras.backend.eval(norm.gamma), (1, 1, 1, 1, 3)) | ||
|
||
np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2, 3)), 0.0, atol=1e-1) | ||
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2, 3)), 1.0, atol=1e-1) | ||
|
||
@keras_parameterized.run_all_keras_modes | ||
def test_batchnorm_convnet_channel_last_3d_fused_correctness(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are no non-4D fused tests where the axis isn't the last element. Either change this test or the previous to have the axis be not the last element, or add a new test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added channels-first versions of the two tests. |
||
model = keras.models.Sequential() | ||
norm = keras.layers.BatchNormalization(axis=-1, momentum=0.8, fused=True) | ||
model.add(norm) | ||
model.compile( | ||
loss='mse', | ||
optimizer=gradient_descent.GradientDescentOptimizer(0.01)) | ||
|
||
# Sequential values ensure the result is axis-dependent. | ||
x = np.arange(5 * 5 * 5 * 5 * 3).reshape([5, 5, 5, 5, 3]) | ||
x = x.astype(np.float32) | ||
model.fit(x, x, epochs=1000, verbose=0) | ||
moving_mean = keras.backend.eval(norm.moving_mean) | ||
moving_variance = keras.backend.eval(norm.moving_variance) | ||
x_mean = x.mean(axis=(0, 1, 2, 3)) | ||
moving_mean_target = x_mean | ||
moving_variance_target = x.var(axis=(0, 1, 2, 3)) | ||
np.testing.assert_allclose( | ||
moving_mean, moving_mean_target, rtol=1e-5) | ||
np.testing.assert_allclose( | ||
moving_variance, moving_variance_target, rtol=1e-2) | ||
|
||
beta = np.reshape(keras.backend.eval(norm.beta), (1, 1, 1, 1, 3)) | ||
gamma = np.reshape(keras.backend.eval(norm.gamma), (1, 1, 1, 1, 3)) | ||
|
||
out = (model.predict(x) - beta) / gamma | ||
np.testing.assert_allclose( | ||
np.mean(out, axis=(0, 1, 2, 3)), 0.0, atol=1e-2) | ||
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2, 3)), 1.0, atol=1e-2) | ||
|
||
# Test with changed input shape. | ||
y = np.arange(7 * 7 * 7 * 7 * 3).reshape([7, 7, 7, 7, 3]) | ||
y = y.astype(np.float32) | ||
out = (model.predict(y) - beta) / gamma | ||
x_std = x.std(axis=(0, 1, 2, 3)) | ||
out_mean_target = (y.mean(axis=(0, 1, 2, 3)) - x_mean) / x_std | ||
out_std_target = y.std(axis=(0, 1, 2, 3)) / x_std | ||
np.testing.assert_allclose( | ||
np.mean(out, axis=(0, 1, 2, 3)), out_mean_target, atol=1e-2) | ||
np.testing.assert_allclose( | ||
np.std(out, axis=(0, 1, 2, 3)), out_std_target, atol=1e-2) | ||
|
||
@keras_parameterized.run_all_keras_modes | ||
def test_batchnorm_correctness(self): | ||
_run_batchnorm_correctness_test( | ||
|
@@ -213,7 +277,7 @@ def call(self, x, training): | |
model = MyModel() | ||
|
||
for _ in range(10): | ||
x = constant_op.constant(0.5, shape=[1, 1]) | ||
x = constant_op.constant(0.5, shape=[2, 1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason is... complicated :) First, Bessel's correction ( Then, for the fused case, the CPU implementation avoids returning NaN when N < 2, while the GPU implementation (CUDNN) explicitly returns NaN in these cases. So when the test calls batchnorm with N = 1 and it (with this PR) calls the fused implementation, it returns NaN (on the GPU) and the test fails. This probably isn't ever observed in the real world because batchnorm doesn't work when N is small anyway. Let me know if you'd like me to take another action here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we know that N = 1, we should probably not allow calling the fused version on GPU. I'd prefer an error message over silently producing NaNs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately this can't easily be done in the Python layer because it only applies in training mode. Adding a check to build() causes a bunch of tests to fail (ones that don't use training mode). It would have to be an in-graph check, which I don't think there's much precedent for(?). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can add a tf.assert in-graph check here. This is not unprecedented and no different from adding a check in the kernel itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added an in-graph check using Assert. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It caused a bunch of other tests to fail. I'll have to look into them. |
||
model(x, training=True) | ||
|
||
# Make sure the moving mean and variance have been updated | ||
|
@@ -255,20 +319,28 @@ def test_basic_batchnorm_v2(self): | |
normalization_v2.BatchNormalization, | ||
kwargs={'fused': None}, | ||
input_shape=(3, 3, 3)) | ||
testing_utils.layer_test( | ||
normalization_v2.BatchNormalization, | ||
kwargs={'fused': True}, | ||
input_shape=(3, 3, 3, 3, 3)) | ||
testing_utils.layer_test( | ||
normalization_v2.BatchNormalization, | ||
kwargs={'fused': True}, | ||
input_shape=(3, 3)) | ||
|
||
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) | ||
def test_v2_fused_attribute(self): | ||
norm = normalization_v2.BatchNormalization() | ||
self.assertEqual(norm.fused, None) | ||
self.assertEqual(norm.fused, True) | ||
inp = keras.layers.Input(shape=(4, 4, 4)) | ||
norm(inp) | ||
self.assertEqual(norm.fused, True) | ||
|
||
norm = normalization_v2.BatchNormalization() | ||
self.assertEqual(norm.fused, None) | ||
self.assertEqual(norm.fused, True) | ||
inp = keras.layers.Input(shape=(4, 4)) | ||
norm(inp) | ||
self.assertEqual(norm.fused, False) | ||
self.assertEqual(norm.fused, True) | ||
|
||
norm = normalization_v2.BatchNormalization(virtual_batch_size=2) | ||
self.assertEqual(norm.fused, False) | ||
|
@@ -291,10 +363,7 @@ def test_v2_fused_attribute(self): | |
with self.assertRaisesRegexp(ValueError, 'fused.*renorm'): | ||
normalization_v2.BatchNormalization(fused=True, renorm=True) | ||
|
||
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'): | ||
normalization_v2.BatchNormalization(fused=True, axis=2) | ||
|
||
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'): | ||
with self.assertRaisesRegexp(ValueError, 'fused.*over a single axis'): | ||
normalization_v2.BatchNormalization(fused=True, axis=[1, 3]) | ||
|
||
with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'): | ||
|
@@ -304,12 +373,6 @@ def test_v2_fused_attribute(self): | |
normalization_v2.BatchNormalization(fused=True, | ||
adjustment=lambda _: (1, 0)) | ||
|
||
norm = normalization_v2.BatchNormalization(fused=True) | ||
self.assertEqual(norm.fused, True) | ||
inp = keras.layers.Input(shape=(4, 4)) | ||
with self.assertRaisesRegexp(ValueError, '4D input tensors'): | ||
norm(inp) | ||
|
||
def test_updates_in_wrap_function(self): | ||
with context.eager_mode(): | ||
layer = keras.layers.BatchNormalization() | ||
|
@@ -362,6 +425,23 @@ def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False): | |
class NormalizationLayersGraphModeOnlyTest( | ||
test.TestCase, parameterized.TestCase): | ||
|
||
def test_unknown_shape_batchnorm(self, layer): | ||
"""Test that a BN layer supports unknown input shapes in graph mode.""" | ||
with self.cached_session(): | ||
def run_model(input_shape): | ||
bn = layer() | ||
x = keras.layers.Input(shape=input_shape) | ||
y = bn(x) | ||
model = keras.models.Model(x, y) | ||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') | ||
known_shape = [2] + [5 if dim is None else dim for dim in input_shape] | ||
val_a = np.random.random(known_shape) | ||
_ = model.predict(val_a) | ||
|
||
run_model((None, 10)) | ||
run_model((None, None, 10)) | ||
run_model((None, None, None, 10)) | ||
|
||
def test_shared_batchnorm(self, layer): | ||
"""Test that a BN layer can be shared across different data streams.""" | ||
with self.cached_session(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Python 2, a
long
is also acceptable. Usesix.integer_types
or just remove the assertThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isinstance(axis, int)
is already used in several places in this class, including in an error check at the beginning of the__init__
function.