From 9da5fc8e6425cabd61fc36f0dcc1823a093d5c1d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 14 Jul 2016 09:06:17 -0800 Subject: [PATCH] Allow batch_norm to use a is_training Tensor, which can be modified at run time. Added tf.Assert() to make sure the update_ops are not executed when is_training is False. Change: 127443556 --- tensorflow/contrib/layers/BUILD | 13 ++ .../contrib/layers/python/layers/layers.py | 73 ++++--- .../layers/python/layers/layers_test.py | 196 ++++++++++++++++-- .../contrib/layers/python/layers/utils.py | 90 ++++++-- .../layers/python/layers/utils_test.py | 184 ++++++++++++++++ 5 files changed, 492 insertions(+), 64 deletions(-) create mode 100644 tensorflow/contrib/layers/python/layers/utils_test.py diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index bd05720b264f6f..a41c5efb6e97d0 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -226,6 +226,19 @@ py_test( ], ) +py_test( + name = "utils_test", + size = "small", + srcs = ["python/layers/utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index feb6ca9e7591c9..bedb13fad0c3e1 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -31,7 +31,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import standard_ops @@ -206,29 +205,48 @@ def batch_norm(inputs, initializer=init_ops.ones_initializer, trainable=False, collections=moving_variance_collections) - if is_training: - # Calculate the moments based on the individual batch. + + is_training_value = utils.constant_value(is_training) + # Calculate the moments based on the individual batch. + need_moments = is_training_value is None or is_training_value + if need_moments: mean, variance = nn.moments(inputs, axis, shift=moving_mean) - # Update the moving_mean and moving_variance moments. - update_moving_mean = moving_averages.assign_moving_average( - moving_mean, mean, decay) - update_moving_variance = moving_averages.assign_moving_average( - moving_variance, variance, decay) + moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: - # Make sure the updates are computed here. - with ops.control_dependencies([update_moving_mean, - update_moving_variance]): - outputs = nn.batch_normalization( - inputs, mean, variance, beta, gamma, epsilon) + def _force_updates(): + """Internal function forces updates moving_vars if is_training.""" + update_moving_mean = moving_averages.assign_moving_average( + moving_mean, mean, decay) + update_moving_variance = moving_averages.assign_moving_average( + moving_variance, variance, decay) + with ops.control_dependencies([update_moving_mean, + update_moving_variance]): + return array_ops.identity(mean), array_ops.identity(variance) + mean, variance = utils.smart_cond(is_training, + _force_updates, + moving_vars_fn) else: - # Collect the updates to be computed later. - ops.add_to_collections(updates_collections, update_moving_mean) - ops.add_to_collections(updates_collections, update_moving_variance) - outputs = nn.batch_normalization( - inputs, mean, variance, beta, gamma, epsilon) + def _delay_updates(): + """Internal function that delay updates moving_vars if is_training.""" + update_moving_mean = moving_averages.assign_moving_average( + moving_mean, mean, decay) + update_moving_variance = moving_averages.assign_moving_average( + moving_variance, variance, decay) + return update_moving_mean, update_moving_variance + + update_mean, update_variance = utils.smart_cond(is_training, + _delay_updates, + moving_vars_fn) + ops.add_to_collections(updates_collections, update_mean) + ops.add_to_collections(updates_collections, update_variance) + # Use computed moments during training and moving_vars otherwise. + vars_fn = lambda: (mean, variance) + mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: - outputs = nn.batch_normalization( - inputs, moving_mean, moving_variance, beta, gamma, epsilon) + mean, variance = moving_mean, moving_variance + # Compute batch_normalization. + outputs = nn.batch_normalization( + inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn: outputs = activation_fn(outputs) @@ -430,18 +448,9 @@ def dropout(inputs, """ with ops.op_scope([inputs], scope, 'Dropout') as sc: inputs = ops.convert_to_tensor(inputs) - is_training_value = utils.constant_value(is_training, dtypes.bool) - if is_training_value is not None: - if is_training_value: - outputs = nn.dropout(inputs, keep_prob, noise_shape) - else: - outputs = inputs - else: - def _dropout(): - return nn.dropout(inputs, keep_prob, noise_shape) - outputs = control_flow_ops.cond(is_training, - _dropout, - lambda: inputs) + dropout_fn = lambda: nn.dropout(inputs, keep_prob, noise_shape) + id_fn = lambda: inputs + outputs = utils.smart_cond(is_training, dropout_fn, id_fn) return utils.collect_named_outputs(outputs_collections, sc, outputs) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index f3407ee1c80d24..4380eac8c643fe 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -429,15 +429,24 @@ def testCreateDropout(self): output.get_shape().assert_is_compatible_with( tf.convert_to_tensor(images).get_shape()) - def testCreateDropoutWithConstant(self): + def testCreateDropoutWithConstantTrue(self): height, width = 3, 3 with self.test_session(): - is_training = tf.constant(False) + is_training = tf.constant(True) images = tf.random_uniform((5, height, width, 3), seed=1) output = tf.contrib.layers.dropout(images, is_training=is_training) self.assertEquals(output.op.name, 'Dropout/dropout/mul_1') output.get_shape().assert_is_compatible_with(images.get_shape()) + def testCreateDropoutWithConstantFalse(self): + height, width = 3, 3 + with self.test_session(): + is_training = tf.constant(False) + images = tf.random_uniform((5, height, width, 3), seed=1) + output = tf.contrib.layers.dropout(images, is_training=is_training) + self.assertEquals(output, images) + output.get_shape().assert_is_compatible_with(images.get_shape()) + def testCreateDropoutWithPlaceholder(self): height, width = 3, 3 with self.test_session(): @@ -796,7 +805,7 @@ def testCreateMovingVars(self): self.assertEquals(len(moving_variance), 1) self.assertEquals(moving_variance[0].op.name, 'BatchNorm/moving_variance') - def testForceUpdateMovingVars(self): + def testNoneUpdatesCollections(self): height, width = 3, 3 with self.test_session() as sess: image_shape = (10, height, width, 3) @@ -806,6 +815,8 @@ def testForceUpdateMovingVars(self): images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) output = tf.contrib.layers.batch_norm(images, decay=0.1, updates_collections=None) + # updates_ops are not added to UPDATE_OPS collection. + self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), []) # Initialize all variables sess.run(tf.initialize_all_variables()) moving_mean = tf.contrib.framework.get_variables( @@ -835,6 +846,8 @@ def testDelayedUpdateMovingVars(self): images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) output = tf.contrib.layers.batch_norm(images, decay=0.1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + # updates_ops are added to UPDATE_OPS collection. + self.assertEquals(len(update_ops), 2) with tf.control_dependencies(update_ops): barrier = tf.no_op(name='barrier') output = control_flow_ops.with_dependencies([barrier], output) @@ -868,8 +881,7 @@ def testEvalMovingVars(self): output = tf.contrib.layers.batch_norm(images, decay=0.1, is_training=False) - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - self.assertEquals(update_ops, []) + self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), []) # Initialize all variables sess.run(tf.initialize_all_variables()) moving_mean = tf.contrib.framework.get_variables( @@ -901,11 +913,55 @@ def testReuseVars(self): expected_mean = np.mean(image_values, axis=(0, 1, 2)) expected_var = np.var(image_values, axis=(0, 1, 2)) images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) + output_train = tf.contrib.layers.batch_norm(images, + decay=0.1, + is_training=True, + scope='BN') + output_eval = tf.contrib.layers.batch_norm(images, + decay=0.1, + is_training=False, + scope='BN', + reuse=True) + # Initialize all variables + sess.run(tf.initialize_all_variables()) + moving_mean = tf.contrib.framework.get_variables( + 'BN/moving_mean')[0] + moving_variance = tf.contrib.framework.get_variables( + 'BN/moving_variance')[0] + mean, variance = sess.run([moving_mean, moving_variance]) + # After initialization moving_mean == 0 and moving_variance == 1. + self.assertAllClose(mean, [0] * 3) + self.assertAllClose(variance, [1] * 3) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + barrier = tf.no_op(name='barrier') + train_op = control_flow_ops.with_dependencies([barrier], output_train) + # Before updates the outputs are different for train and eval. + self.assertFalse(np.allclose(sess.run([output_train]), + sess.run([output_eval]))) + for _ in range(10): + sess.run([train_op]) + mean = moving_mean.eval() + variance = moving_variance.eval() + # After 10 updates with decay 0.1 moving_mean == expected_mean and + # moving_variance == expected_var. + self.assertAllClose(mean, expected_mean) + self.assertAllClose(variance, expected_var) + # After convergence output_train and output_eval should be the same. + self.assertAllClose(sess.run([output_train]), sess.run([output_eval])) + + def testIsTrainingVariable(self): + height, width = 3, 3 + with self.test_session() as sess: + image_shape = (10, height, width, 3) + image_values = np.random.rand(*image_shape) + expected_mean = np.mean(image_values, axis=(0, 1, 2)) + expected_var = np.var(image_values, axis=(0, 1, 2)) + images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) + is_training = tf.Variable(True) output = tf.contrib.layers.batch_norm(images, decay=0.1, - is_training=False) - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - self.assertEquals(update_ops, []) + is_training=is_training) # Initialize all variables sess.run(tf.initialize_all_variables()) moving_mean = tf.contrib.framework.get_variables( @@ -916,18 +972,128 @@ def testReuseVars(self): # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 3) self.assertAllClose(variance, [1] * 3) - # Simulate assigment from saver restore. - init_assigns = [tf.assign(moving_mean, expected_mean), - tf.assign(moving_variance, expected_var)] - sess.run(init_assigns) + # Before updates the outputs are different depending of is_training. + output_true = sess.run([output], {is_training: True}) + output_false = sess.run([output], {is_training: False}) + self.assertFalse(np.allclose(output_true, output_false)) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + barrier = tf.no_op(name='barrier') + train_op = control_flow_ops.with_dependencies([barrier], output) for _ in range(10): - sess.run([output], {images: np.random.rand(*image_shape)}) + sess.run([train_op]) mean = moving_mean.eval() variance = moving_variance.eval() - # Although we feed different images, the moving_mean and moving_variance - # shouldn't change. + # After 10 updates with decay 0.1 moving_mean == expected_mean and + # moving_variance == expected_var. self.assertAllClose(mean, expected_mean) self.assertAllClose(variance, expected_var) + # After updates to convergence the outputs don't depend on is_training. + output_true = sess.run([output], {is_training: True}) + output_false = sess.run([output], {is_training: False}) + self.assertAllClose(output_true, output_false) + + def testNoUpdatesWhenIsTrainingFalse(self): + height, width = 3, 3 + with self.test_session() as sess: + image_shape = (10, height, width, 3) + image_values = np.random.rand(*image_shape) + images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) + output = tf.contrib.layers.batch_norm(images, + decay=0.1, + is_training=False) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + # updates_ops are not added to UPDATE_OPS collection. + self.assertEquals(len(update_ops), 0) + # Initialize all variables + sess.run(tf.initialize_all_variables()) + moving_mean = tf.contrib.framework.get_variables( + 'BatchNorm/moving_mean')[0] + moving_variance = tf.contrib.framework.get_variables( + 'BatchNorm/moving_variance')[0] + mean, variance = sess.run([moving_mean, moving_variance]) + # After initialization moving_mean == 0 and moving_variance == 1. + self.assertAllClose(mean, [0] * 3) + self.assertAllClose(variance, [1] * 3) + # When is_training is False batch_norm doesn't update moving_vars. + for _ in range(10): + sess.run([output]) + self.assertAllClose(moving_mean.eval(), [0] * 3) + self.assertAllClose(moving_variance.eval(), [1] * 3) + + def testNoneUpdatesCollectionNoTraining(self): + height, width = 3, 3 + with self.test_session() as sess: + image_shape = (10, height, width, 3) + image_values = np.random.rand(*image_shape) + images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) + output = tf.contrib.layers.batch_norm(images, + decay=0.1, + updates_collections=None, + is_training=False) + # updates_ops are not added to UPDATE_OPS collection. + self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), []) + # Initialize all variables + sess.run(tf.initialize_all_variables()) + moving_mean = tf.contrib.framework.get_variables( + 'BatchNorm/moving_mean')[0] + moving_variance = tf.contrib.framework.get_variables( + 'BatchNorm/moving_variance')[0] + mean, variance = sess.run([moving_mean, moving_variance]) + # After initialization moving_mean == 0 and moving_variance == 1. + self.assertAllClose(mean, [0] * 3) + self.assertAllClose(variance, [1] * 3) + # When is_training is False batch_norm doesn't update moving_vars. + for _ in range(10): + sess.run([output]) + self.assertAllClose(moving_mean.eval(), [0] * 3) + self.assertAllClose(moving_variance.eval(), [1] * 3) + + def testNoneUpdatesCollectionIsTrainingVariable(self): + height, width = 3, 3 + with self.test_session() as sess: + image_shape = (10, height, width, 3) + image_values = np.random.rand(*image_shape) + expected_mean = np.mean(image_values, axis=(0, 1, 2)) + expected_var = np.var(image_values, axis=(0, 1, 2)) + images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) + is_training = tf.Variable(True) + output = tf.contrib.layers.batch_norm(images, + decay=0.1, + updates_collections=None, + is_training=is_training) + # updates_ops are not added to UPDATE_OPS collection. + self.assertEquals(tf.get_collection(tf.GraphKeys.UPDATE_OPS), []) + # Initialize all variables + sess.run(tf.initialize_all_variables()) + moving_mean = tf.contrib.framework.get_variables( + 'BatchNorm/moving_mean')[0] + moving_variance = tf.contrib.framework.get_variables( + 'BatchNorm/moving_variance')[0] + mean, variance = sess.run([moving_mean, moving_variance]) + # After initialization moving_mean == 0 and moving_variance == 1. + self.assertAllClose(mean, [0] * 3) + self.assertAllClose(variance, [1] * 3) + # When is_training is False batch_norm doesn't update moving_vars. + for _ in range(10): + sess.run([output], {is_training: False}) + self.assertAllClose(moving_mean.eval(), [0] * 3) + self.assertAllClose(moving_variance.eval(), [1] * 3) + # Before updates the outputs are different depending of is_training. + output_true = sess.run([output], {is_training: True}) + output_false = sess.run([output], {is_training: False}) + self.assertFalse(np.allclose(output_true, output_false)) + # When is_training is True update moving_vars. + for _ in range(10): + sess.run([output], {is_training: True}) + # After 10 updates with decay 0.1 moving_mean == expected_mean and + # moving_variance == expected_var. + self.assertAllClose(moving_mean.eval(), expected_mean) + self.assertAllClose(moving_variance.eval(), expected_var) + # After updates to convergence the outputs don't depend on is_training. + output_true = sess.run([output], {is_training: True}) + output_false = sess.run([output], {is_training: False}) + self.assertTrue(np.allclose(output_true, output_false)) class MaxPool2DTest(tf.test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/utils.py b/tensorflow/contrib/layers/python/layers/utils.py index c500675789e448..045fdb87418f53 100644 --- a/tensorflow/contrib/layers/python/layers/utils.py +++ b/tensorflow/contrib/layers/python/layers/utils.py @@ -20,9 +20,14 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape - +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables __all__ = ['collect_named_outputs', + 'constant_value', + 'static_cond', + 'smart_cond', 'get_variable_collections', 'two_element_tuple', 'last_dimension', @@ -52,34 +57,85 @@ def collect_named_outputs(collections, name, outputs): return outputs -def constant_value(value_or_tensor, tensor_dtype=None): - """Returns value if value_or_tensor has a constant value. +def constant_value(value_or_tensor_or_var, dtype=None): + """Returns value if value_or_tensor_or_var has a constant value. Args: - value_or_tensor: A value or a `Tensor`. - tensor_dtype: Optional `tf.dtype`, if set it would check the tensor type. + value_or_tensor_or_var: A value, a `Tensor` or a `Variable`. + dtype: Optional `tf.dtype`, if set it would check it has the right + dtype. Returns: The constant value or None if it not constant. Raises: - ValueError: if value_or_tensor is None or the tensor has the wrong dtype. + ValueError: if value_or_tensor_or_var is None or the tensor_variable has the + wrong dtype. """ - if value_or_tensor is None: - raise ValueError('value_or_tensor cannot be None') - value = value_or_tensor - if isinstance(value_or_tensor, ops.Tensor): - if tensor_dtype and value_or_tensor.dtype != tensor_dtype: - raise ValueError('The tensor has the wrong type %s instead of %s' % ( - value_or_tensor.dtype, tensor_dtype)) - if value_or_tensor.op.type == 'Const': - value_or_tensor.graph.prevent_feeding(value_or_tensor) - value = value_or_tensor.op.get_attr('value') - else: + if value_or_tensor_or_var is None: + raise ValueError('value_or_tensor_or_var cannot be None') + value = value_or_tensor_or_var + if isinstance(value_or_tensor_or_var, (ops.Tensor, variables.Variable)): + if dtype and value_or_tensor_or_var.dtype != dtype: + raise ValueError('It has the wrong type %s instead of %s' % ( + value_or_tensor_or_var.dtype, dtype)) + if isinstance(value_or_tensor_or_var, variables.Variable): value = None + else: + value = tensor_util.constant_value(value_or_tensor_or_var) return value +def static_cond(pred, fn1, fn2, name=None): + """Return either fn1() or fn2() based on the boolean value of `pred`. + + Same signature as `control_flow_ops.cond()` but requires pred to be a bool. + + Args: + pred: A value determining whether to return the result of `fn1` or `fn2`. + fn1: The callable to be performed if pred is true. + fn2: The callable to be performed if pred is false. + name: Optional name prefix for the returned tensors. + + Returns: + Tensors returned by the call to either `fn1` or `fn2`. + + Raises: + TypeError: if `fn1` or `fn2` is not callable. + """ + if not callable(fn1): + raise TypeError('fn1 must be callable.') + if not callable(fn2): + raise TypeError('fn2 must be callable.') + if pred: + return fn1() + else: + return fn2() + + +def smart_cond(pred, fn1, fn2, name=None): + """Return either fn1() or fn2() based on the boolean predicate/value `pred`. + + If `pred` is bool or has a constant value it would use `static_cond`, + otherwise it would use `tf.cond`. + + Args: + pred: A scalar determining whether to return the result of `fn1` or `fn2`. + fn1: The callable to be performed if pred is true. + fn2: The callable to be performed if pred is false. + name: Optional name prefix when using tf.cond + Returns: + Tensors returned by the call to either `fn1` or `fn2`. + """ + pred_value = constant_value(pred) + if pred_value is not None: + # Use static_cond if pred has a constant value. + return static_cond(pred_value, fn1, fn2) + else: + # Use dynamic cond otherwise. + return control_flow_ops.cond(pred, fn1, fn2, name) + + def get_variable_collections(variables_collections, name): if isinstance(variables_collections, dict): variable_collections = variables_collections.get(name, None) diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py new file mode 100644 index 00000000000000..89f6722524415c --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/utils_test.py @@ -0,0 +1,184 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for regularizers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.layers.python.layers import utils + + +class ConstantValueTest(tf.test.TestCase): + + def test_value(self): + for v in [True, False, 1, 0, 1.0]: + value = utils.constant_value(v) + self.assertEqual(value, v) + + def test_constant(self): + for v in [True, False, 1, 0, 1.0]: + c = tf.constant(v) + value = utils.constant_value(c) + self.assertEqual(value, v) + with self.test_session(): + self.assertEqual(c.eval(), v) + + def test_variable(self): + for v in [True, False, 1, 0, 1.0]: + with tf.Graph().as_default() as g, self.test_session(g) as sess: + x = tf.Variable(v) + value = utils.constant_value(x) + self.assertEqual(value, None) + sess.run(tf.initialize_all_variables()) + self.assertEqual(x.eval(), v) + + def test_placeholder(self): + for v in [True, False, 1, 0, 1.0]: + p = tf.placeholder(np.dtype(type(v)), []) + x = tf.identity(p) + value = utils.constant_value(p) + self.assertEqual(value, None) + with self.test_session(): + self.assertEqual(x.eval(feed_dict={p: v}), v) + + +class StaticCondTest(tf.test.TestCase): + + def test_value(self): + fn1 = lambda: 'fn1' + fn2 = lambda: 'fn2' + expected = lambda v: 'fn1' if v else 'fn2' + for v in [True, False, 1, 0]: + o = utils.static_cond(v, fn1, fn2) + self.assertEqual(o, expected(v)) + + def test_constant(self): + fn1 = lambda: tf.constant('fn1') + fn2 = lambda: tf.constant('fn2') + expected = lambda v: b'fn1' if v else b'fn2' + for v in [True, False, 1, 0]: + o = utils.static_cond(v, fn1, fn2) + with self.test_session(): + self.assertEqual(o.eval(), expected(v)) + + def test_variable(self): + fn1 = lambda: tf.Variable('fn1') + fn2 = lambda: tf.Variable('fn2') + expected = lambda v: b'fn1' if v else b'fn2' + for v in [True, False, 1, 0]: + o = utils.static_cond(v, fn1, fn2) + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + self.assertEqual(o.eval(), expected(v)) + + def test_tensors(self): + fn1 = lambda: tf.constant(0) - tf.constant(1) + fn2 = lambda: tf.constant(0) - tf.constant(2) + expected = lambda v: -1 if v else -2 + for v in [True, False, 1, 0]: + o = utils.static_cond(v, fn1, fn2) + with self.test_session(): + self.assertEqual(o.eval(), expected(v)) + + +class SmartCondStaticTest(tf.test.TestCase): + + def test_value(self): + fn1 = lambda: 'fn1' + fn2 = lambda: 'fn2' + expected = lambda v: 'fn1' if v else 'fn2' + for v in [True, False, 1, 0]: + o = utils.smart_cond(tf.constant(v), fn1, fn2) + self.assertEqual(o, expected(v)) + + def test_constant(self): + fn1 = lambda: tf.constant('fn1') + fn2 = lambda: tf.constant('fn2') + expected = lambda v: b'fn1' if v else b'fn2' + for v in [True, False, 1, 0]: + o = utils.smart_cond(tf.constant(v), fn1, fn2) + with self.test_session(): + self.assertEqual(o.eval(), expected(v)) + + def test_variable(self): + fn1 = lambda: tf.Variable('fn1') + fn2 = lambda: tf.Variable('fn2') + expected = lambda v: b'fn1' if v else b'fn2' + for v in [True, False, 1, 0]: + o = utils.smart_cond(tf.constant(v), fn1, fn2) + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + self.assertEqual(o.eval(), expected(v)) + + def test_tensors(self): + fn1 = lambda: tf.constant(0) - tf.constant(1) + fn2 = lambda: tf.constant(0) - tf.constant(2) + expected = lambda v: -1 if v else -2 + for v in [True, False, 1, 0]: + o = utils.smart_cond(tf.constant(v), fn1, fn2) + with self.test_session(): + self.assertEqual(o.eval(), expected(v)) + + +class SmartCondDynamicTest(tf.test.TestCase): + + def test_value(self): + fn1 = lambda: tf.convert_to_tensor('fn1') + fn2 = lambda: tf.convert_to_tensor('fn2') + expected = lambda v: b'fn1' if v else b'fn2' + p = tf.placeholder(tf.bool, []) + for v in [True, False, 1, 0]: + o = utils.smart_cond(p, fn1, fn2) + with self.test_session(): + self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) + + def test_constant(self): + fn1 = lambda: tf.constant('fn1') + fn2 = lambda: tf.constant('fn2') + expected = lambda v: b'fn1' if v else b'fn2' + p = tf.placeholder(tf.bool, []) + for v in [True, False, 1, 0]: + o = utils.smart_cond(p, fn1, fn2) + with self.test_session(): + self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) + + def test_variable(self): + fn1 = lambda: tf.Variable('fn1') + fn2 = lambda: tf.Variable('fn2') + expected = lambda v: b'fn1' if v else b'fn2' + p = tf.placeholder(tf.bool, []) + for v in [True, False, 1, 0]: + o = utils.smart_cond(p, fn1, fn2) + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) + + def test_tensors(self): + fn1 = lambda: tf.constant(0) - tf.constant(1) + fn2 = lambda: tf.constant(0) - tf.constant(2) + expected = lambda v: -1 if v else -2 + p = tf.placeholder(tf.bool, []) + for v in [True, False, 1, 0]: + o = utils.smart_cond(p, fn1, fn2) + with self.test_session(): + self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) + + +if __name__ == '__main__': + tf.test.main()