Skip to content

Commit

Permalink
Scope handling changes for TF 1.4 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
vierja authored and nagitsu committed Nov 2, 2017
1 parent 826834a commit 11e79bf
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
42 changes: 28 additions & 14 deletions luminoth/models/base/truncated_base_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,14 @@ class TruncatedBaseNetwork(BaseNetwork):
a good image representation for other ML tasks.
"""

def __init__(self, config, parent_name=None, name='truncated_base_network',
**kwargs):
def __init__(self, config, name='truncated_base_network', **kwargs):
super(TruncatedBaseNetwork, self).__init__(config, name=name, **kwargs)
self._endpoint = (
config.endpoint or DEFAULT_ENDPOINTS[config.architecture]
)
self._parent_name = parent_name
self._scope_endpoint = '{}/{}/{}'.format(
self.module_name, config.architecture, self._endpoint
)
if parent_name:
self._scope_endpoint = '{}/{}'.format(
parent_name, self._scope_endpoint
)

def _build(self, inputs, is_training=True):
"""
Expand All @@ -51,13 +45,8 @@ def _build(self, inputs, is_training=True):
pred = super(TruncatedBaseNetwork, self)._build(
inputs, is_training=is_training
)
try:
return dict(pred['end_points'])[self._scope_endpoint]
except KeyError:
raise ValueError(
'"{}" is an invalid value of endpoint for this '
'architecture.'.format(self._endpoint)
)

return self._get_endpoint(dict(pred['end_points']))

def get_trainable_vars(self):
"""
Expand Down Expand Up @@ -90,3 +79,28 @@ def get_trainable_vars(self):
)

return all_trainable[:index + 1]

def _get_endpoint(self, endpoints):
"""
Returns the endpoint tensor from the list of possible endpoints.
Since we already have a dictionary with variable names we should be
able to get the desired tensor directly. Unfortunately the variable
names change with scope and the scope changes between TensorFlow
versions. We opted to just select the tensor for which the variable
name ends with the endpoint name we want (it should be just one).
Args:
endpoints: a dictionary with {variable_name: tensor}.
Returns:
endpoint_value: a tensor.
"""
for endpoint_key, endpoint_value in endpoints.items():
if endpoint_key.endswith(self._scope_endpoint):
return endpoint_value

raise ValueError(
'"{}" is an invalid value of endpoint for this '
'architecture.'.format(self._scope_endpoint)
)
17 changes: 9 additions & 8 deletions luminoth/models/base/truncated_base_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,27 @@ def testTrainableVariables(self):
)
model(inputs)
# Variables in ResNet-50:
# (the order of beta and gamma depends on the TensorFlow's version)
# 0 conv1/weights:0
# 1 conv1/BatchNorm/beta:0
# 2 conv1/BatchNorm/gamma:0
# 1 conv1/BatchNorm/(beta|gamma):0
# 2 conv1/BatchNorm/(beta|gamma):0
# 3 block1/unit_1/bottleneck_v1/shortcut/weights:0
# (...)
# 153 block4/unit_3/bottleneck_v1/conv2/weights:0
# 154 block4/unit_3/bottleneck_v1/conv2/BatchNorm/beta:0
# 155 block4/unit_3/bottleneck_v1/conv2/BatchNorm/gamma:0
# 154 block4/unit_3/bottleneck_v1/conv2/BatchNorm/(beta|gamma):0
# 155 block4/unit_3/bottleneck_v1/conv2/BatchNorm/(beta|gamma):0
# --- endpoint ---
# 156 block4/unit_3/bottleneck_v1/conv3/weights:0
# 157 block4/unit_3/bottleneck_v1/conv3/BatchNorm/beta:0
# 158 block4/unit_3/bottleneck_v1/conv3/BatchNorm/gamma:0
# 157 block4/unit_3/bottleneck_v1/conv3/BatchNorm/(beta|gamma):0
# 158 block4/unit_3/bottleneck_v1/conv3/BatchNorm/(beta|gamma):0
# 159 logits/weights:0
# 160 logits/biases:0
trainable_vars = model.get_trainable_vars()
self.assertEqual(len(trainable_vars), 156)
self.assertEqual(
trainable_vars[-1].name,
trainable_vars[-3].name,
'truncated_base_network/resnet_v1_50/' +
'block4/unit_3/bottleneck_v1/conv2/BatchNorm/gamma:0'
'block4/unit_3/bottleneck_v1/conv2/weights:0'
)

model = TruncatedBaseNetwork(
Expand Down
4 changes: 1 addition & 3 deletions luminoth/models/fasterrcnn/fasterrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def __init__(self, config, name='fasterrcnn'):
self._losses_collections = ['fastercnn_losses']

# We want the pretrained model to be outside the FasterRCNN name scope.
self.base_network = TruncatedBaseNetwork(
config.model.base_network, parent_name=self.module_name
)
self.base_network = TruncatedBaseNetwork(config.model.base_network)

def _build(self, image, gt_boxes=None, is_training=True):
"""
Expand Down

0 comments on commit 11e79bf

Please sign in to comment.