diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc index 51b37e3ec14bbd..968a7dc66225e1 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc @@ -294,7 +294,7 @@ bool OpGradientDoesntRequireInputIndices( {"SparseSegmentSumWithNumSegments", {false, {3}}}, {"SparseSlice", {false, {2, 4}}}, {"SparseSoftmax", {false, {1}}}, - {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}}, + {"SparseSoftmaxCrossEntropyWithLogits", {false, {1}}}, {"SparseSparseMaximum", {true, {}}}, {"SparseSparseMinimum", {true, {}}}, {"SparseTensorDenseAdd", {false, {1, 2, 3}}}, diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 342bd37eea5ff5..dd1f049cdcc378 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -527,7 +527,7 @@ def testSliceDimOutOfRange(self): @test_util.assert_no_new_pyobjects_executing_eagerly def testTensorDir(self): - t = array_ops.zeros(1) + t = array_ops.ones(1) t.test_attr = "Test" instance_dir = dir(t) diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py index 76973add820ff8..9c65f75054fb7b 100644 --- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -36,9 +36,7 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import variables import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import app from tensorflow.python.platform import test @@ -193,7 +191,7 @@ def testEmpty(self): @test_util.run_deprecated_v1 def testGradient(self): - with self.session(use_gpu=True): + with self.session(use_gpu=True) as sess: l = constant_op.constant([3, 0, 1], name="l") f = constant_op.constant( [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4], @@ -203,26 +201,43 @@ def testGradient(self): x = nn_ops.sparse_softmax_cross_entropy_with_logits( labels=l, logits=f, name="xent") err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3]) - print("cross entropy gradient err = ", err) + + # Check that no extra computation performed. When only first derivative is + # requested, second derivative must not be computed. So when there is no + # second derivative, there is no `BatchMatMul` op in the graph. + op_names = [ + op.op_def.name for op in sess.graph.get_operations() if op.op_def + ] + self.assertNotIn("BatchMatMul", op_names) + self.assertNotIn("BatchMatMulV2", op_names) + self.assertLess(err, 5e-8) @test_util.run_deprecated_v1 def testSecondGradient(self): - images_placeholder = array_ops.placeholder(dtypes.float32, shape=(3, 2)) - labels_placeholder = array_ops.placeholder(dtypes.int32, shape=(3)) - weights = variables.Variable(random_ops.truncated_normal([2], stddev=1.0)) - weights_with_zeros = array_ops.stack([array_ops.zeros([2]), weights], - axis=1) - logits = math_ops.matmul(images_placeholder, weights_with_zeros) - cross_entropy = nn_ops.sparse_softmax_cross_entropy_with_logits( - labels=labels_placeholder, logits=logits) - loss = math_ops.reduce_mean(cross_entropy) - - # Taking ths second gradient should fail, since it is not - # yet supported. - with self.assertRaisesRegexp(LookupError, - "explicitly disabled"): - _ = gradients_impl.hessians(loss, [weights]) + with self.session() as sess: + l = constant_op.constant([3, 0, 1], name="l") + f = constant_op.constant( + [0.3, 0.4, 0.1, 1.2, 0.1, 1.9, 0.1, 0.7, 0.8, 0.2, 1.3, 1.3], + shape=[3, 4], + dtype=dtypes.float64, + name="f") + x = nn_ops.sparse_softmax_cross_entropy_with_logits( + labels=l, logits=f, name="xent") + + gradients = gradients_impl.gradients(x, [f])[0] + err = gradient_checker.compute_gradient_error(f, [3, 4], gradients, + [3, 4]) + + # Check that second derivative is calculated. + # (it is equivalent to being `BatchMatMul` op in the graph because of + # implementation of xentropy grad) + op_names = [ + op.op_def.name for op in sess.graph.get_operations() if op.op_def + ] + self.assertIn("BatchMatMulV2", op_names) + + self.assertLess(err, 5e-8) def _testHighDim(self, features, labels): np_loss, np_backprop = self._npXent(np.array(features), np.array(labels)) diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index c0faf902327b68..54e0aa21ff36de 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -241,6 +241,7 @@ def testGradient(self): op.op_def.name for op in sess.graph.get_operations() if op.op_def ] self.assertNotIn("BatchMatMul", op_names) + self.assertNotIn("BatchMatMulV2", op_names) print("cross entropy gradient err = ", err) self.assertLess(err, 5e-8) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 1422dafa5618e6..a494260e96e170 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -42,6 +42,7 @@ from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -2657,7 +2658,20 @@ def _constant_if_small(value, shape, dtype, name): return None +def _tag_zeros_tensor(fun): + """ Tags the result of function by setting _is_zeros_tensor attribute. + + This is useful to compute Hessians of fused ops such as cross_entropy. + """ + def wrapped(*args, **kwargs): + tensor = fun(*args, **kwargs) + tensor._is_zeros_tensor = True + return tensor + return tf_decorator.make_decorator(fun, wrapped) + + @tf_export("zeros") +@_tag_zeros_tensor def zeros(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to zero. @@ -2790,6 +2804,7 @@ def zeros_like_v2( return zeros_like_impl(input, dtype, name, optimize=True) +@_tag_zeros_tensor def zeros_like_impl(tensor, dtype, name, optimize=True): """Internal implementation for the v1/v2 zeros_like API calls.""" with ops.name_scope(name, "zeros_like", [tensor]) as name: diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 51eec89723d8c6..5d6e7801650b60 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -19,10 +19,8 @@ from __future__ import print_function from tensorflow.python.eager import backprop -from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -524,18 +522,9 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): softmax_grad = op.outputs[1] grad = _BroadcastMul(grad_loss, softmax_grad) - def IsZero(g): - # Some introspection to check if the gradient is feeding zeros - if context.executing_eagerly(): - # TODO(apassos) add an efficient way to detect eager zeros here. - return False - if g.op.type in ("ZerosLike", "Zeros"): - return True - const_fill_value = tensor_util.constant_value(g) - return const_fill_value is not None and (const_fill_value == 0).all() - logits = op.inputs[0] - if grad_grad is not None and not IsZero(grad_grad): + if (grad_grad is not None + and not getattr(grad_grad, "_is_zeros_tensor", False)): softmax = nn_ops.softmax(logits) grad += ((grad_grad - array_ops.squeeze( @@ -548,22 +537,29 @@ def IsZero(g): @ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") -def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _): +def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): """Gradient function for SparseSoftmaxCrossEntropyWithLogits.""" - # grad_0 is the backprop for cost, and we multiply it with the gradients + # grad_loss is the backprop for cost, and we multiply it with the gradients # (which is output[1]) + # grad_grad is the backprop for softmax gradient. # There is no gradient for the labels # - # Currently there is no way to take the second derivative of this op - # due to the fused implementation's interaction with tf.gradients(), - # so we make sure we prevent silently incorrect results by raising - # an error if the second derivative is requested via prevent_gradient. - sparse_softmax_grad_without_gradient = array_ops.prevent_gradient( - op.outputs[1], - message="Currently there is no way to take the second " - "derivative of sparse_softmax_cross_entropy_with_logits due to the fused " - "implementation's interaction with tf.gradients()") - return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None + # Second derivative is just softmax derivative w.r.t. logits. + softmax_grad = op.outputs[1] + grad = _BroadcastMul(grad_loss, softmax_grad) + + logits = op.inputs[0] + if (grad_grad is not None + and not getattr(grad_grad, "_is_zeros_tensor", False)): + softmax = nn_ops.softmax(logits) + + grad += ((grad_grad - array_ops.squeeze( + math_ops.matmul( + array_ops.expand_dims(grad_grad, 1), + array_ops.expand_dims(softmax, 2)), + axis=1)) * softmax) + + return grad, None @ops.RegisterGradient("Conv2D")