Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix crash in softmax-xent when some input dimensions are 1.
Before, tf.nn.softmax_cross_entropy_with_logits would fail a CHECK if one input tensor had shape (1, 1) and the other did not.

In particular, the call to ToIndexArray<2> here https://github.com/tensorflow/tensorflow/blob/1f3da84a89702d3b4f234ee83762d738caffe098/tensorflow/core/kernels/xent_op.cc#L99 would fail, since the call assumed the array had two dimensions. If both dimensions were 1, BCast would merge the two dimensions into a single dimension. Passing fewer_dims_optimization=false stops this optimization

PiperOrigin-RevId: 384844496
Change-Id: Ifb02dc74964132c3ed3f3bc98b0858dbe4e258b7
  • Loading branch information
reedwm authored and tensorflower-gardener committed Jul 15, 2021
1 parent ffae69d commit 4d74d8a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
23 changes: 8 additions & 15 deletions tensorflow/core/kernels/xent_op.cc
Expand Up @@ -46,7 +46,8 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
TensorShape shape_in = logits_in.shape();

BCast bcast(BCast::FromShape(logits_in.shape()),
BCast::FromShape(labels_in.shape()));
BCast::FromShape(labels_in.shape()),
/*fewer_dims_optimization=*/false);
if (!logits_in.IsSameSize(labels_in)) {
OP_REQUIRES(context, bcast.IsValid(),
errors::InvalidArgument(
Expand Down Expand Up @@ -88,20 +89,12 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
{0}, 1, shape_in, &back_out));
if (shape_in.dim_size(0) > 0) {
functor::XentFunctor<Device, T> functor;
if (logits_in.IsSameSize(labels_in)) {
functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
Eigen::array<Eigen::DenseIndex, 2>{1, 1},
Eigen::array<Eigen::DenseIndex, 2>{1, 1}, logits_in.matrix<T>(),
labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
back_out->matrix<T>());
} else {
functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
BCast::ToIndexArray<2>(bcast.x_bcast()),
BCast::ToIndexArray<2>(bcast.y_bcast()),
logits_in.template shaped<T, 2>(bcast.x_reshape()),
labels_in.template shaped<T, 2>(bcast.y_reshape()),
scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>());
}
functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
BCast::ToIndexArray<2>(bcast.x_bcast()),
BCast::ToIndexArray<2>(bcast.y_bcast()),
logits_in.template shaped<T, 2>(bcast.x_reshape()),
labels_in.template shaped<T, 2>(bcast.y_reshape()),
scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>());
}
}
};
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/python/kernel_tests/xent_op_test.py
Expand Up @@ -63,6 +63,13 @@ def testFeaturesBroadcast(self):
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_gradient, tf_gradient)

tf_f = constant_op.constant(np.array([[1.]]).astype(np.float32))
tf_l = constant_op.constant(np.array([[1.], [1.]]).astype(np.float32))
tf_loss, tf_gradient = gen_nn_ops.softmax_cross_entropy_with_logits(
tf_f, tf_l)
self.assertAllClose([0, 0], tf_loss)
self.assertAllCloseAccordingToType([[0], [0]], tf_gradient)

@test_util.run_deprecated_v1
def testNotMatrix(self):
with self.cached_session():
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/python/kernel_tests/xent_op_test_base.py
Expand Up @@ -151,6 +151,9 @@ def _testLabelsBroadcast(self, uniform_labels_gradient):
labels = np.array([[0., 0., 0., 1.]]).astype(np.float16)
logits = np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16)
self._testXent2D(labels, logits, with_placeholders=True)
labels = np.array([[1.]]).astype(np.float16)
logits = np.array([[1.], [2.]]).astype(np.float16)
self._testXent2D(labels, logits, with_placeholders=True)
labels = np.array([[0.], [2.], [0.25]]).astype(np.float16)
logits = np.array([[1., 1., 1., 1.], [1., 2., 3., 4.],
[1., 2., 3., 4.]]).astype(np.float16)
Expand Down

0 comments on commit 4d74d8a

Please sign in to comment.