From 4acf7ab8e099c2141e0e5e858a955bcff249e17e Mon Sep 17 00:00:00 2001 From: Ryan Dahl Date: Tue, 31 May 2016 12:14:42 -0700 Subject: [PATCH] add API for using calling inference_small with a config --- resnet.py | 53 ++++++++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/resnet.py b/resnet.py index 8b40d9e..a056c41 100644 --- a/resnet.py +++ b/resnet.py @@ -49,8 +49,8 @@ def inference(x, is_training, c['conv_filters_out'] = 64 c['ksize'] = 7 c['stride'] = 2 - x = _conv(x, c) - x = _bn(x, c) + x = conv(x, c) + x = bn(x, c) x = activation(x) with tf.variable_scope('scale2'): @@ -81,7 +81,7 @@ def inference(x, is_training, if num_classes != None: with tf.variable_scope('fc'): - x = _fc(x, c) + x = fc(x, c) return x @@ -94,22 +94,25 @@ def inference_small(x, use_bias=False, # defaults to using batch norm num_classes=10): c = Config() - c['bottleneck'] = False c['is_training'] = tf.convert_to_tensor(is_training, dtype='bool', name='is_training') - c['ksize'] = 3 - c['stride'] = 1 c['use_bias'] = use_bias c['fc_units_out'] = num_classes c['num_blocks'] = num_blocks + c['num_classes'] = num_classes + inference_small_config(x, c) +def inference_small_config(x, c): + c['bottleneck'] = False + c['ksize'] = 3 + c['stride'] = 1 with tf.variable_scope('scale1'): c['conv_filters_out'] = 16 c['block_filters_internal'] = 16 c['stack_stride'] = 1 - x = _conv(x, c) - x = _bn(x, c) + x = conv(x, c) + x = bn(x, c) x = activation(x) x = stack(x, c) @@ -126,9 +129,9 @@ def inference_small(x, # post-net x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool") - if num_classes != None: + if c['num_classes'] != None: with tf.variable_scope('fc'): - x = _fc(x, c) + x = fc(x, c) return x @@ -180,48 +183,48 @@ def block(x, c): with tf.variable_scope('a'): c['ksize'] = 1 c['stride'] = c['block_stride'] - x = _conv(x, c) - x = _bn(x, c) + x = conv(x, c) + x = bn(x, c) x = activation(x) with tf.variable_scope('b'): - x = _conv(x, c) - x = _bn(x, c) + x = conv(x, c) + x = bn(x, c) x = activation(x) with tf.variable_scope('c'): c['conv_filters_out'] = filters_out c['ksize'] = 1 assert c['stride'] == 1 - x = _conv(x, c) - x = _bn(x, c) + x = conv(x, c) + x = bn(x, c) else: with tf.variable_scope('A'): c['stride'] = c['block_stride'] assert c['ksize'] == 3 - x = _conv(x, c) - x = _bn(x, c) + x = conv(x, c) + x = bn(x, c) x = activation(x) with tf.variable_scope('B'): c['conv_filters_out'] = filters_out assert c['ksize'] == 3 assert c['stride'] == 1 - x = _conv(x, c) - x = _bn(x, c) + x = conv(x, c) + x = bn(x, c) with tf.variable_scope('shortcut'): if filters_out != filters_in or c['block_stride'] != 1: c['ksize'] = 1 c['stride'] = c['block_stride'] c['conv_filters_out'] = filters_out - shortcut = _conv(shortcut, c) - shortcut = _bn(shortcut, c) + shortcut = conv(shortcut, c) + shortcut = bn(shortcut, c) return activation(x + shortcut) -def _bn(x, c): +def bn(x, c): x_shape = x.get_shape() params_shape = x_shape[-1:] @@ -268,7 +271,7 @@ def _bn(x, c): return x -def _fc(x, c): +def fc(x, c): num_units_in = x.get_shape()[1] num_units_out = c['fc_units_out'] weights_initializer = tf.truncated_normal_initializer( @@ -307,7 +310,7 @@ def _get_variable(name, trainable=trainable) -def _conv(x, c): +def conv(x, c): ksize = c['ksize'] stride = c['stride'] filters_out = c['conv_filters_out']