Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update slim/nets/resnet #1559

Merged
merged 1 commit into from Jun 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 4 additions & 14 deletions slim/nets/resnet_utils.py
Expand Up @@ -178,26 +178,16 @@ def stack_blocks_dense(net, blocks, output_stride=None,
raise ValueError('The target output_stride cannot be reached.')

with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
unit_depth, unit_depth_bottleneck, unit_stride = unit

# If we have reached the target output_stride, then we need to employ
# atrous convolution with stride=1 and multiply the atrous rate by the
# current unit's stride for use in subsequent layers.
if output_stride is not None and current_stride == output_stride:
net = block.unit_fn(net,
depth=unit_depth,
depth_bottleneck=unit_depth_bottleneck,
stride=1,
rate=rate)
rate *= unit_stride
net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
rate *= unit.get('stride', 1)

else:
net = block.unit_fn(net,
depth=unit_depth,
depth_bottleneck=unit_depth_bottleneck,
stride=unit_stride,
rate=1)
current_stride *= unit_stride
net = block.unit_fn(net, rate=1, **unit)
current_stride *= unit.get('stride', 1)
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)

if output_stride is not None and current_stride != output_stride:
Expand Down
79 changes: 45 additions & 34 deletions slim/nets/resnet_v1.py
Expand Up @@ -119,7 +119,7 @@ def resnet_v1(inputs,
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=True,
spatial_squeeze=False,
reuse=None,
scope=None):
"""Generator for v1 ResNet models.
Expand Down Expand Up @@ -205,13 +205,38 @@ def resnet_v1(inputs,
else:
logits = net
# Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
end_points = slim.utils.convert_collection_to_dict(
end_points_collection)
if num_classes is not None:
end_points['predictions'] = slim.softmax(logits, scope='predictions')
return logits, end_points
resnet_v1.default_image_size = 224


def resnet_v1_block(scope, base_depth, num_units, stride):
"""Helper function for creating a resnet_v1 bottleneck block.

Args:
scope: The scope of the block.
base_depth: The depth of the bottleneck layer for each unit.
num_units: The number of units in the block.
stride: The stride of the block, implemented as a stride in the last unit.
All other units have stride=1.

Returns:
A resnet_v1 bottleneck block.
"""
return resnet_utils.Block(scope, bottleneck, [{
'depth': base_depth * 4,
'depth_bottleneck': base_depth,
'stride': 1
}] * (num_units - 1) + [{
'depth': base_depth * 4,
'depth_bottleneck': base_depth,
'stride': stride
}])


def resnet_v1_50(inputs,
num_classes=None,
is_training=True,
Expand All @@ -222,14 +247,10 @@ def resnet_v1_50(inputs,
scope='resnet_v1_50'):
"""ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block(
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block(
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
resnet_utils.Block(
'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)
resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
resnet_v1_block('block3', base_depth=256, num_units=6, stride=2),
resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
Expand All @@ -248,14 +269,10 @@ def resnet_v1_101(inputs,
scope='resnet_v1_101'):
"""ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block(
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block(
'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
resnet_utils.Block(
'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)
resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
resnet_v1_block('block3', base_depth=256, num_units=23, stride=2),
resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
Expand All @@ -274,14 +291,11 @@ def resnet_v1_152(inputs,
scope='resnet_v1_152'):
"""ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block(
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block(
'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]),
resnet_utils.Block(
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
resnet_v1_block('block2', base_depth=128, num_units=8, stride=2),
resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze,
Expand All @@ -299,14 +313,11 @@ def resnet_v1_200(inputs,
scope='resnet_v1_200'):
"""ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block(
'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block(
'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]),
resnet_utils.Block(
'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(2048, 512, 1)] * 3)]
resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
resnet_v1_block('block2', base_depth=128, num_units=24, stride=2),
resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, spatial_squeeze=spatial_squeeze,
Expand Down
50 changes: 20 additions & 30 deletions slim/nets/resnet_v1_test.py
Expand Up @@ -156,14 +156,17 @@ def _resnet_plain(self, inputs, blocks, output_stride=None, scope=None):
with tf.variable_scope(scope, values=[inputs]):
with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
end_points = dict(tf.get_collection('end_points'))
end_points = slim.utils.convert_collection_to_dict('end_points')
return net, end_points

def testEndPointsV1(self):
"""Test the end points of a tiny v1 bottleneck network."""
bottleneck = resnet_v1.bottleneck
blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])]
blocks = [
resnet_v1.resnet_v1_block(
'block1', base_depth=1, num_units=2, stride=2),
resnet_v1.resnet_v1_block(
'block2', base_depth=2, num_units=2, stride=1),
]
inputs = create_test_input(2, 32, 16, 3)
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
Expand All @@ -189,30 +192,23 @@ def _stack_blocks_nondense(self, net, blocks):
for block in blocks:
with tf.variable_scope(block.scope, 'block', [net]):
for i, unit in enumerate(block.args):
depth, depth_bottleneck, stride = unit
with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
net = block.unit_fn(net,
depth=depth,
depth_bottleneck=depth_bottleneck,
stride=stride,
rate=1)
net = block.unit_fn(net, rate=1, **unit)
return net

def _atrousValues(self, bottleneck):
def testAtrousValuesBottleneck(self):
"""Verify the values of dense feature extraction by atrous convolution.

Make sure that dense feature extraction by stack_blocks_dense() followed by
subsampling gives identical results to feature extraction at the nominal
network output stride using the simple self._stack_blocks_nondense() above.

Args:
bottleneck: The bottleneck function.
"""
block = resnet_v1.resnet_v1_block
blocks = [
resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]),
resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]),
resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)])
block('block1', base_depth=1, num_units=2, stride=2),
block('block2', base_depth=2, num_units=2, stride=2),
block('block3', base_depth=4, num_units=2, stride=2),
block('block4', base_depth=8, num_units=2, stride=1),
]
nominal_stride = 8

Expand Down Expand Up @@ -244,9 +240,6 @@ def _atrousValues(self, bottleneck):
output, expected = sess.run([output, expected])
self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)

def testAtrousValuesBottleneck(self):
self._atrousValues(resnet_v1.bottleneck)


class ResnetCompleteNetworkTest(tf.test.TestCase):
"""Tests with complete small ResNet v1 networks."""
Expand All @@ -261,16 +254,13 @@ def _resnet_small(self,
reuse=None,
scope='resnet_v1_small'):
"""A shallow and thin ResNet v1 for faster tests."""
bottleneck = resnet_v1.bottleneck
block = resnet_v1.resnet_v1_block
blocks = [
resnet_utils.Block(
'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]),
resnet_utils.Block(
'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]),
resnet_utils.Block(
'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
resnet_utils.Block(
'block4', bottleneck, [(32, 8, 1)] * 2)]
block('block1', base_depth=1, num_units=3, stride=2),
block('block2', base_depth=2, num_units=3, stride=2),
block('block3', base_depth=4, num_units=3, stride=2),
block('block4', base_depth=8, num_units=2, stride=1),
]
return resnet_v1.resnet_v1(inputs, blocks, num_classes,
is_training=is_training,
global_pool=global_pool,
Expand Down