Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Hessian for sparse softmax cross entropy #31700

Merged
merged 11 commits into from Feb 10, 2020
2 changes: 1 addition & 1 deletion tensorflow/python/eager/pywrap_gradient_exclusions.cc
Expand Up @@ -294,7 +294,7 @@ bool OpGradientDoesntRequireInputIndices(
{"SparseSegmentSumWithNumSegments", {false, {3}}},
{"SparseSlice", {false, {2, 4}}},
{"SparseSoftmax", {false, {1}}},
{"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
{"SparseSoftmaxCrossEntropyWithLogits", {false, {1}}},
saxenasaurabh marked this conversation as resolved.
Show resolved Hide resolved
{"SparseSparseMaximum", {true, {}}},
{"SparseSparseMinimum", {true, {}}},
{"SparseTensorDenseAdd", {false, {1, 2, 3}}},
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/eager/tensor_test.py
Expand Up @@ -527,7 +527,7 @@ def testSliceDimOutOfRange(self):

@test_util.assert_no_new_pyobjects_executing_eagerly
def testTensorDir(self):
t = array_ops.zeros(1)
t = array_ops.ones(1)
t.test_attr = "Test"

instance_dir = dir(t)
Expand Down
53 changes: 34 additions & 19 deletions tensorflow/python/kernel_tests/sparse_xent_op_test.py
Expand Up @@ -36,9 +36,7 @@
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import app
from tensorflow.python.platform import test
Expand Down Expand Up @@ -193,7 +191,7 @@ def testEmpty(self):

@test_util.run_deprecated_v1
def testGradient(self):
with self.session(use_gpu=True):
with self.session(use_gpu=True) as sess:
l = constant_op.constant([3, 0, 1], name="l")
f = constant_op.constant(
[0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
Expand All @@ -203,26 +201,43 @@ def testGradient(self):
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
print("cross entropy gradient err = ", err)

# Check that no extra computation performed. When only first derivative is
# requested, second derivative must not be computed. So when there is no
# second derivative, there is no `BatchMatMul` op in the graph.
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertNotIn("BatchMatMul", op_names)
self.assertNotIn("BatchMatMulV2", op_names)

self.assertLess(err, 5e-8)

@test_util.run_deprecated_v1
def testSecondGradient(self):
images_placeholder = array_ops.placeholder(dtypes.float32, shape=(3, 2))
labels_placeholder = array_ops.placeholder(dtypes.int32, shape=(3))
weights = variables.Variable(random_ops.truncated_normal([2], stddev=1.0))
weights_with_zeros = array_ops.stack([array_ops.zeros([2]), weights],
axis=1)
logits = math_ops.matmul(images_placeholder, weights_with_zeros)
cross_entropy = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels_placeholder, logits=logits)
loss = math_ops.reduce_mean(cross_entropy)

# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
"explicitly disabled"):
_ = gradients_impl.hessians(loss, [weights])
with self.session() as sess:
l = constant_op.constant([3, 0, 1], name="l")
f = constant_op.constant(
[0.3, 0.4, 0.1, 1.2, 0.1, 1.9, 0.1, 0.7, 0.8, 0.2, 1.3, 1.3],
shape=[3, 4],
dtype=dtypes.float64,
name="f")
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")

gradients = gradients_impl.gradients(x, [f])[0]
err = gradient_checker.compute_gradient_error(f, [3, 4], gradients,
[3, 4])

# Check that second derivative is calculated.
# (it is equivalent to being `BatchMatMul` op in the graph because of
# implementation of xentropy grad)
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertIn("BatchMatMulV2", op_names)

self.assertLess(err, 5e-8)

def _testHighDim(self, features, labels):
np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/kernel_tests/xent_op_test.py
Expand Up @@ -241,6 +241,7 @@ def testGradient(self):
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertNotIn("BatchMatMul", op_names)
self.assertNotIn("BatchMatMulV2", op_names)

print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/ops/array_ops.py
Expand Up @@ -42,6 +42,7 @@
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import

Expand Down Expand Up @@ -2657,7 +2658,20 @@ def _constant_if_small(value, shape, dtype, name):
return None


def _tag_zeros_tensor(fun):
""" Tags the result of function by setting _is_zeros_tensor attribute.

This is useful to compute Hessians of fused ops such as cross_entropy.
"""
def wrapped(*args, **kwargs):
tensor = fun(*args, **kwargs)
tensor._is_zeros_tensor = True
return tensor
return tf_decorator.make_decorator(fun, wrapped)


@tf_export("zeros")
@_tag_zeros_tensor
mknbv marked this conversation as resolved.
Show resolved Hide resolved
def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero.

Expand Down Expand Up @@ -2790,6 +2804,7 @@ def zeros_like_v2(
return zeros_like_impl(input, dtype, name, optimize=True)


@_tag_zeros_tensor
def zeros_like_impl(tensor, dtype, name, optimize=True):
"""Internal implementation for the v1/v2 zeros_like API calls."""
with ops.name_scope(name, "zeros_like", [tensor]) as name:
Expand Down
46 changes: 21 additions & 25 deletions tensorflow/python/ops/nn_grad.py
Expand Up @@ -19,10 +19,8 @@
from __future__ import print_function

from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
Expand Down Expand Up @@ -524,18 +522,9 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)

def IsZero(g):
# Some introspection to check if the gradient is feeding zeros
if context.executing_eagerly():
# TODO(apassos) add an efficient way to detect eager zeros here.
return False
if g.op.type in ("ZerosLike", "Zeros"):
return True
const_fill_value = tensor_util.constant_value(g)
return const_fill_value is not None and (const_fill_value == 0).all()

logits = op.inputs[0]
if grad_grad is not None and not IsZero(grad_grad):
if (grad_grad is not None
and not getattr(grad_grad, "_is_zeros_tensor", False)):
softmax = nn_ops.softmax(logits)

grad += ((grad_grad - array_ops.squeeze(
Expand All @@ -548,22 +537,29 @@ def IsZero(g):


@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
"""Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
# grad_0 is the backprop for cost, and we multiply it with the gradients
# grad_loss is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# grad_grad is the backprop for softmax gradient.
# There is no gradient for the labels
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
op.outputs[1],
message="Currently there is no way to take the second "
"derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
"implementation's interaction with tf.gradients()")
return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
# Second derivative is just softmax derivative w.r.t. logits.
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)

logits = op.inputs[0]
if (grad_grad is not None
and not getattr(grad_grad, "_is_zeros_tensor", False)):
softmax = nn_ops.softmax(logits)

grad += ((grad_grad - array_ops.squeeze(
math_ops.matmul(
array_ops.expand_dims(grad_grad, 1),
array_ops.expand_dims(softmax, 2)),
axis=1)) * softmax)

return grad, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the None wrt grad_grad here just made this code silently return the wrong answer when taking third-order derivatives.

Either add a prevent_gradient or implement the third-order derivative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This None refers to the gradient wrt labels passed as the second input into the operation, isn't it?

Because we break out of fused implementation when Hessian is computed, I thought this should work properly when higher order derivatives are requested. I tested it locally now by adding compute_gradient_error for the result of tf.hessians similar to the testSecondGradient test case to be added with this PR and got error around 2.12e-8.



@ops.RegisterGradient("Conv2D")
Expand Down