Skip to content

Commit

Permalink
Merge pull request #43 from mrry/newslim
Browse files Browse the repository at this point in the history
Updated to the latest version of TF-Slim
  • Loading branch information
mrry committed Apr 12, 2016
2 parents c74897b + c74d438 commit 9a1dfdf
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 27 deletions.
50 changes: 28 additions & 22 deletions inception/inception/slim/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
@scopes.add_arg_scope
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
moving_vars='moving_vars',
Expand All @@ -57,6 +58,7 @@ def batch_norm(inputs,
inputs: a tensor of size [batch_size, height, width, channels]
or [batch_size, channels].
decay: decay for the moving average.
center: If True, subtract beta. If False, beta is not created and ignored.
scale: If True, multiply by gamma. If False, gamma is
not used. When the next layer is linear (also e.g. ReLU), this can be
disabled since the scaling can be done by the next layer.
Expand All @@ -78,31 +80,35 @@ def batch_norm(inputs,
with tf.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse):
axis = list(range(len(inputs_shape) - 1))
params_shape = inputs_shape[-1:]
with scopes.arg_scope([variables.variable], restore=restore):
# Allocate parameters for the beta and gamma of the normalization.
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
if center:
beta = variables.variable('beta',
params_shape,
initializer=tf.zeros_initializer,
trainable=trainable)
if scale:
gamma = variables.variable('gamma',
params_shape,
initializer=tf.ones,
trainable=trainable)
else:
gamma = None
# Create moving_mean and moving_variance add them to moving_vars and
# GraphKeys.MOVING_AVERAGE_VARIABLES collections.
with scopes.arg_scope([variables.variable], trainable=False,
collections=[
moving_vars,
tf.GraphKeys.MOVING_AVERAGE_VARIABLES]):
moving_mean = variables.variable('moving_mean',
trainable=trainable,
restore=restore)
if scale:
gamma = variables.variable('gamma',
params_shape,
initializer=tf.ones_initializer,
trainable=trainable,
restore=restore)
# Create moving_mean and moving_variance add them to
# GraphKeys.MOVING_AVERAGE_VARIABLES collections.
moving_collections = [moving_vars, tf.GraphKeys.MOVING_AVERAGE_VARIABLES]
moving_mean = variables.variable('moving_mean',
params_shape,
initializer=tf.zeros_initializer,
trainable=False,
restore=restore,
collections=moving_collections)
moving_variance = variables.variable('moving_variance',
params_shape,
initializer=tf.zeros_initializer)
moving_variance = variables.variable('moving_variance',
params_shape,
initializer=tf.ones)
initializer=tf.ones_initializer,
trainable=False,
restore=restore,
collections=moving_collections)
if is_training:
# Calculate the moments based on the individual batch.
mean, variance = tf.nn.moments(inputs, axis)
Expand Down Expand Up @@ -400,7 +406,7 @@ def dropout(inputs, keep_prob=0.5, is_training=True, scope=None):
Args:
inputs: the tensor to pass to the Dropout layer.
keep_prob: the probability of dropping each input unit.
keep_prob: the probability of keeping each input unit.
is_training: whether or not the model is in training mode. If so, dropout is
applied and values scaled. Otherwise, inputs is returned.
scope: Optional scope for op_scope.
Expand Down
42 changes: 42 additions & 0 deletions inception/inception/slim/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,20 @@ def testCreateOp(self):
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])

def testCreateVariables(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.batch_norm(images)
beta = variables.get_variables_by_name('beta')[0]
self.assertEquals(beta.op.name, 'BatchNorm/beta')
gamma = variables.get_variables_by_name('gamma')
self.assertEquals(gamma, [])
moving_mean = tf.moving_average_variables()[0]
moving_variance = tf.moving_average_variables()[1]
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')

def testCreateVariablesWithScale(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
Expand All @@ -489,6 +503,34 @@ def testCreateVariables(self):
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')

def testCreateVariablesWithoutCenterWithScale(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.batch_norm(images, center=False, scale=True)
beta = variables.get_variables_by_name('beta')
self.assertEquals(beta, [])
gamma = variables.get_variables_by_name('gamma')[0]
self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
moving_mean = tf.moving_average_variables()[0]
moving_variance = tf.moving_average_variables()[1]
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')

def testCreateVariablesWithoutCenterWithoutScale(self):
height, width = 3, 3
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
ops.batch_norm(images, center=False, scale=False)
beta = variables.get_variables_by_name('beta')
self.assertEquals(beta, [])
gamma = variables.get_variables_by_name('gamma')
self.assertEquals(gamma, [])
moving_mean = tf.moving_average_variables()[0]
moving_variance = tf.moving_average_variables()[1]
self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')

def testMovingAverageVariables(self):
height, width = 3, 3
with self.test_session():
Expand Down
80 changes: 76 additions & 4 deletions inception/inception/slim/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@

import tensorflow as tf

from tensorflow.core.framework import graph_pb2
from inception.slim import scopes

# Collection containing all the variables created using slim.variables
Expand Down Expand Up @@ -171,6 +172,79 @@ def get_unique_variable(name):
raise ValueError('Variable %s does not uniquely identify a variable', name)


class VariableDeviceChooser(object):
"""Slim device chooser for variables.
When using a parameter server it will assign them in a round-robin fashion.
When not using a parameter server it allows GPU:0 placement otherwise CPU:0.
"""

def __init__(self,
num_parameter_servers=0,
ps_device='/job:ps',
placement='CPU:0'):
"""Initialize VariableDeviceChooser.
Args:
num_parameter_servers: number of parameter servers.
ps_device: string representing the parameter server device.
placement: string representing the placement of the variable either CPU:0
or GPU:0. When using parameter servers forced to CPU:0.
"""
self._num_ps = num_parameter_servers
self._ps_device = ps_device
self._placement = placement if num_parameter_servers == 0 else 'CPU:0'
self._next_task_id = 0

def __call__(self, op):
device_string = ''
if self._num_ps > 0:
task_id = self._next_task_id
self._next_task_id = (self._next_task_id + 1) % self._num_ps
device_string = '%s/task:%d' % (self._ps_device, task_id)
device_string += '/%s' % self._placement
return device_string


# TODO(sguada) Remove once get_variable is able to colocate op.devices.
def variable_device(device, name):
"""Fix the variable device to colocate its ops."""
if callable(device):
var_name = tf.get_variable_scope().name + '/' + name
var_def = graph_pb2.NodeDef(name=var_name, op='Variable')
device = device(var_def)
if device is None:
device = ''
return device


@scopes.add_arg_scope
def global_step(device=''):
"""Returns the global step variable.
Args:
device: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
Returns:
the tensor representing the global step variable.
"""
global_step_ref = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
if global_step_ref:
return global_step_ref[0]
else:
collections = [
VARIABLES_TO_RESTORE,
tf.GraphKeys.VARIABLES,
tf.GraphKeys.GLOBAL_STEP,
]
# Get the device for the variable.
with tf.device(variable_device(device, 'global_step')):
return tf.get_variable('global_step', shape=[], dtype=tf.int64,
initializer=tf.zeros_initializer,
trainable=False, collections=collections)


@scopes.add_arg_scope
def variable(name, shape=None, dtype=tf.float32, initializer=None,
regularizer=None, trainable=True, collections=None, device='',
Expand Down Expand Up @@ -200,9 +274,6 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
Returns:
The created or existing variable.
"""
# Instantiate the device for this variable if it is passed as a function.
if device and callable(device):
device = device()
collections = list(collections or [])

# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
Expand All @@ -212,7 +283,8 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
collections.append(VARIABLES_TO_RESTORE)
# Remove duplicates
collections = set(collections)
with tf.device(device):
# Get the device for the variable.
with tf.device(variable_device(device, name)):
return tf.get_variable(name, shape=shape, dtype=dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable, collections=collections)
Loading

0 comments on commit 9a1dfdf

Please sign in to comment.