Skip to content
This repository has been archived by the owner on Oct 19, 2019. It is now read-only.

Commit

Permalink
add API for using calling inference_small with a config
Browse files Browse the repository at this point in the history
  • Loading branch information
ry committed May 31, 2016
1 parent aad0194 commit 4acf7ab
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions resnet.py
Expand Up @@ -49,8 +49,8 @@ def inference(x, is_training,
c['conv_filters_out'] = 64
c['ksize'] = 7
c['stride'] = 2
x = _conv(x, c)
x = _bn(x, c)
x = conv(x, c)
x = bn(x, c)
x = activation(x)

with tf.variable_scope('scale2'):
Expand Down Expand Up @@ -81,7 +81,7 @@ def inference(x, is_training,

if num_classes != None:
with tf.variable_scope('fc'):
x = _fc(x, c)
x = fc(x, c)

return x

Expand All @@ -94,22 +94,25 @@ def inference_small(x,
use_bias=False, # defaults to using batch norm
num_classes=10):
c = Config()
c['bottleneck'] = False
c['is_training'] = tf.convert_to_tensor(is_training,
dtype='bool',
name='is_training')
c['ksize'] = 3
c['stride'] = 1
c['use_bias'] = use_bias
c['fc_units_out'] = num_classes
c['num_blocks'] = num_blocks
c['num_classes'] = num_classes
inference_small_config(x, c)

def inference_small_config(x, c):
c['bottleneck'] = False
c['ksize'] = 3
c['stride'] = 1
with tf.variable_scope('scale1'):
c['conv_filters_out'] = 16
c['block_filters_internal'] = 16
c['stack_stride'] = 1
x = _conv(x, c)
x = _bn(x, c)
x = conv(x, c)
x = bn(x, c)
x = activation(x)
x = stack(x, c)

Expand All @@ -126,9 +129,9 @@ def inference_small(x,
# post-net
x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool")

if num_classes != None:
if c['num_classes'] != None:
with tf.variable_scope('fc'):
x = _fc(x, c)
x = fc(x, c)

return x

Expand Down Expand Up @@ -180,48 +183,48 @@ def block(x, c):
with tf.variable_scope('a'):
c['ksize'] = 1
c['stride'] = c['block_stride']
x = _conv(x, c)
x = _bn(x, c)
x = conv(x, c)
x = bn(x, c)
x = activation(x)

with tf.variable_scope('b'):
x = _conv(x, c)
x = _bn(x, c)
x = conv(x, c)
x = bn(x, c)
x = activation(x)

with tf.variable_scope('c'):
c['conv_filters_out'] = filters_out
c['ksize'] = 1
assert c['stride'] == 1
x = _conv(x, c)
x = _bn(x, c)
x = conv(x, c)
x = bn(x, c)
else:
with tf.variable_scope('A'):
c['stride'] = c['block_stride']
assert c['ksize'] == 3
x = _conv(x, c)
x = _bn(x, c)
x = conv(x, c)
x = bn(x, c)
x = activation(x)

with tf.variable_scope('B'):
c['conv_filters_out'] = filters_out
assert c['ksize'] == 3
assert c['stride'] == 1
x = _conv(x, c)
x = _bn(x, c)
x = conv(x, c)
x = bn(x, c)

with tf.variable_scope('shortcut'):
if filters_out != filters_in or c['block_stride'] != 1:
c['ksize'] = 1
c['stride'] = c['block_stride']
c['conv_filters_out'] = filters_out
shortcut = _conv(shortcut, c)
shortcut = _bn(shortcut, c)
shortcut = conv(shortcut, c)
shortcut = bn(shortcut, c)

return activation(x + shortcut)


def _bn(x, c):
def bn(x, c):
x_shape = x.get_shape()
params_shape = x_shape[-1:]

Expand Down Expand Up @@ -268,7 +271,7 @@ def _bn(x, c):
return x


def _fc(x, c):
def fc(x, c):
num_units_in = x.get_shape()[1]
num_units_out = c['fc_units_out']
weights_initializer = tf.truncated_normal_initializer(
Expand Down Expand Up @@ -307,7 +310,7 @@ def _get_variable(name,
trainable=trainable)


def _conv(x, c):
def conv(x, c):
ksize = c['ksize']
stride = c['stride']
filters_out = c['conv_filters_out']
Expand Down

0 comments on commit 4acf7ab

Please sign in to comment.