Skip to content

Commit

Permalink
SparseXentOp now returns NaNs for loss & grad rows where the label va…
Browse files Browse the repository at this point in the history
…lue is OOB.

Change: 128485714
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Jul 26, 2016
1 parent 6d4ead2 commit dcb9053
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/bounds_check.h
Expand Up @@ -42,7 +42,7 @@ namespace internal {
// This function may only be used on primitive integral types (int32, int64,
// etc). It does not guarantee any atomicity or barriers.
template <typename T>
const T SubtleMustCopy(const T &x) {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC const T SubtleMustCopy(const T &x) {
static_assert(std::is_integral<T>::value,
"SubtleMustCopy can only be used on integer types.");
auto *to_x = reinterpret_cast<const volatile T *>(&x);
Expand Down
44 changes: 32 additions & 12 deletions tensorflow/core/kernels/sparse_xent_op.h
Expand Up @@ -19,6 +19,8 @@ limitations under the License.

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
Expand Down Expand Up @@ -56,14 +58,22 @@ class SparseXentLossGenerator {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator(
typename TTypes<const T, 2>::Tensor32Bit logits,
typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
typename TTypes<const Index, 1>::Tensor32Bit labels)
: logits_(logits), sum_exp_logits_(sum_exp_logits), labels_(labels) {}
typename TTypes<const Index, 1>::Tensor32Bit labels,
const Index max_depth)
: logits_(logits),
sum_exp_logits_(sum_exp_logits),
labels_(labels),
max_depth_(max_depth) {}

EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const Eigen::array<int, 2>& coords) const {
int batch = coords[0];
int depth = coords[1];
return (labels_(batch) == depth)
const int batch = coords[0];
const int depth = coords[1];
const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch));
if (!FastBoundsCheck(label, max_depth_)) {
return Eigen::NumTraits<T>::quiet_NaN();
}
return TF_PREDICT_FALSE(label == depth)
? (Eigen::numext::log(sum_exp_logits_(batch)) - logits_(coords))
: T(0.0);
};
Expand All @@ -72,6 +82,7 @@ class SparseXentLossGenerator {
typename TTypes<const T, 2>::Tensor32Bit logits_;
typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
typename TTypes<const Index, 1>::Tensor32Bit labels_;
const Index max_depth_;
};

// Generator for calculation of the sparse Xent gradient.
Expand All @@ -87,23 +98,30 @@ class SparseXentGradGenerator {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator(
typename TTypes<const T, 2>::Tensor32Bit exp_logits,
typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
typename TTypes<const Index, 1>::Tensor32Bit labels)
typename TTypes<const Index, 1>::Tensor32Bit labels,
const Index max_depth)
: exp_logits_(exp_logits),
sum_exp_logits_(sum_exp_logits),
labels_(labels) {}
labels_(labels),
max_depth_(max_depth) {}

EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const Eigen::array<int, 2>& coords) const {
int batch = coords[0];
int depth = coords[1];
T subtract = (depth == labels_(batch)) ? T(1.0) : T(0.0);
const int batch = coords[0];
const int depth = coords[1];
const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch));
if (!FastBoundsCheck(label, max_depth_)) {
return Eigen::NumTraits<T>::quiet_NaN();
}
T subtract = TF_PREDICT_FALSE(depth == label) ? T(1.0) : T(0.0);
return exp_logits_(coords) / sum_exp_logits_(batch) - subtract;
};

private:
typename TTypes<const T, 2>::Tensor32Bit exp_logits_;
typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
typename TTypes<const Index, 1>::Tensor32Bit labels_;
const Index max_depth_;
};

} // namespace generator
Expand Down Expand Up @@ -185,7 +203,8 @@ struct SparseXentEigenImpl {
// along classes
generator::SparseXentLossGenerator<T, Index> sparse_xent_loss_gen(
sparse_xent_helpers::To32BitConst<T>(backprop),
sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels));
sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels),
backprop.dimension(1) /* max_depth */);
To32Bit(loss).device(d) =
To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class);

Expand All @@ -194,7 +213,8 @@ struct SparseXentEigenImpl {
To32Bit(backprop).device(d) = To32Bit(backprop).exp();
generator::SparseXentGradGenerator<T, Index> sparse_xent_grad_gen(
sparse_xent_helpers::To32BitConst<T>(backprop),
sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels));
sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels),
backprop.dimension(1) /* max_depth */);
To32Bit(backprop).device(d) =
To32Bit(backprop).generate(sparse_xent_grad_gen);
}
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/python/kernel_tests/sparse_xent_op_test.py
Expand Up @@ -73,6 +73,30 @@ def testSingleClass(self):
self._testSingleClass(use_gpu=True)
self._testSingleClass(use_gpu=False)

def _testInvalidLabel(self, use_gpu):
features = [
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]]
labels = [4, 3, 0, -1]
with self.test_session(use_gpu=use_gpu) as sess:
loss, backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
features, labels)
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllClose(
[[np.nan] * 4,
[0.25, 0.25, 0.25, -0.75],
[-0.968, 0.087, 0.237, 0.6439],
[np.nan] * 4],
tf_backprop, rtol=1e-3, atol=1e-3)
self.assertAllClose(
[np.nan, 1.3862, 3.4420, np.nan], tf_loss, rtol=1e-3, atol=1e-3)

def testInvalidLabel(self):
self._testInvalidLabel(use_gpu=True)
self._testInvalidLabel(use_gpu=False)

def testNpXent(self):
# We create 2 batches of logits for testing.
# batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3.
Expand Down

0 comments on commit dcb9053

Please sign in to comment.