Navigation Menu

Skip to content

Commit

Permalink
Convert CHECKs to OP_REQUIRES in quantile ops for boosted trees.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 414159811
Change-Id: I59ca6a106b8e4159a2098966001e707729c80bcf
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Dec 4, 2021
1 parent c9a75c7 commit c99ecf8
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tensorflow/core/kernels/boosted_trees/quantile_ops.cc
Expand Up @@ -170,7 +170,8 @@ class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
context->input(kExampleWeightsName, &example_weights_t));
DCHECK(float_features_list.size() > 0) << "Got empty feature list";
OP_REQUIRES(context, float_features_list.size() > 0,
errors::Internal("Got empty feature list"));
auto example_weights = example_weights_t->flat<float>();
const int64_t weight_size = example_weights.size();
const int64_t batch_size = float_features_list[0].flat<float>().size();
Expand Down Expand Up @@ -324,8 +325,11 @@ class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
OpInputList summaries_list;
OP_REQUIRES_OK(context,
context->input_list(kSummariesName, &summaries_list));
int32_t num_streams = stream_resource->num_streams();
CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
auto num_streams = stream_resource->num_streams();
OP_REQUIRES(
context, num_streams == summaries_list.size(),
errors::Internal("Expected num_streams == summaries_list.size(), got ",
num_streams, " vs ", summaries_list.size()));

auto do_quantile_add_summary = [&](const int64_t begin, const int64_t end) {
// Iterating all features.
Expand All @@ -340,7 +344,10 @@ class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
const auto summary_values = summaries.matrix<float>();
const auto& tensor_shape = summaries.shape();
const int64_t entries_size = tensor_shape.dim_size(0);
CHECK_EQ(tensor_shape.dim_size(1), 4);
OP_REQUIRES(
context, tensor_shape.dim_size(1) == 4,
errors::Internal("Expected tensor_shape.dim_size(1) == 4, got ",
tensor_shape.dim_size(1)));
std::vector<QuantileSummaryEntry> summary_entries;
summary_entries.reserve(entries_size);
for (int64_t i = 0; i < entries_size; i++) {
Expand Down Expand Up @@ -512,7 +519,9 @@ class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
mutex_lock l(*stream_resource->mutex());

const int64_t num_streams = stream_resource->num_streams();
CHECK_EQ(num_features_, num_streams);
OP_REQUIRES(context, num_streams == num_features_,
errors::Internal("Expected num_streams == num_features_, got ",
num_streams, " vs ", num_features_));
OpOutputList bucket_boundaries_list;
OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
&bucket_boundaries_list));
Expand Down

0 comments on commit c99ecf8

Please sign in to comment.