Skip to content

Commit

Permalink
SSD: Fix pretrained VGG_16 weights loading
Browse files Browse the repository at this point in the history
  • Loading branch information
joaqo authored and nagitsu committed Mar 20, 2018
1 parent 26abdb1 commit d43fb59
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 36 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Expand Up @@ -92,7 +92,11 @@ tags
Pipfile
Pipfile.lock

# Flake8
.flake8

# Luminoth
/datasets
/logs
/models
/jobs
15 changes: 12 additions & 3 deletions luminoth/models/base/base_network.py
Expand Up @@ -183,11 +183,20 @@ def load_weights(self):
# config file.
# Weights are downloaded by default to the $LUMI_HOME folder if
# running locally, or to the job bucket if running in Google Cloud.

# TODO: Shouldn't _config['weights'] be called weights_path or
# something similar?
self._config['weights'] = get_checkpoint_file(self._architecture)

module_variables = snt.get_variables_in_module(
self, tf.GraphKeys.MODEL_VARIABLES
)
if self.pretrained_weights_scope:
module_variables = tf.get_collection(
tf.GraphKeys.MODEL_VARIABLES,
scope=self.pretrained_weights_scope
)
else:
module_variables = snt.get_variables_in_module(
self, tf.GraphKeys.MODEL_VARIABLES
)
assert len(module_variables) > 0

load_variables = []
Expand Down
43 changes: 10 additions & 33 deletions luminoth/models/base/ssd_feature_extractor.py
Expand Up @@ -45,17 +45,13 @@ def _build(self, inputs, is_training=True):
# The original SSD paper uses a modified version of the vgg16 network,
# which we'll build here
if self.vgg_type:
# TODO: there is a problem with the scope, so I hardcoded this
# in the meantime, check bottom of this file [1] for more info
base_network_truncation_endpoint = base_net_endpoints[
'ssd_feature_extractor/vgg_16/conv5/conv5_3']
scope + '/vgg_16/conv5/conv5_3']

# TODO: there is a problem with the scope, so I hardcoded this
# in the meantime, check bottom of this file [1] for more info
# We'll add the feature maps to a collection. In the paper they use
# one of vgg16's layers as a feature map, so we start by adding it.
tf.add_to_collection('FEATURE_MAPS', base_net_endpoints[
'ssd_feature_extractor/vgg_16/conv4/conv4_3']
scope + '/vgg_16/conv4/conv4_3']
)

# TODO: check that the usage of `padding='VALID'` is correct
Expand All @@ -76,35 +72,16 @@ def _build(self, inputs, is_training=True):
outputs_collections='FEATURE_MAPS')
net = slim.conv2d(net, 128, [1, 1], scope='conv10_1')
net = slim.conv2d(net, 256, [3, 3], scope='conv10_2',
padding='VALID', outputs_collections='FEATURE_MAPS')
padding='VALID',
outputs_collections='FEATURE_MAPS')
net = slim.conv2d(net, 128, [1, 1], scope='conv11_1')
# import ipdb; ipdb.set_trace()
net = slim.conv2d(net, 256, [3, 3], scope='conv11_2',
padding='VALID', outputs_collections='FEATURE_MAPS')
padding='VALID',
outputs_collections='FEATURE_MAPS')

# This parameter determines onto which variables we try to load the
# pretrained weights
self.pretrained_weights_scope = 'ssd_feature_extractor/vgg_16'

# Its actually an ordered dict
return utils.convert_collection_to_dict('FEATURE_MAPS')

# [1]:
# ipdb> for k in base_net_endpoints.keys(): print(k)
# ssd_feature_extractor/vgg_16/conv1/conv1_1
# ssd_feature_extractor/vgg_16/conv1/conv1_2
# ssd/ssd_feature_extractor/vgg_16/pool1
# ssd_feature_extractor/vgg_16/conv2/conv2_1
# ssd_feature_extractor/vgg_16/conv2/conv2_2
# ssd/ssd_feature_extractor/vgg_16/pool2
# ssd_feature_extractor/vgg_16/conv3/conv3_1
# ssd_feature_extractor/vgg_16/conv3/conv3_2
# ssd_feature_extractor/vgg_16/conv3/conv3_3
# ssd/ssd_feature_extractor/vgg_16/pool3
# ssd_feature_extractor/vgg_16/conv4/conv4_1
# ssd_feature_extractor/vgg_16/conv4/conv4_2
# ssd_feature_extractor/vgg_16/conv4/conv4_3
# ssd/ssd_feature_extractor/vgg_16/pool4
# ssd_feature_extractor/vgg_16/conv5/conv5_1
# ssd_feature_extractor/vgg_16/conv5/conv5_2
# ssd_feature_extractor/vgg_16/conv5/conv5_3
# ssd/ssd_feature_extractor/vgg_16/pool5
# ssd_feature_extractor/vgg_16/fc6
# ssd_feature_extractor/vgg_16/fc7
# ssd_feature_extractor/vgg_16/fc8

0 comments on commit d43fb59

Please sign in to comment.