Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fused batchnorm with any ndims and axis #40338

Closed
144 changes: 101 additions & 43 deletions tensorflow/python/keras/layers/normalization.py
Expand Up @@ -198,10 +198,8 @@ def __init__(self,
if self._USE_V2_BEHAVIOR:
if fused:
self._raise_if_fused_cannot_be_used()
# We leave fused as None if self._fused_can_be_used()==True, since we
# still may set it to False in self.build() if the input rank is not 4.
elif fused is None and not self._fused_can_be_used():
fused = False
elif fused is None:
fused = self._fused_can_be_used()
elif fused is None:
fused = True
self.supports_masking = True
Expand All @@ -221,26 +219,20 @@ def __init__(self,

def _raise_if_fused_cannot_be_used(self):
"""Raises a ValueError if fused implementation cannot be used.

In addition to the checks done in this function, the input tensors rank must
be 4. The input rank check can only be done once the input shape is known.
"""
# Note the ValueErrors in this function are caught and not reraised in
# _fused_can_be_used(). No other exception besides ValueError should be
# raised here.

# Currently fused batch norm doesn't support renorm. It also only supports a
# channel dimension on axis 1 or 3, when no virtual batch size or adjustment
# is used.
# single axis, when no virtual batch size or adjustment is used.
if self.renorm:
raise ValueError('Passing both fused=True and renorm=True is '
'unsupported')
axis = [self.axis] if isinstance(self.axis, int) else self.axis
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
# input rank is required to be 4 (which is checked later).
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
raise ValueError('Passing fused=True is only supported when axis is 1 '
'or 3')
if len(axis) > 1:
raise ValueError('Passing fused=True is only supported when operating '
'over a single axis.')
if self.virtual_batch_size is not None:
raise ValueError('Passing fused=True is unsupported when '
'virtual_batch_size is specified.')
Expand Down Expand Up @@ -281,6 +273,62 @@ def _support_zero_size_input(self):
distribution_strategy_context.get_strategy().extended,
'experimental_enable_get_next_as_optional', False)

def _get_shape_and_axis_for_fused(self, nd_shape, nd_axis):
"""Compute an equivalent shape and axis that are compatible with the fused
implementation.

The input/output of the layer can be reshaped to/from the shape returned by
this function without affecting the correctness of the computation.

Arguments:
nd_shape: Tensor. The original shape of the operation.
nd_axis: Integer. The original axis of the operation.

Returns:
shape: Tensor. A 4D shape.
axis: Integer. An axis (always 1 or 3).
"""
assert(isinstance(nd_axis, int))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python 2, a long is also acceptable. Use six.integer_types or just remove the assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(axis, int) is already used in several places in this class, including in an error check at the beginning of the __init__ function.

ndims = nd_shape.shape[0]
shape = nd_shape[:]
axis = nd_shape + nd_axis if nd_axis < 0 else nd_axis
# First check if the axis needs to be moved.
if axis not in (1, ndims - 1):
# Move axis to dim 1.
if axis == 0:
# Transform [C, ...] to [1, C, ...].
shape = array_ops.concat([constant_op.constant([1]), shape], axis=0)
ndims += 1
else:
# Merge excess pre-axis dims into first dim.
# Transform [N, ..., C, ...] to [product(N, ...), C, ...].
product = math_ops.reduce_prod(shape[:axis], keepdims=True)
shape = array_ops.concat([product, shape[axis:]], axis=0)
ndims -= (axis - 1)
axis = 1
# Now change shape to 4D.
is_channels_last = axis == ndims - 1
if ndims < 4:
# Insert new dims after existing spatial dim or before channel dim.
new_dims = constant_op.constant([1] * (4 - ndims))
if is_channels_last:
# Transform [..., C] to [..., 1..., C] (ndims=4).
shape = array_ops.concat([shape[:-1], new_dims, shape[-1:]], axis=0)
else:
# Transform [N, C, ...] to [N, C, ..., 1...] (ndims=4).
shape = array_ops.concat([shape, new_dims], axis=0)
elif ndims > 4:
# Merge excess spatial dims into the second spatial dim.
# Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)].
# Or [N, H, W, ..., C] to [N, H, product(W, ...), C].
merge_dim = 2 if is_channels_last else 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, maybe do:

product = 1
for elem in shape[merge_dim:merg_dim + ndims - 4]:
  product *= elem
shape[merge_dim:merg_dim + ndims - 4] = [product]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

product = math_ops.reduce_prod(
shape[merge_dim:merge_dim + 1 + (ndims - 4)], keepdims=True)
shape = array_ops.concat([shape[:merge_dim], product,
shape[merge_dim + 1 + (ndims - 4):]], axis=0)
axis = 3 if is_channels_last else 1
return shape, axis

def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
Expand Down Expand Up @@ -315,39 +363,16 @@ def build(self, input_shape):
raise ValueError('When using virtual_batch_size, adjustment cannot '
'be specified')

if self.fused in (None, True):
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
if self._USE_V2_BEHAVIOR:
if self.fused is None:
self.fused = (ndims == 4)
elif self.fused and ndims != 4:
raise ValueError('Batch normalization layers with fused=True only '
'support 4D input tensors.')
else:
assert self.fused is not None
self.fused = (ndims == 4 and self._fused_can_be_used())
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
# particularly tricky. A compromise might be to just support the most
# common use case (turning 5D w/ virtual batch to NCHW)

if self.fused:
if self.axis == [1]:
self._data_format = 'NCHW'
elif self.axis == [3]:
self._data_format = 'NHWC'
else:
raise ValueError('Unsupported axis, fused batch norm only supports '
'axis == [1] or axis == [3]')
if self.fused and not self._USE_V2_BEHAVIOR:
self.fused = self._fused_can_be_used()

axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
self.depth = 1
for x in axis_to_dim:
if axis_to_dim[x] is None:
raise ValueError('Input has undefined `axis` dimension. Input shape: ',
input_shape)
self.depth *= axis_to_dim[x]
self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
Expand Down Expand Up @@ -499,6 +524,37 @@ def _fused_batch_norm(self, inputs, training):
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const

original_shape = None
fused_axis = self.axis[0]
input_shape = array_ops.shape(inputs)

# The use of Bessel's correction in training mode imposes the requirement
# that the number of elements in the reduced dimensions is > 1.
check = control_flow_ops.Assert(
training == False or math_ops.reduce_prod(input_shape) > self.depth,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think training can also be a tensor. The assert should be added to this function I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also probably better to also raise an error immediately if the shape and training boolean are statically known

["Number of normalized-over elements must be > 1 (due to use of "
"Bessel's correction in training mode).",
input_shape, self.depth])
input_shape = control_flow_ops.with_dependencies([check], input_shape)

ndims = len(inputs.shape)
if self.axis[0] not in (1, ndims - 1) or ndims != 4:
# The fused implementation only supports NCHW or NHWC, so we reshape the
# input/output tensor to/from an equivalent 4D shape.
fused_shape, fused_axis = self._get_shape_and_axis_for_fused(input_shape,
self.axis[0])
original_shape = input_shape
inputs = array_ops.reshape(inputs, fused_shape)

# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
# particularly tricky. A compromise might be to just support the most
# common use case (turning 5D w/ virtual batch to NCHW)

data_format = 'NCHW' if fused_axis == 1 else 'NHWC'

# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if self._support_zero_size_input():
Expand Down Expand Up @@ -548,7 +604,7 @@ def _fused_batch_norm_training():
self.moving_variance, remove=False),
epsilon=self.epsilon,
is_training=True,
data_format=self._data_format,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor)

def _fused_batch_norm_training_empty():
Expand All @@ -563,7 +619,7 @@ def _fused_batch_norm_inference():
variance=self.moving_variance,
epsilon=self.epsilon,
is_training=False,
data_format=self._data_format)
data_format=data_format)

train_op = _fused_batch_norm_training
if use_fused_avg_updates and input_batch_size is not None:
Expand All @@ -575,6 +631,8 @@ def _fused_batch_norm_inference():

output, mean, variance = control_flow_util.smart_cond(
training, train_op, _fused_batch_norm_inference)
if original_shape is not None:
output = array_ops.reshape(output, original_shape)
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)

training_value = control_flow_util.constant_value(training)
Expand Down