Skip to content

Commit

Permalink
Support NCHW
Browse files Browse the repository at this point in the history
  • Loading branch information
pudae committed Mar 16, 2018
1 parent bb478af commit ab41685
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
4 changes: 4 additions & 0 deletions eval_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
tf.app.flags.DEFINE_string(
'model_name', 'densenet121', 'The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
'data_format', 'NHWC', 'The structure of the Tensor. NHWC or NCHW.')

tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
Expand Down Expand Up @@ -102,6 +105,7 @@ def main(_):
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
data_format=FLAGS.data_format,
is_training=False)

##############################################################
Expand Down
55 changes: 39 additions & 16 deletions nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
slim = tf.contrib.slim


@slim.add_arg_scope
def _global_avg_pool2d(inputs, data_format='NHWC', scope=None, outputs_collections=None):
with tf.variable_scope(scope, 'xx', [inputs]) as sc:
axis = [1, 2] if data_format == 'NHWC' else [2, 3]
net = tf.reduce_mean(inputs, axis=axis, keep_dims=True)
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
return net


@slim.add_arg_scope
def _conv(inputs, num_filters, kernel_size, stride=1, dropout_rate=None,
scope=None, outputs_collections=None):
Expand All @@ -45,12 +54,15 @@ def _conv(inputs, num_filters, kernel_size, stride=1, dropout_rate=None,


@slim.add_arg_scope
def _conv_block(inputs, num_filters, scope=None, outputs_collections=None):
def _conv_block(inputs, num_filters, data_format='NHWC', scope=None, outputs_collections=None):
with tf.variable_scope(scope, 'conv_blockx', [inputs]) as sc:
net = inputs
net = _conv(net, num_filters*4, 1, scope='x1')
net = _conv(net, num_filters, 3, scope='x2')
net = tf.concat([inputs, net], axis=3)
if data_format == 'NHWC':
net = tf.concat([inputs, net], axis=3)
else: # "NCHW"
net = tf.concat([inputs, net], axis=1)

net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)

Expand Down Expand Up @@ -98,6 +110,7 @@ def densenet(inputs,
num_filters=None,
num_layers=None,
dropout_rate=None,
data_format='NHWC',
is_training=True,
reuse=None,
scope=None):
Expand All @@ -109,6 +122,9 @@ def densenet(inputs,
compression = 1.0 - reduction
num_dense_blocks = len(num_layers)

if data_format == 'NCHW':
inputs = tf.transpose(inputs, [0, 3, 1, 2])

with tf.variable_scope(scope, 'densenetxxx', [inputs, num_classes],
reuse=reuse) as sc:
end_points_collection = sc.name + '_end_points'
Expand Down Expand Up @@ -147,7 +163,7 @@ def densenet(inputs,
with tf.variable_scope('final_block', [inputs]):
net = slim.batch_norm(net)
net = tf.nn.relu(net)
net = tf.reduce_mean(net, [1,2], name='global_avg_pool', keep_dims=True)
net = _global_avg_pool2d(net, scope='global_avg_pool')

net = slim.conv2d(net, num_classes, 1,
biases_initializer=tf.zeros_initializer(),
Expand All @@ -162,39 +178,42 @@ def densenet(inputs,
return net, end_points


def densenet121(inputs, num_classes=1000, is_training=True, reuse=None):
def densenet121(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None):
return densenet(inputs,
num_classes=num_classes,
reduction=0.5,
growth_rate=32,
num_filters=64,
num_layers=[6,12,24,16],
data_format=data_format,
is_training=is_training,
reuse=reuse,
scope='densenet121')
densenet121.default_image_size = 224


def densenet161(inputs, num_classes=1000, is_training=True, reuse=None):
def densenet161(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None):
return densenet(inputs,
num_classes=num_classes,
reduction=0.5,
growth_rate=48,
num_filters=96,
num_layers=[6,12,36,24],
data_format=data_format,
is_training=is_training,
reuse=reuse,
scope='densenet161')
densenet161.default_image_size = 224


def densenet169(inputs, num_classes=1000, is_training=True, reuse=None):
def densenet169(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None):
return densenet(inputs,
num_classes=num_classes,
reduction=0.5,
growth_rate=32,
num_filters=64,
num_layers=[6,12,32,32],
data_format=data_format,
is_training=is_training,
reuse=reuse,
scope='densenet169')
Expand All @@ -203,15 +222,19 @@ def densenet169(inputs, num_classes=1000, is_training=True, reuse=None):

def densenet_arg_scope(weight_decay=1e-4,
batch_norm_decay=0.99,
batch_norm_epsilon=1.1e-5):
with slim.arg_scope([slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=None,
biases_initializer=None):
with slim.arg_scope([slim.batch_norm],
scale=True,
decay=batch_norm_decay,
epsilon=batch_norm_epsilon) as scope:
return scope
batch_norm_epsilon=1.1e-5,
data_format='NHWC'):
with slim.arg_scope([slim.conv2d, slim.batch_norm, slim.avg_pool2d, slim.max_pool2d,
_conv_block, _global_avg_pool2d],
data_format=data_format):
with slim.arg_scope([slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=None,
biases_initializer=None):
with slim.arg_scope([slim.batch_norm],
scale=True,
decay=batch_norm_decay,
epsilon=batch_norm_epsilon) as scope:
return scope


7 changes: 4 additions & 3 deletions nets/nets_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
}


def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
def get_network_fn(name, num_classes, weight_decay=0.0, data_format='NHWC',
is_training=False):
"""Returns a network_fn such as `logits, end_points = network_fn(images)`.
Args:
Expand All @@ -57,12 +58,12 @@ def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
"""
if name not in networks_map:
raise ValueError('Name of network unknown %s' % name)
arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
arg_scope = arg_scopes_map[name](weight_decay=weight_decay, data_format=data_format)
func = networks_map[name]
@functools.wraps(func)
def network_fn(images):
with slim.arg_scope(arg_scope):
return func(images, num_classes, is_training=is_training)
return func(images, num_classes, data_format=data_format, is_training=is_training)
if hasattr(func, 'default_image_size'):
network_fn.default_image_size = func.default_image_size

Expand Down
4 changes: 4 additions & 0 deletions train_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@
tf.app.flags.DEFINE_string(
'model_name', 'densenet121', 'The name of the architecture to train.')

tf.app.flags.DEFINE_string(
'data_format', 'NHWC', 'The structure of the Tensor. NHWC or NCHW.')

tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
Expand Down Expand Up @@ -409,6 +412,7 @@ def main(_):
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
weight_decay=FLAGS.weight_decay,
data_format=FLAGS.data_format,
is_training=True)

#####################################
Expand Down

0 comments on commit ab41685

Please sign in to comment.