Permalink
Browse files

add API for using calling inference_small with a config

  • Loading branch information...
1 parent aad0194 commit 4acf7ab8e099c2141e0e5e858a955bcff249e17e @ry committed May 31, 2016
Showing with 28 additions and 25 deletions.
  1. +28 −25 resnet.py
View
@@ -49,8 +49,8 @@ def inference(x, is_training,
c['conv_filters_out'] = 64
c['ksize'] = 7
c['stride'] = 2
- x = _conv(x, c)
- x = _bn(x, c)
+ x = conv(x, c)
+ x = bn(x, c)
x = activation(x)
with tf.variable_scope('scale2'):
@@ -81,7 +81,7 @@ def inference(x, is_training,
if num_classes != None:
with tf.variable_scope('fc'):
- x = _fc(x, c)
+ x = fc(x, c)
return x
@@ -94,22 +94,25 @@ def inference_small(x,
use_bias=False, # defaults to using batch norm
num_classes=10):
c = Config()
- c['bottleneck'] = False
c['is_training'] = tf.convert_to_tensor(is_training,
dtype='bool',
name='is_training')
- c['ksize'] = 3
- c['stride'] = 1
c['use_bias'] = use_bias
c['fc_units_out'] = num_classes
c['num_blocks'] = num_blocks
+ c['num_classes'] = num_classes
+ inference_small_config(x, c)
+def inference_small_config(x, c):
+ c['bottleneck'] = False
+ c['ksize'] = 3
+ c['stride'] = 1
with tf.variable_scope('scale1'):
c['conv_filters_out'] = 16
c['block_filters_internal'] = 16
c['stack_stride'] = 1
- x = _conv(x, c)
- x = _bn(x, c)
+ x = conv(x, c)
+ x = bn(x, c)
x = activation(x)
x = stack(x, c)
@@ -126,9 +129,9 @@ def inference_small(x,
# post-net
x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool")
- if num_classes != None:
+ if c['num_classes'] != None:
with tf.variable_scope('fc'):
- x = _fc(x, c)
+ x = fc(x, c)
return x
@@ -180,48 +183,48 @@ def block(x, c):
with tf.variable_scope('a'):
c['ksize'] = 1
c['stride'] = c['block_stride']
- x = _conv(x, c)
- x = _bn(x, c)
+ x = conv(x, c)
+ x = bn(x, c)
x = activation(x)
with tf.variable_scope('b'):
- x = _conv(x, c)
- x = _bn(x, c)
+ x = conv(x, c)
+ x = bn(x, c)
x = activation(x)
with tf.variable_scope('c'):
c['conv_filters_out'] = filters_out
c['ksize'] = 1
assert c['stride'] == 1
- x = _conv(x, c)
- x = _bn(x, c)
+ x = conv(x, c)
+ x = bn(x, c)
else:
with tf.variable_scope('A'):
c['stride'] = c['block_stride']
assert c['ksize'] == 3
- x = _conv(x, c)
- x = _bn(x, c)
+ x = conv(x, c)
+ x = bn(x, c)
x = activation(x)
with tf.variable_scope('B'):
c['conv_filters_out'] = filters_out
assert c['ksize'] == 3
assert c['stride'] == 1
- x = _conv(x, c)
- x = _bn(x, c)
+ x = conv(x, c)
+ x = bn(x, c)
with tf.variable_scope('shortcut'):
if filters_out != filters_in or c['block_stride'] != 1:
c['ksize'] = 1
c['stride'] = c['block_stride']
c['conv_filters_out'] = filters_out
- shortcut = _conv(shortcut, c)
- shortcut = _bn(shortcut, c)
+ shortcut = conv(shortcut, c)
+ shortcut = bn(shortcut, c)
return activation(x + shortcut)
-def _bn(x, c):
+def bn(x, c):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
@@ -268,7 +271,7 @@ def _bn(x, c):
return x
-def _fc(x, c):
+def fc(x, c):
num_units_in = x.get_shape()[1]
num_units_out = c['fc_units_out']
weights_initializer = tf.truncated_normal_initializer(
@@ -307,7 +310,7 @@ def _get_variable(name,
trainable=trainable)
-def _conv(x, c):
+def conv(x, c):
ksize = c['ksize']
stride = c['stride']
filters_out = c['conv_filters_out']

0 comments on commit 4acf7ab

Please sign in to comment.