Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make op unique name generation case insensitive #18413

Merged
merged 4 commits into from
May 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def setUp(self):
'predicted_distributions': self._predicted_distributions,
}
self._expected_loss = 1.61610
self._expected_op_name = 'mutual_information_loss/mul'
self._expected_op_name = 'mutual_information_loss/mul_1'
self._batch_size = 2


Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/layers/python/layers/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,7 @@ def testCreateDropout(self):
with self.test_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.dropout(images)
self.assertEqual(output.op.name, 'Dropout/dropout/mul')
self.assertEqual(output.op.name, 'Dropout/dropout_1/mul')
output.get_shape().assert_is_compatible_with(
ops.convert_to_tensor(images).get_shape())

Expand Down
14 changes: 7 additions & 7 deletions tensorflow/contrib/quantize/python/fold_batch_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
def _IsValidUnfusedBatchNorm(graph, context):
"""Checks that the output of the unfused batch norm has consumers."""
add_shift = graph.get_operation_by_name(
context + '/BatchNorm/batchnorm/add_1')
context + '/BatchNorm/batchnorm_1/add_1')
# Ensure that the output tensor of batch norm has consumers, otherwise this
# is a dangling node and not a match.
return bool(add_shift.outputs[0].consumers())
Expand Down Expand Up @@ -567,7 +567,7 @@ def _GetBatchNormParams(graph, context, has_scaling):

op_suffix_mean = '/BatchNorm/moments/Squeeze'
op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
op_suffix_epsilon = '/BatchNorm/batchnorm/add/y'
op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'

Expand Down Expand Up @@ -642,12 +642,12 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,

Returns:
A pair of Operations, the first is the original consumer node of the batch
norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of
norm (../BatchNorm/batchnorm_1/add_1), the second is the consumer node of
the folded graph (add_fold).
"""
mul_scale_name = 'mul_1' if has_scaling else 'mul'
mul_scale = graph.get_operation_by_name(context +
'/BatchNorm/batchnorm/' +
'/BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
weights = op_below.inputs[1]
Expand All @@ -669,7 +669,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
]
scale_name = 'mul' if has_scaling else 'Rsqrt'
scale = graph.get_operation_by_name(
context + '/BatchNorm/batchnorm/' + scale_name)
context + '/BatchNorm/batchnorm_1/' + scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
context + '/scale_reshape')

Expand Down Expand Up @@ -697,7 +697,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
[(1, mul_fold.outputs[0])])

add_shift = graph.get_operation_by_name(
context + '/BatchNorm/batchnorm/add_1')
context + '/BatchNorm/batchnorm_1/add_1')

corrected_output = conv_or_fc_folded.outputs[0]
if correction_offset is not None:
Expand Down Expand Up @@ -885,7 +885,7 @@ def _HasScaling(graph, input_to_ops_map, bn):
Returns:
A boolean indicating whether this batch norm layer has scaling enabled.
"""
rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt')
rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)

return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/contrib/quantize/python/fold_batch_norms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,13 +516,13 @@ def _BatchNormMultiplierName(self, scope, has_scaling, fused):
if has_scaling:
if fused:
return scope + '/BatchNorm_Fold/mul'
return scope + '/BatchNorm/batchnorm/mul'
return scope + '/BatchNorm/batchnorm/Rsqrt'
return scope + '/BatchNorm/batchnorm_1/mul'
return scope + '/BatchNorm/batchnorm_1/Rsqrt'

def _BathNormBiasName(self, scope, fused):
if fused:
return scope + '/BatchNorm_Fold/bias'
return scope + '/BatchNorm/batchnorm/sub'
return scope + '/BatchNorm/batchnorm_1/sub'

def _WeightInit(self, stddev):
"""Returns a truncated normal variable initializer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def testComputeRFFromGraphDefStopPropagation(self):
effective_stride_y, effective_padding_x, effective_padding_y) = (
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node,
['Dropout/dropout/random_uniform']))
['Dropout/dropout_1/random_uniform']))
self.assertEqual(receptive_field_x, 3)
self.assertEqual(receptive_field_y, 3)
self.assertEqual(effective_stride_x, 4)
Expand Down
30 changes: 19 additions & 11 deletions tensorflow/python/framework/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3446,8 +3446,9 @@ def _create_op_from_tf_operation(self, c_op, compute_device=True):
# the name will still appear in _names_in_use even though the name hasn't
# been used. This is ok, just leave _names_in_use as-is in this case.
# TODO(skyewm): make the C API guarantee no name conflicts.
if ret.name not in self._names_in_use:
self._names_in_use[ret.name] = 1
name_key = ret.name.lower()
if name_key not in self._names_in_use:
self._names_in_use[name_key] = 1
self._create_op_helper(ret, compute_device=compute_device)
return ret

Expand Down Expand Up @@ -4163,20 +4164,27 @@ def unique_name(self, name, mark_as_used=True):
"""
if self._name_stack:
name = self._name_stack + "/" + name
i = self._names_in_use.get(name, 0)
# Increment the number for "name".

# For the sake of checking for names in use, we treat names as case
# insensitive (e.g. foo = Foo).
name_key = name.lower()
i = self._names_in_use.get(name_key, 0)
# Increment the number for "name_key".
if mark_as_used:
self._names_in_use[name] = i + 1
self._names_in_use[name_key] = i + 1
if i > 0:
base_name = name
# Make sure the composed name is not already used.
while name in self._names_in_use:
name = "%s_%d" % (base_name, i)
base_name_key = name_key
# Make sure the composed name key is not already used.
while name_key in self._names_in_use:
name_key = "%s_%d" % (base_name_key, i)
i += 1
# Mark the composed name as used in case someone wants
# Mark the composed name_key as used in case someone wants
# to call unique_name("name_1").
if mark_as_used:
self._names_in_use[name] = 1
self._names_in_use[name_key] = 1

# Return the new name with the original capitalization of the given name.
name = "%s_%d" % (name, i-1)
return name

def get_name_scope(self):
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/framework/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,15 @@ def testOutOfOrderUniqueName(self):
self.assertEqual("foo_1", g.unique_name("foo"))
self.assertEqual("foo_3", g.unique_name("foo"))

def testUniqueNameCaseInsensitivity(self):
g = ops.Graph()
self.assertEqual("foo", g.unique_name("foo"))
self.assertEqual("Foo_1", g.unique_name("Foo"))
with g.name_scope("bar"):
self.assertEqual("bar/foo", g.unique_name("foo"))
with g.name_scope("Bar"):
self.assertEqual("Bar_1/foo", g.unique_name("foo"))

def testInvalidNameRaisesError(self):
g = ops.Graph()
with g.name_scope(""): # Should not raise
Expand Down