Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add remaining missing validation to `BoostedTreesCalculateBestFeature…
…Split`

PiperOrigin-RevId: 387423006
Change-Id: I8eaf30efb223011519e60707bfa751b275d3a443
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 28, 2021
1 parent 4f8db85 commit 429f009
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion tensorflow/core/kernels/boosted_trees/stats_ops.cc
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <limits>
#include <string>
#include <vector>

#include "third_party/eigen3/Eigen/Core"
Expand All @@ -22,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"

namespace tensorflow {
Expand Down Expand Up @@ -254,12 +256,18 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
// node_id_range
const Tensor* node_id_range_t;
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
OP_REQUIRES(
context, node_id_range_t->NumElements() == 2,
errors::InvalidArgument("node_id_range argument must have shape [2]"));
const auto node_id_range = node_id_range_t->vec<int32>();
const int32_t node_id_first = node_id_range(0); // inclusive
const int32_t node_id_last = node_id_range(1); // exclusive

const Tensor* stats_summary_t;
OP_REQUIRES_OK(context, context->input("stats_summary", &stats_summary_t));
OP_REQUIRES(
context, stats_summary_t->shape().dims() == 4,
errors::InvalidArgument("stats_summary argument must have rank 4"));
TTypes<float, 4>::ConstTensor stats_summary =
stats_summary_t->tensor<float, 4>();
const int32_t feature_dims = stats_summary_t->dim_size(1);
Expand All @@ -272,6 +280,8 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {

const Tensor* l1_t;
OP_REQUIRES_OK(context, context->input("l1", &l1_t));
OP_REQUIRES(context, l1_t->NumElements() == 1,
errors::InvalidArgument("l1 argument must be a scalar"));
const auto l1 = l1_t->scalar<float>()();
DCHECK_GE(l1, 0);
if (logits_dim_ > 1) {
Expand All @@ -281,17 +291,25 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {

const Tensor* l2_t;
OP_REQUIRES_OK(context, context->input("l2", &l2_t));
OP_REQUIRES(context, l2_t->NumElements() == 1,
errors::InvalidArgument("l2 argument must be a scalar"));
const auto l2 = l2_t->scalar<float>()();
DCHECK_GE(l2, 0);

const Tensor* tree_complexity_t;
OP_REQUIRES_OK(context,
context->input("tree_complexity", &tree_complexity_t));
OP_REQUIRES(
context, tree_complexity_t->NumElements() == 1,
errors::InvalidArgument("tree_complexity argument must be a scalar"));
const auto tree_complexity = tree_complexity_t->scalar<float>()();

const Tensor* min_node_weight_t;
OP_REQUIRES_OK(context,
context->input("min_node_weight", &min_node_weight_t));
OP_REQUIRES(
context, min_node_weight_t->NumElements() == 1,
errors::InvalidArgument("min_node_weight argument must be a scalar"));
const auto min_node_weight = min_node_weight_t->scalar<float>()();

std::vector<int32> output_node_ids;
Expand All @@ -300,7 +318,7 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
std::vector<int32> output_thresholds;
std::vector<Eigen::VectorXf> output_left_node_contribs;
std::vector<Eigen::VectorXf> output_right_node_contribs;
std::vector<string> output_split_types;
std::vector<std::string> output_split_types;

// TODO(tanzheny) parallelize the computation.
// Iterate each node and find the best gain per node.
Expand Down

0 comments on commit 429f009

Please sign in to comment.