Skip to content
Permalink
Browse files Browse the repository at this point in the history
Disallow empty node_id_range in tf.raw_ops.BoostedTreesCalculateBestF…
…eatureSplitV2 and tf.raw_ops.BoostedTreesCalculateBestGainsPerFeature

PiperOrigin-RevId: 387165936
Change-Id: I2f70341af96236b2776c2a592c917d549c1fc1e2
  • Loading branch information
pak-laura authored and tensorflower-gardener committed Jul 27, 2021
1 parent 9adfe49 commit 9c87c32
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tensorflow/core/kernels/boosted_trees/stats_ops.cc
Expand Up @@ -51,6 +51,16 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
// node_id_range
const Tensor* node_id_range_t;
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
OP_REQUIRES(
context, node_id_range_t->dims() == 1,
errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
"given node_id_range has dims of ",
node_id_range_t->dims()));
OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
errors::InvalidArgument(
"node_id_range must be a rank 1 tensor with shape=[2], but "
"given node_id_range has shape ",
node_id_range_t->dim_size(0), " on its first dim"));
const auto node_id_range = node_id_range_t->vec<int32>();
const int32_t node_id_first = node_id_range(0); // inclusive
const int32_t node_id_last = node_id_range(1); // exclusive
Expand Down Expand Up @@ -570,6 +580,16 @@ class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
const Tensor* node_id_range_t;
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
const auto node_id_range = node_id_range_t->vec<int32>();
OP_REQUIRES(
context, node_id_range_t->dims() == 1,
errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
"given node_id_range has dims of ",
node_id_range_t->dims()));
OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
errors::InvalidArgument(
"node_id_range must be a rank 1 tensor with shape=[2], but "
"given node_id_range has shape ",
node_id_range_t->dim_size(0), " on its first dim"));
const int32_t node_id_first = node_id_range(0); // Inclusive.
const int32_t node_id_last = node_id_range(1); // Exclusive.

Expand Down

0 comments on commit 9c87c32

Please sign in to comment.