Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[r2.0 CherryPick]: Use experimental_ref() in moving_averages #32399

Merged
merged 1 commit into from
Sep 15, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions tensorflow/python/training/moving_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export


Expand Down Expand Up @@ -369,7 +368,7 @@ def __init__(self,
self._num_updates = num_updates
self._zero_debias = zero_debias
self._name = name
self._averages = object_identity.ObjectIdentityDictionary()
self._averages = {}

@property
def name(self):
Expand Down Expand Up @@ -423,7 +422,7 @@ def apply(self, var_list=None):
raise TypeError("The variables must be half, float, or double: %s" %
var.name)

if var not in self._averages:
if var.experimental_ref() not in self._averages:
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
Expand All @@ -445,8 +444,8 @@ def apply(self, var_list=None):
"Variable", "VariableV2", "VarHandleOp"
]))
if self._zero_debias:
zero_debias_true.add(avg)
self._averages[var] = avg
zero_debias_true.add(avg.experimental_ref())
self._averages[var.experimental_ref()] = avg

with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
Expand All @@ -457,10 +456,9 @@ def apply(self, var_list=None):
(1.0 + num_updates) / (10.0 + num_updates))
updates = []
for var in var_list:
zero_debias = any(self._averages[var] is v for v in zero_debias_true)
updates.append(
assign_moving_average(
self._averages[var], var, decay, zero_debias=zero_debias))
avg = self._averages[var.experimental_ref()]
zero_debias = avg.experimental_ref() in zero_debias_true
updates.append(assign_moving_average(avg, var, decay, zero_debias))
return control_flow_ops.group(*updates, name=scope)

def average(self, var):
Expand All @@ -473,7 +471,7 @@ def average(self, var):
A `Variable` object or `None` if the moving average of `var`
is not maintained.
"""
return self._averages.get(var, None)
return self._averages.get(var.experimental_ref(), None)

def average_name(self, var):
"""Returns the name of the `Variable` holding the average for `var`.
Expand All @@ -497,8 +495,8 @@ def average_name(self, var):
by the `ExponentialMovingAverage class` to hold the moving average of
`var`.
"""
if var in self._averages:
return self._averages[var].op.name
if var.experimental_ref() in self._averages:
return self._averages[var.experimental_ref()].op.name
return ops.get_default_graph().unique_name(
var.op.name + "/" + self.name, mark_as_used=False)

Expand Down