Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 317950322
Change-Id: I83c81973a220b74c015a8571c4f1d50b4ede91db
  • Loading branch information
tensorflower-gardener committed Jun 23, 2020
1 parent 908664e commit 8535daf
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 78 deletions.
2 changes: 1 addition & 1 deletion tensorflow/python/framework/func_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(self, name, collections=None, capture_by_value=None):
self.inputs = []
self.outputs = []
self.control_outputs = []
self.control_captures = object_identity.ObjectIdentitySet()
self.control_captures = set()
self.structured_input_signature = None
self.structured_outputs = None
self._weak_variables = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,87 +188,61 @@ def initial_value(self):
def constraint(self):
return self._variable.constraint

def _apply_assign_update(self,
update_fn,
value,
use_locking=None,
name=None,
read_value=True):
if ops.executing_eagerly_outside_functions():
assign_op = update_fn(value, use_locking, name, False)
return self if read_value else assign_op

# Fallback to wrapping the returned variable in graph mode if possible
assign_var = update_fn(value, use_locking, name, read_value)
if read_value and resource_variable_ops.is_resource_variable(assign_var):
return create_autocast_variable(assign_var)
return assign_var

def _apply_update(self, update_fn, *args, **kwargs):
update_var = update_fn(*args, **kwargs)
if ops.executing_eagerly_outside_functions():
return self

# Fallback to wrapping the returned variable in graph mode if possible
if resource_variable_ops.is_resource_variable(update_var):
return create_autocast_variable(update_var)
return update_var

def assign(self, value, use_locking=None, name=None, read_value=True):
return self._apply_assign_update(self._variable.assign, value, use_locking,
name, read_value)
assign_op = self._variable.assign(value, use_locking, name, read_value)
return _maybe_wrap(assign_op, wrap=read_value)

def assign_add(self, delta, use_locking=None, name=None, read_value=True):
return self._apply_assign_update(self._variable.assign_add, delta,
use_locking, name, read_value)
assign_op = self._variable.assign_add(delta, use_locking, name, read_value)
return _maybe_wrap(assign_op, wrap=read_value)

def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
return self._apply_assign_update(self._variable.assign_sub, delta,
use_locking, name, read_value)
assign_op = self._variable.assign_sub(delta, use_locking, name, read_value)
return _maybe_wrap(assign_op, wrap=read_value)

def scatter_sub(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.scatter_sub, sparse_delta,
use_locking, name)
var = self._variable.scatter_sub(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def scatter_add(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.scatter_add, sparse_delta,
use_locking, name)
var = self._variable.scatter_add(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def scatter_max(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.scatter_max, sparse_delta,
use_locking, name)
var = self._variable.scatter_max(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def scatter_min(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.scatter_min, sparse_delta,
use_locking, name)
var = self._variable.scatter_min(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def scatter_mul(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.scatter_mul, sparse_delta,
use_locking, name)
var = self._variable.scatter_mul(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def scatter_div(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.scatter_div, sparse_delta,
use_locking, name)
var = self._variable.scatter_div(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def scatter_update(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.scatter_update, sparse_delta,
use_locking, name)
var = self._variable.scatter_update(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
return self._apply_update(self._variable.batch_scatter_update, sparse_delta,
use_locking, name)
var = self._variable.batch_scatter_update(sparse_delta, use_locking, name)
return _maybe_wrap(var)

def scatter_nd_sub(self, indices, updates, name=None):
return self._apply_update(self._variable.scatter_nd_sub, indices, updates,
name)
var = self._variable.scatter_nd_sub(indices, updates, name)
return _maybe_wrap(var)

def scatter_nd_add(self, indices, updates, name=None):
return self._apply_update(self._variable.scatter_nd_add, indices, updates,
name)
var = self._variable.scatter_nd_add(indices, updates, name)
return _maybe_wrap(var)

def scatter_nd_update(self, indices, updates, name=None):
return self._apply_update(self._variable.scatter_nd_update, indices,
updates, name)
var = self._variable.scatter_nd_update(indices, updates, name)
return _maybe_wrap(var)

def load(self, value, session=None):
return self._variable.load(value, session)
Expand Down Expand Up @@ -495,3 +469,24 @@ def __repr__(self):
# pylint: enable=missing-format-attribute

return AutoCastDistributedVariable(variable)


def _maybe_wrap(variable, wrap=True):
"""Creates an AutoCastVariable that wraps another variable if applicable.
This function is used to wrap the return value of AutoCastVariable.assign.
Unfortunately MirroredVariable.assign will (incorrectly) return a Mirrored
value instead of a MirroredVariable. So we cannot properly wrap it in an
AutoCastVariable. We return the original variable in that case.
Args:
variable: A tf.Variable or op.
wrap: A boolean to define whether to wrap the variable in an
AutoCastVariable or not.
Returns:
An AutoCastVariable if wrap is True and variable is a resource variable.
"""
if wrap and resource_variable_ops.is_resource_variable(variable):
return create_autocast_variable(variable)
return variable
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ def run_and_check():
self.assertAllClose(3., self.evaluate(x.assign_sub(3.)))

# Assign multiple times
# This currently doesn't work in graph mode if a strategy is used
if not ds_context.has_strategy() or context.executing_eagerly():
# This currently only works if no strategy is used
if not ds_context.has_strategy():
assign = x.assign(1.)
self.assertAllClose(1., self.evaluate(assign))
self.assertAllClose(0., self.evaluate(assign.assign(0.)))
Expand Down Expand Up @@ -344,23 +344,6 @@ def run_and_check():
# assign still expect float32 value even if in float16 scope
run_and_check()

@combinations.generate(maybe_distribute)
def test_assign_tf_function(self, distribution):
if not context.executing_eagerly():
self.skipTest('Test is not compatible with graph mode')

with distribution.scope():
x = get_var(0., dtypes.float32)
x = autocast_variable.create_autocast_variable(x)

@def_function.function
def run_assign():
return x.assign(1.).assign_add(3.).assign_add(3.).assign_sub(2.)

with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertAllClose(5., self.evaluate(run_assign()))

@combinations.generate(maybe_distribute)
def test_assign_stays_in_true_dtype(self, distribution):
with distribution.scope():
Expand All @@ -375,16 +358,18 @@ def test_assign_stays_in_true_dtype(self, distribution):
dtypes.float16):
# Variable should be increased, despite it appearing to be the same
# float16 value.
self.evaluate(x.assign(1. + small_tensor))
self.assertEqual(1. + small_val,
self.evaluate(x.assign(1. + small_tensor)))
self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x))
self.assertEqual(1. + small_val, self.evaluate(x.value()))

self.evaluate(x.assign(1.))
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.evaluate(x.assign_add(small_tensor))
self.assertEqual(1. + small_val,
self.evaluate(x.assign_add(small_tensor)))
self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x))
self.assertEqual(1. + small_val, self.evaluate(x.value()))

@combinations.generate(maybe_distribute)
def test_checkpoint(self, distribution):
Expand Down

0 comments on commit 8535daf

Please sign in to comment.