In [40]:
import time

import numpy as np
import tensorflow as tf

# Conditional and Case Statements

## `tf.cond`

`tf.cond` is basically an `if/else` statement. You provide a boolean predicate and two functions which return tensors. One will run if the predicate is `True`, the other if it is `False`. Here's a simple example:

In [13]:
tf.reset_default_graph()
pred = tf.placeholder(tf.bool)
def run_if_true():
    return tf.add(3, 3)
def run_if_false():
    return tf.square(3)
out = tf.cond(pred, run_if_true, run_if_false)

In [24]:
with tf.Session() as sess:
    choice = np.random.choice([True, False])
    feed_dict = {pred: choice}
    res = sess.run(out, feed_dict)
    print('Choice: {}\tResult: {}'.format(choice, res))

Choice: False	Result: 9


For simple functions, we can use lambdas instead:

In [18]:
tf.reset_default_graph()
pred = tf.placeholder(tf.bool)
out = tf.cond(pred, lambda: tf.add(3, 3), lambda: tf.square(3))

In [35]:
def stochastic_depth_conv2d(inputs, filters, kernel_size, keep_prob, activation=tf.nn.relu, name=None):
    with tf.variable_scope(name, 'stochastic_depth_conv'):
        def full_layer():
            return tf.layers.conv2d(inputs, filters, kernel_size, activation=activation)
        def skip_layer():
            if inputs.get_shape().as_list()[-1] != filters:
                return tf.layers.conv2d(inputs, filters, [1, 1], activation=activation)
            else:
                return inputs
        pred = tf.random_uniform([]) < keep_prob
        return tf.cond(pred, full_layer, skip_layer)

In [None]:
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32, [None, 228, 228, 3], name='inputs')
keep_prob = tf.placeholder(tf.float32, [], name='keep_prob')
conv = stochastic_depth_conv2d(inputs, 32, [3, 3], keep_prob)
conv = stochastic_depth_conv2d(conv, 32, [3, 3], keep_prob)
conv = stochastic_depth_conv2d(conv, 32, [3, 3], keep_prob)
conv = stochastic_depth_conv2d(conv, 64, [3, 3], keep_prob)
conv = stochastic_depth_conv2d(conv, 64, [3, 3], keep_prob)
conv = stochastic_depth_conv2d(conv, 64, [3, 3], keep_prob)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    feed_dict = {
        inputs: np.random.normal(size=[32, 228, 228, 3]),
        keep_prob: 0.2
    }
    start_t = time.time()
    sess.run(conv, feed_dict)
    end_t = time.time()
    print(end_t - start_t)

In [None]:
tf.summary.FileWriter('stochastic_graph', graph=tf.get_default_graph()).close()

## Efficiency

In [79]:
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32, [None, 228, 228, 3], name='inputs')
keep_prob = tf.placeholder(tf.float32, [], name='keep_prob')
conv = stochastic_depth_conv2d(inputs, 32, [3, 3], keep_prob)
for i in range(100):
    conv = stochastic_depth_conv2d(conv, 32, [3, 3], keep_prob)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    feed_dict = {
        inputs: np.random.normal(size=[32, 228, 228, 3]),
        keep_prob: 0.2
    }
    start_t = time.time()
    sess.run(conv, feed_dict)
    end_t = time.time()
    print(end_t - start_t)

0.3996918201446533


In [81]:
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32, [None, 228, 228, 3], name='inputs')
conv = tf.layers.conv2d(inputs, 32, [3, 3])
for i in range(100):
    conv = tf.layers.conv2d(conv, 32, [3, 3])
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    feed_dict = {inputs: np.random.normal(size=[32, 228, 228, 3])}
    start_t = time.time()
    sess.run(conv, feed_dict)
    end_t = time.time()
    print(end_t - start_t)

0.45209479331970215
