Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

variable_scope use auxiliary_name_scope to control whether to create new name scope #14390

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
109 changes: 109 additions & 0 deletions tensorflow/python/kernel_tests/variable_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,115 @@ def testVarOpScopeNestedOuterScope(self):
with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "outer_1/default/scope2/")

def testBasicWhenAuxiliaryNameScopeIsFalse(self):
with self.test_session():
with variable_scope.variable_scope("scope",
auxiliary_name_scope=False) as scope:
self.assertEqual(scope.original_name_scope, "")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original_name_scope is reasonable (but counterintuitive?) behavior. We almost never create variable_scope like this case. I'm afraid that user may misuse the argument and get lost in name scope.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's better to restrict auxiliary_name_scope for reentering case. Namely, auxiliary_name_scope only takes effect when VariableScope is passed.

self.assertEqual(variable_scope.get_variable("w", []).name, "scope/w:0")
self.assertEqual(constant_op.constant([], name="c").name, "c:0")
with variable_scope.variable_scope(scope,
auxiliary_name_scope=False) as scope1:
self.assertEqual(scope.original_name_scope, "")
self.assertEqual(variable_scope.get_variable("w1", []).name, "scope/w1:0")
self.assertEqual(constant_op.constant([], name="c1").name, "c1:0")
# Recheck: new name scope is NOT created before
with ops.name_scope("scope"):
self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0")

with variable_scope.variable_scope("outer"):
with variable_scope.variable_scope("inner",
auxiliary_name_scope=False) as inner:
self.assertEqual(inner.original_name_scope, "outer/")
self.assertEqual(variable_scope.get_variable("w", []).name, "outer/inner/w:0")
self.assertEqual(constant_op.constant([], name="c").name, "outer/c:0")
with variable_scope.variable_scope(inner,
auxiliary_name_scope=False) as inner1:
self.assertEqual(inner1.original_name_scope, "outer/")
self.assertEqual(variable_scope.get_variable("w1", []).name, "outer/inner/w1:0")
self.assertEqual(constant_op.constant([], name="c1").name, "outer/c1:0")
# Recheck: new name scope is NOT created before
with ops.name_scope("inner"):
self.assertEqual(constant_op.constant([], name="c").name, "outer/inner/c:0")

def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
with self.test_session():
with variable_scope.variable_scope(None, default_name="default",
auxiliary_name_scope=False) as scope:
self.assertEqual(scope.original_name_scope, "")
self.assertEqual(variable_scope.get_variable("w", []).name, "default/w:0")
self.assertEqual(constant_op.constant([], name="c").name, "c:0")
# Recheck: new name scope is NOT created before
with ops.name_scope("default"):
self.assertEqual(constant_op.constant([], name="c").name, "default/c:0")

with variable_scope.variable_scope("outer"):
with variable_scope.variable_scope(None, default_name="default",
auxiliary_name_scope=False) as inner:
self.assertEqual(inner.original_name_scope, "outer/")
self.assertEqual(variable_scope.get_variable("w", []).name, "outer/default/w:0")
self.assertEqual(constant_op.constant([], name="c").name, "outer/c:0")
# Recheck: new name scope is NOT created before
with ops.name_scope("default"):
self.assertEqual(constant_op.constant([], name="c").name, "outer/default/c:0")

def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
with self.test_session():
root_scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(root_scope,
auxiliary_name_scope=False) as scope:
self.assertEqual(scope.original_name_scope, "")
self.assertEqual(variable_scope.get_variable("w", []).name, "w:0")
self.assertEqual(constant_op.constant([], name="c").name, "c:0")

with variable_scope.variable_scope("outer"):
with variable_scope.variable_scope(root_scope,
auxiliary_name_scope=False) as inner:
self.assertEqual(inner.original_name_scope, "")
self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0")
self.assertEqual(constant_op.constant([], name="c1").name, "outer/c1:0")

def testAuxiliaryNameScopeIsInvalid(self):
with self.test_session():
with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
with variable_scope.variable_scope(None, default_name="scope",
auxiliary_name_scope="invalid"):
pass

with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
with variable_scope.variable_scope("scope", auxiliary_name_scope="invalid"):
pass

with variable_scope.variable_scope("scope") as scope:
pass
with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
with variable_scope.variable_scope(scope, auxiliary_name_scope="invalid"):
pass

def testReuseScopeWithoutNameScopeCollision(self):
# Github issue: #13429
with self.test_session():
with variable_scope.variable_scope("outer"):
with variable_scope.variable_scope("inner") as inner:
pass

with variable_scope.variable_scope(inner,
auxiliary_name_scope=False) as scope:
with ops.name_scope(scope.original_name_scope):
self.assertEqual(variable_scope.get_variable("w", []).name, "outer/inner/w:0")
self.assertEqual(constant_op.constant([], name="c").name, "outer/inner/c:0")
with ops.name_scope("inner"):
self.assertEqual(constant_op.constant([], name="c").name, "inner/c:0")

with variable_scope.variable_scope("another"):
with variable_scope.variable_scope(inner,
auxiliary_name_scope=False) as scope1:
with ops.name_scope(scope1.original_name_scope):
self.assertEqual(variable_scope.get_variable("w1", []).name, "outer/inner/w1:0")
self.assertEqual(constant_op.constant([], name="c1").name, "outer/inner/c1:0")
with ops.name_scope("inner"):
self.assertEqual(constant_op.constant([], name="c").name, "another/inner/c:0")

@test_util.run_in_graph_and_eager_modes()
def testGetLocalVar(self):
# Check that local variable respects naming.
Expand Down
36 changes: 31 additions & 5 deletions tensorflow/python/ops/variable_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,10 @@ def __enter__(self):
else self._name_or_scope)
self._reuse = (self._reuse
or self._old.reuse) # Re-using is inherited by sub-scopes.
if self._old_name_scope is None:
name_scope = self._name_or_scope
else:
name_scope = self._old_name_scope
variable_scope_object = VariableScope(
self._reuse,
name=self._new_name,
Expand All @@ -1588,7 +1592,7 @@ def __enter__(self):
dtype=self._old.dtype,
use_resource=self._old.use_resource,
custom_getter=self._old.custom_getter,
name_scope=self._old_name_scope or self._name_or_scope,
name_scope=name_scope,
constraint=self._constraint)
if self._initializer is not None:
variable_scope_object.set_initializer(self._initializer)
Expand Down Expand Up @@ -1757,7 +1761,8 @@ def __init__(self,
reuse=None,
dtype=None,
use_resource=None,
constraint=None):
constraint=None,
auxiliary_name_scope=True):
"""Initialize the context manager.

Args:
Expand Down Expand Up @@ -1789,6 +1794,8 @@ def __init__(self,
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
auxiliary_name_scope: If `True`, we create an auxiliary name scope with
the scope. If `False`, we don't touch name scope.

Returns:
A scope that can be captured and reused.
Expand Down Expand Up @@ -1826,6 +1833,10 @@ def __init__(self,
self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access
self._cached_pure_variable_scope = None
self._current_name_scope = None
if not isinstance(auxiliary_name_scope, bool):
raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
"while get {}".format(auxiliary_name_scope))
self._auxiliary_name_scope = auxiliary_name_scope

def __enter__(self):
# If the default graph is building a function, then we should not replace it
Expand All @@ -1844,6 +1855,21 @@ def __enter__(self):
if self._current_name_scope is not None:
self._current_name_scope.__enter__()
return self._cached_pure_variable_scope.__enter__()

if self._auxiliary_name_scope:
# Create a new name scope later
current_name_scope = None
else:
# Reenter the current name scope
name_scope = ops.get_name_scope()
Copy link
Member Author

@facaiy facaiy Nov 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just worry that get_name_scope here might be performance bottleneck. In order to reuse variable_scope.original_name_scope directly, that's why I introduce the name scope support.

if name_scope:
# Hack to reenter
name_scope = name_scope + "/"
current_name_scope = ops.name_scope(name_scope)
else:
# Root scope
current_name_scope = ops.name_scope(name_scope)

if self._name_or_scope is not None:
if not isinstance(self._name_or_scope,
(VariableScope,) + six.string_types):
Expand All @@ -1853,8 +1879,8 @@ def __enter__(self):
name_scope = self._name_or_scope
else:
name_scope = self._name_or_scope.name.split("/")[-1]
if name_scope:
self._current_name_scope = ops.name_scope(name_scope)
if name_scope or current_name_scope:
self._current_name_scope = current_name_scope or ops.name_scope(name_scope)
current_name_scope_name = self._current_name_scope.__enter__()
if isinstance(self._name_or_scope, six.string_types):
old_name_scope = current_name_scope_name
Expand Down Expand Up @@ -1892,7 +1918,7 @@ def __enter__(self):
else: # Here name_or_scope is None. Using default name, but made unique.
if self._reuse:
raise ValueError("reuse=True cannot be used without a name_or_scope")
self._current_name_scope = ops.name_scope(self._default_name)
self._current_name_scope = current_name_scope or ops.name_scope(self._default_name)
current_name_scope_name = self._current_name_scope.__enter__()
unique_default_name = _get_unique_variable_scope(self._default_name)
self._cached_pure_variable_scope = _pure_variable_scope(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'name_or_scope\', \'default_name\', \'values\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'name_or_scope\', \'default_name\', \'values\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\', \'auxiliary_name_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\'], "
}
}