Skip to content

Commit

Permalink
Unconditionally tag zero tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
mknbv committed Dec 23, 2019
1 parent ea809e3 commit 6fe6391
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 30 deletions.
30 changes: 19 additions & 11 deletions tensorflow/python/ops/array_ops.py
Expand Up @@ -2300,13 +2300,20 @@ def _constant_if_small(value, shape, dtype, name):
return None


def _eager_mark_zeros_tensor(tensor):
if context.executing_eagerly():
setattr(tensor, "_is_zeros_tensor", True)
return tensor
def _tag_zeros_tensor(fun):
""" Tags the result of function by setting _is_zeros_tensor attribute.
This is useful to compute Hessians of fused ops such as cross_entropy.
"""
def wrapped(*args, **kwargs):
tensor = fun(*args, **kwargs)
tensor._is_zeros_tensor = True
return tensor
return wrapped


@tf_export("zeros")
@_tag_zeros_tensor
def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero.
Expand Down Expand Up @@ -2343,7 +2350,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
# to prevent serialized GraphDefs from becoming too large.
output = _constant_if_small(zero, shape, dtype, name)
if output is not None:
return _eager_mark_zeros_tensor(output)
return output

# Go through tensor shapes to get int64-if-needed semantics
shape = constant_op._tensor_shape_tensor_conversion_function(
Expand All @@ -2355,7 +2362,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
shape = reshape(shape, [-1]) # Ensure it's a vector
output = fill(shape, constant(zero, dtype=dtype), name=name)
assert output.dtype.base_dtype == dtype
return _eager_mark_zeros_tensor(output)
return output


@tf_export(v1=["zeros_like"])
Expand Down Expand Up @@ -2430,15 +2437,16 @@ def zeros_like_v2(
return zeros_like_impl(input, dtype, name, optimize=True)


@_tag_zeros_tensor
def zeros_like_impl(tensor, dtype, name, optimize=True):
"""Internal implementation for the v1/v2 zeros_like API calls."""
with ops.name_scope(name, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")

if context.executing_eagerly():
if dtype is not None and dtype != tensor.dtype:
return _eager_mark_zeros_tensor(zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name))
return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
with ops.device(tensor.device):
return gen_array_ops.zeros_like(tensor, name=name)

Expand All @@ -2452,10 +2460,10 @@ def zeros_like_impl(tensor, dtype, name, optimize=True):
return zeros(tensor.shape, dtype=dtype or tensor.dtype, name=name)

if dtype is not None and dtype != tensor.dtype and dtype != dtypes.variant:
return _eager_mark_zeros_tensor(zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name))
return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
else:
return _eager_mark_zeros_tensor(gen_array_ops.zeros_like(tensor, name=name))
return gen_array_ops.zeros_like(tensor, name=name)


@tf_export(v1=["ones_like"])
Expand Down
23 changes: 4 additions & 19 deletions tensorflow/python/ops/nn_grad.py
Expand Up @@ -513,23 +513,6 @@ def _BroadcastMul(vec, mat):
return vec * mat


def _IsZero(tensor):
"""Check if tensor contains only zeros.
Args:
tensor: tensor to check
Returns:
True if tensor contains only zeros and False otherwise
"""
if context.executing_eagerly():
return getattr(tensor, "_is_zeros_tensor", False)
if tensor.op.type in ("ZerosLike", "Zeros"):
return True
const_fill_value = tensor_util.constant_value(tensor)
return const_fill_value is not None and (const_fill_value == 0).all()


@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
"""Gradient function for SoftmaxCrossEntropyWithLogits."""
Expand All @@ -542,7 +525,8 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
grad = _BroadcastMul(grad_loss, softmax_grad)

logits = op.inputs[0]
if grad_grad is not None and not _IsZero(grad_grad):
if (grad_grad is not None
and not getattr(grad_grad, "_is_zeros_tensor", False)):
softmax = nn_ops.softmax(logits)

grad += ((grad_grad - array_ops.squeeze(
Expand All @@ -567,7 +551,8 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
grad = _BroadcastMul(grad_loss, softmax_grad)

logits = op.inputs[0]
if grad_grad is not None and not _IsZero(grad_grad):
if (grad_grad is not None
and not getattr(grad_grad, "_is_zeros_tensor", False)):
softmax = nn_ops.softmax(logits)

grad += ((grad_grad - array_ops.squeeze(
Expand Down

0 comments on commit 6fe6391

Please sign in to comment.