Skip to content

Commit

Permalink
Allow batch_norm to use a is_training Tensor, which can be modified a…
Browse files Browse the repository at this point in the history
…t run time.

Added tf.Assert() to make sure the update_ops are not executed when is_training is False.
Change: 127443556
  • Loading branch information
tensorflower-gardener committed Jul 14, 2016
1 parent 6472759 commit 9da5fc8
Show file tree
Hide file tree
Showing 5 changed files with 492 additions and 64 deletions.
13 changes: 13 additions & 0 deletions tensorflow/contrib/layers/BUILD
Expand Up @@ -226,6 +226,19 @@ py_test(
],
)

py_test(
name = "utils_test",
size = "small",
srcs = ["python/layers/utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

filegroup(
name = "all_files",
srcs = glob(
Expand Down
73 changes: 41 additions & 32 deletions tensorflow/contrib/layers/python/layers/layers.py
Expand Up @@ -31,7 +31,6 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
Expand Down Expand Up @@ -206,29 +205,48 @@ def batch_norm(inputs,
initializer=init_ops.ones_initializer,
trainable=False,
collections=moving_variance_collections)
if is_training:
# Calculate the moments based on the individual batch.

is_training_value = utils.constant_value(is_training)
# Calculate the moments based on the individual batch.
need_moments = is_training_value is None or is_training_value
if need_moments:
mean, variance = nn.moments(inputs, axis, shift=moving_mean)
# Update the moving_mean and moving_variance moments.
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
# Make sure the updates are computed here.
with ops.control_dependencies([update_moving_mean,
update_moving_variance]):
outputs = nn.batch_normalization(
inputs, mean, variance, beta, gamma, epsilon)
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
with ops.control_dependencies([update_moving_mean,
update_moving_variance]):
return array_ops.identity(mean), array_ops.identity(variance)
mean, variance = utils.smart_cond(is_training,
_force_updates,
moving_vars_fn)
else:
# Collect the updates to be computed later.
ops.add_to_collections(updates_collections, update_moving_mean)
ops.add_to_collections(updates_collections, update_moving_variance)
outputs = nn.batch_normalization(
inputs, mean, variance, beta, gamma, epsilon)
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
return update_moving_mean, update_moving_variance

update_mean, update_variance = utils.smart_cond(is_training,
_delay_updates,
moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)
# Use computed moments during training and moving_vars otherwise.
vars_fn = lambda: (mean, variance)
mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
else:
outputs = nn.batch_normalization(
inputs, moving_mean, moving_variance, beta, gamma, epsilon)
mean, variance = moving_mean, moving_variance
# Compute batch_normalization.
outputs = nn.batch_normalization(
inputs, mean, variance, beta, gamma, epsilon)
outputs.set_shape(inputs_shape)
if activation_fn:
outputs = activation_fn(outputs)
Expand Down Expand Up @@ -430,18 +448,9 @@ def dropout(inputs,
"""
with ops.op_scope([inputs], scope, 'Dropout') as sc:
inputs = ops.convert_to_tensor(inputs)
is_training_value = utils.constant_value(is_training, dtypes.bool)
if is_training_value is not None:
if is_training_value:
outputs = nn.dropout(inputs, keep_prob, noise_shape)
else:
outputs = inputs
else:
def _dropout():
return nn.dropout(inputs, keep_prob, noise_shape)
outputs = control_flow_ops.cond(is_training,
_dropout,
lambda: inputs)
dropout_fn = lambda: nn.dropout(inputs, keep_prob, noise_shape)
id_fn = lambda: inputs
outputs = utils.smart_cond(is_training, dropout_fn, id_fn)
return utils.collect_named_outputs(outputs_collections, sc, outputs)


Expand Down
196 changes: 181 additions & 15 deletions tensorflow/contrib/layers/python/layers/layers_test.py
Expand Up @@ -429,15 +429,24 @@ def testCreateDropout(self):
output.get_shape().assert_is_compatible_with(
tf.convert_to_tensor(images).get_shape())

def testCreateDropoutWithConstant(self):
def testCreateDropoutWithConstantTrue(self):
height, width = 3, 3
with self.test_session():
is_training = tf.constant(False)
is_training = tf.constant(True)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images, is_training=is_training)
self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
output.get_shape().assert_is_compatible_with(images.get_shape())

def testCreateDropoutWithConstantFalse(self):
height, width = 3, 3
with self.test_session():
is_training = tf.constant(False)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images, is_training=is_training)
self.assertEquals(output, images)
output.get_shape().assert_is_compatible_with(images.get_shape())

def testCreateDropoutWithPlaceholder(self):
height, width = 3, 3
with self.test_session():
Expand Down Expand Up @@ -796,7 +805,7 @@ def testCreateMovingVars(self):
self.assertEquals(len(moving_variance), 1)
self.assertEquals(moving_variance[0].op.name, 'BatchNorm/moving_variance')

def testForceUpdateMovingVars(self):
def testNoneUpdatesCollections(self):
height, width = 3, 3
with self.test_session() as sess:
image_shape = (10, height, width, 3)
Expand All @@ -806,6 +815,8 @@ def testForceUpdateMovingVars(self):
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images, decay=0.1,
updates_collections=None)
# updates_ops are not added to UPDATE_OPS collection.
self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), [])
# Initialize all variables
sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables(
Expand Down Expand Up @@ -835,6 +846,8 @@ def testDelayedUpdateMovingVars(self):
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images, decay=0.1)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# updates_ops are added to UPDATE_OPS collection.
self.assertEquals(len(update_ops), 2)
with tf.control_dependencies(update_ops):
barrier = tf.no_op(name='barrier')
output = control_flow_ops.with_dependencies([barrier], output)
Expand Down Expand Up @@ -868,8 +881,7 @@ def testEvalMovingVars(self):
output = tf.contrib.layers.batch_norm(images,
decay=0.1,
is_training=False)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
self.assertEquals(update_ops, [])
self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), [])
# Initialize all variables
sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables(
Expand Down Expand Up @@ -901,11 +913,55 @@ def testReuseVars(self):
expected_mean = np.mean(image_values, axis=(0, 1, 2))
expected_var = np.var(image_values, axis=(0, 1, 2))
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output_train = tf.contrib.layers.batch_norm(images,
decay=0.1,
is_training=True,
scope='BN')
output_eval = tf.contrib.layers.batch_norm(images,
decay=0.1,
is_training=False,
scope='BN',
reuse=True)
# Initialize all variables
sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables(
'BN/moving_mean')[0]
moving_variance = tf.contrib.framework.get_variables(
'BN/moving_variance')[0]
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 3)
self.assertAllClose(variance, [1] * 3)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
barrier = tf.no_op(name='barrier')
train_op = control_flow_ops.with_dependencies([barrier], output_train)
# Before updates the outputs are different for train and eval.
self.assertFalse(np.allclose(sess.run([output_train]),
sess.run([output_eval])))
for _ in range(10):
sess.run([train_op])
mean = moving_mean.eval()
variance = moving_variance.eval()
# After 10 updates with decay 0.1 moving_mean == expected_mean and
# moving_variance == expected_var.
self.assertAllClose(mean, expected_mean)
self.assertAllClose(variance, expected_var)
# After convergence output_train and output_eval should be the same.
self.assertAllClose(sess.run([output_train]), sess.run([output_eval]))

def testIsTrainingVariable(self):
height, width = 3, 3
with self.test_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
expected_var = np.var(image_values, axis=(0, 1, 2))
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
is_training = tf.Variable(True)
output = tf.contrib.layers.batch_norm(images,
decay=0.1,
is_training=False)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
self.assertEquals(update_ops, [])
is_training=is_training)
# Initialize all variables
sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables(
Expand All @@ -916,18 +972,128 @@ def testReuseVars(self):
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 3)
self.assertAllClose(variance, [1] * 3)
# Simulate assigment from saver restore.
init_assigns = [tf.assign(moving_mean, expected_mean),
tf.assign(moving_variance, expected_var)]
sess.run(init_assigns)
# Before updates the outputs are different depending of is_training.
output_true = sess.run([output], {is_training: True})
output_false = sess.run([output], {is_training: False})
self.assertFalse(np.allclose(output_true, output_false))
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
barrier = tf.no_op(name='barrier')
train_op = control_flow_ops.with_dependencies([barrier], output)
for _ in range(10):
sess.run([output], {images: np.random.rand(*image_shape)})
sess.run([train_op])
mean = moving_mean.eval()
variance = moving_variance.eval()
# Although we feed different images, the moving_mean and moving_variance
# shouldn't change.
# After 10 updates with decay 0.1 moving_mean == expected_mean and
# moving_variance == expected_var.
self.assertAllClose(mean, expected_mean)
self.assertAllClose(variance, expected_var)
# After updates to convergence the outputs don't depend on is_training.
output_true = sess.run([output], {is_training: True})
output_false = sess.run([output], {is_training: False})
self.assertAllClose(output_true, output_false)

def testNoUpdatesWhenIsTrainingFalse(self):
height, width = 3, 3
with self.test_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images,
decay=0.1,
is_training=False)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# updates_ops are not added to UPDATE_OPS collection.
self.assertEquals(len(update_ops), 0)
# Initialize all variables
sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables(
'BatchNorm/moving_mean')[0]
moving_variance = tf.contrib.framework.get_variables(
'BatchNorm/moving_variance')[0]
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 3)
self.assertAllClose(variance, [1] * 3)
# When is_training is False batch_norm doesn't update moving_vars.
for _ in range(10):
sess.run([output])
self.assertAllClose(moving_mean.eval(), [0] * 3)
self.assertAllClose(moving_variance.eval(), [1] * 3)

def testNoneUpdatesCollectionNoTraining(self):
height, width = 3, 3
with self.test_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images,
decay=0.1,
updates_collections=None,
is_training=False)
# updates_ops are not added to UPDATE_OPS collection.
self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), [])
# Initialize all variables
sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables(
'BatchNorm/moving_mean')[0]
moving_variance = tf.contrib.framework.get_variables(
'BatchNorm/moving_variance')[0]
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 3)
self.assertAllClose(variance, [1] * 3)
# When is_training is False batch_norm doesn't update moving_vars.
for _ in range(10):
sess.run([output])
self.assertAllClose(moving_mean.eval(), [0] * 3)
self.assertAllClose(moving_variance.eval(), [1] * 3)

def testNoneUpdatesCollectionIsTrainingVariable(self):
height, width = 3, 3
with self.test_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
expected_var = np.var(image_values, axis=(0, 1, 2))
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
is_training = tf.Variable(True)
output = tf.contrib.layers.batch_norm(images,
decay=0.1,
updates_collections=None,
is_training=is_training)
# updates_ops are not added to UPDATE_OPS collection.
self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), [])
# Initialize all variables
sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables(
'BatchNorm/moving_mean')[0]
moving_variance = tf.contrib.framework.get_variables(
'BatchNorm/moving_variance')[0]
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 3)
self.assertAllClose(variance, [1] * 3)
# When is_training is False batch_norm doesn't update moving_vars.
for _ in range(10):
sess.run([output], {is_training: False})
self.assertAllClose(moving_mean.eval(), [0] * 3)
self.assertAllClose(moving_variance.eval(), [1] * 3)
# Before updates the outputs are different depending of is_training.
output_true = sess.run([output], {is_training: True})
output_false = sess.run([output], {is_training: False})
self.assertFalse(np.allclose(output_true, output_false))
# When is_training is True update moving_vars.
for _ in range(10):
sess.run([output], {is_training: True})
# After 10 updates with decay 0.1 moving_mean == expected_mean and
# moving_variance == expected_var.
self.assertAllClose(moving_mean.eval(), expected_mean)
self.assertAllClose(moving_variance.eval(), expected_var)
# After updates to convergence the outputs don't depend on is_training.
output_true = sess.run([output], {is_training: True})
output_false = sess.run([output], {is_training: False})
self.assertTrue(np.allclose(output_true, output_false))


class MaxPool2DTest(tf.test.TestCase):
Expand Down

0 comments on commit 9da5fc8

Please sign in to comment.