Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent crash when histogram is called with NaN values.
Fixes #45770

PiperOrigin-RevId: 443149951
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Apr 20, 2022
1 parent 484b5e8 commit e57fd69
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tensorflow/core/kernels/histogram_op.cc
Expand Up @@ -50,6 +50,15 @@ struct HistogramFixedWidthFunctor<CPUDevice, T, Tout> {
static_cast<double>(nbins);
const double nbins_minus_1 = static_cast<double>(nbins - 1);

// We cannot handle NANs in the algorithm below (due to the case to int32)
const Eigen::Tensor<int32, 1, 1> nans_tensor =
values.isnan().template cast<int32>();
const Eigen::Tensor<int32, 0, 1> reduced_tensor = nans_tensor.sum();
const int num_nans = reduced_tensor(0);
if (num_nans > 0) {
return errors::InvalidArgument("Histogram values must not contain NaN");
}

// The calculation is done by finding the slot of each value in `values`.
// With [a, b]:
// step = (b - a) / nbins
Expand Down Expand Up @@ -98,12 +107,12 @@ class HistogramFixedWidthOp : public OpKernel {
const auto nbins = nbins_tensor.scalar<int32>()();

OP_REQUIRES(
ctx, (value_range(0) < value_range(1)),
ctx, value_range(0) < value_range(1),
errors::InvalidArgument("value_range should satisfy value_range[0] < "
"value_range[1], but got '[",
value_range(0), ", ", value_range(1), "]'"));
OP_REQUIRES(
ctx, (nbins > 0),
ctx, nbins > 0,
errors::InvalidArgument("nbins should be a positive number, but got '",
nbins, "'"));

Expand Down

0 comments on commit e57fd69

Please sign in to comment.