Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[r2.2:Cherrypick] Change Keras batch normalization layer to use the running mean and average computation in fused_batch_norm. #37270

Merged
merged 1 commit into from
Mar 5, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 50 additions & 29 deletions tensorflow/python/keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

from tensorflow.python.compat import compat
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
Expand Down Expand Up @@ -513,9 +514,11 @@ def _assign_moving_average(self, variable, value, momentum, inputs_size):
K.zeros_like(update_delta))
return state_ops.assign_sub(variable, update_delta, name=scope)

def _assign_new_value(self, variable, value):
def _assign_new_value(self, variable, value, inputs_size=None):
with K.name_scope('AssignNewValue') as scope:
with ops.colocate_with(variable):
if inputs_size is not None:
value = array_ops.where(inputs_size > 0, value, variable)
return state_ops.assign(variable, value, name=scope)

def _fused_batch_norm(self, inputs, training):
Expand All @@ -530,13 +533,41 @@ def _fused_batch_norm(self, inputs, training):
else:
inputs_size = None

if compat.forward_compatible(2020, 3, 6):
exponential_avg_factor = 1.0 - self.momentum
else:
exponential_avg_factor = None

def _maybe_add_or_remove_bessels_correction(variance, remove=True):
r"""Add or remove Bessel's correction."""
# Removes Bessel's correction if remove == True, adds it otherwise.
# This is to be consistent with non-fused batch norm. Note that the
# variance computed by fused batch norm is with Bessel's correction.
# This is only used in legacy V1 batch norm tests.
if self._bessels_correction_test_only:
return variance
sample_size = math_ops.cast(
array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
if remove:
factor = (sample_size -
math_ops.cast(1.0, variance.dtype)) / sample_size
else:
factor = sample_size / (
sample_size - math_ops.cast(1.0, variance.dtype))
return variance * factor

def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
mean=self.moving_mean,
variance=_maybe_add_or_remove_bessels_correction(
self.moving_variance, remove=False),
epsilon=self.epsilon,
data_format=self._data_format)
is_training=True,
data_format=self._data_format,
exponential_avg_factor=exponential_avg_factor)

def _fused_batch_norm_inference():
return nn.fused_batch_norm(
Expand All @@ -551,40 +582,30 @@ def _fused_batch_norm_inference():

output, mean, variance = tf_utils.smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
if not self._bessels_correction_test_only:
# Remove Bessel's correction to be consistent with non-fused batch norm.
# Note that the variance computed by fused batch norm is
# with Bessel's correction.
sample_size = math_ops.cast(
array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
variance *= factor
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)

training_value = tf_utils.constant_value(training)
if training_value is None:
momentum = tf_utils.smart_cond(training,
lambda: self.momentum,
lambda: 1.0)
else:
momentum = ops.convert_to_tensor_v2(self.momentum)
if training_value or training_value is None:
if not compat.forward_compatible(2020, 3, 6):
if training_value is None:
momentum = tf_utils.smart_cond(training, lambda: self.momentum,
lambda: 1.0)
else:
momentum = ops.convert_to_tensor_v2(self.momentum)

def mean_update():
return self._assign_moving_average(self.moving_mean, mean, momentum,
inputs_size)
"""Update self.moving_mean with the most recent data point."""
if compat.forward_compatible(2020, 3, 6):
return self._assign_new_value(self.moving_mean, mean, inputs_size)
else:
return self._assign_moving_average(self.moving_mean, mean, momentum,
inputs_size)

def variance_update():
"""Update self.moving_variance with the most recent data point."""
if self.renorm:
# We apply epsilon as part of the moving_stddev to mirror the training
# code path.
moving_stddev = self._assign_moving_average(
self.moving_stddev, math_ops.sqrt(variance + self.epsilon),
momentum, inputs_size)
return self._assign_new_value(
self.moving_variance,
# Apply relu in case floating point rounding causes it to go
# negative.
K.relu(moving_stddev * moving_stddev - self.epsilon))
if compat.forward_compatible(2020, 3, 6):
return self._assign_new_value(self.moving_variance, variance,
inputs_size)
else:
return self._assign_moving_average(self.moving_variance, variance,
momentum, inputs_size)
Expand Down