Skip to content

Commit

Permalink
Convert CHECK to OP_REQUIRES in lambda of BoostedTreesPredictOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 414159842
Change-Id: Ib9cea6f63986ec952c437cb0f517743abe1d0e86
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Dec 4, 2021
1 parent 929bddd commit e841358
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tensorflow/core/kernels/boosted_trees/prediction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ class BoostedTreesPredictOp : public OpKernel {
}

const int32_t last_tree = resource->num_trees() - 1;
auto do_work = [&resource, &bucketized_features, &output_logits, last_tree,
this](int64_t start, int64_t end) {
auto do_work = [&context, &resource, &bucketized_features, &output_logits,
last_tree, this](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
std::vector<float> tree_logits(logits_dimension_, 0.0);
int32_t tree_id = 0;
Expand All @@ -274,7 +274,11 @@ class BoostedTreesPredictOp : public OpKernel {
if (resource->is_leaf(tree_id, node_id)) {
const float tree_weight = resource->GetTreeWeight(tree_id);
const auto& leaf_logits = resource->node_value(tree_id, node_id);
DCHECK_EQ(leaf_logits.size(), logits_dimension_);
OP_REQUIRES(
context, leaf_logits.size() == logits_dimension_,
errors::Internal(
"Expected leaf_logits.size() == logits_dimension_, got ",
leaf_logits.size(), " vs ", logits_dimension_));
for (int32_t j = 0; j < logits_dimension_; ++j) {
tree_logits[j] += tree_weight * leaf_logits[j];
}
Expand Down

0 comments on commit e841358

Please sign in to comment.