Skip to content

Commit

Permalink
Create a class to use for input specification in TensorForest, and su…
Browse files Browse the repository at this point in the history
…pport a mix of dense and sparse features. Move feature processing to tensor_forest to have the most sane interface (just dict of tensors or single tensor).

Had to move data_ops because of python3 import weirdness.
Change value == bias to go right instead of left because this inherently handles sparse one-hot categorical data (if a node's bias is set at 1 (the only value it could pick), both values of 1 and 0 are not strictly > than 1, so that node would always go left).
Change: 143970989
  • Loading branch information
tensorflower-gardener committed Jan 9, 2017
1 parent 0da4df0 commit 70b6a5d
Show file tree
Hide file tree
Showing 22 changed files with 1,042 additions and 591 deletions.
2 changes: 1 addition & 1 deletion tensorflow/contrib/cmake/tf_core_kernels.cmake
Expand Up @@ -46,13 +46,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/grow_tree_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/reinterpret_string_to_float_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/scatter_add_ndim_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/topn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/data/reinterpret_string_to_float_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc"
Expand Down
21 changes: 7 additions & 14 deletions tensorflow/contrib/learn/python/learn/estimators/random_forest.py
Expand Up @@ -27,7 +27,6 @@
from tensorflow.contrib.learn.python.learn.utils import export

from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.data import data_ops
from tensorflow.contrib.tensor_forest.python import tensor_forest

from tensorflow.python.framework import dtypes
Expand Down Expand Up @@ -110,16 +109,12 @@ def _model_fn(features, labels):
weights = features.pop(weights_name)
if keys_name and keys_name in features:
keys = features.pop(keys_name)
processed_features, spec = data_ops.ParseDataTensorOrDict(features)
_assert_float32(processed_features)
if labels is not None:
labels = data_ops.ParseLabelTensorOrDict(labels)
_assert_float32(labels)

graph_builder = graph_builder_class(params, device_assigner=device_assigner)
inference = {eval_metrics.INFERENCE_PROB_NAME:
graph_builder.inference_graph(processed_features,
data_spec=spec)}
inference = {
eval_metrics.INFERENCE_PROB_NAME:
graph_builder.inference_graph(features)
}
if not params.regression:
inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
inference[eval_metrics.INFERENCE_PROB_NAME], 1)
Expand All @@ -131,13 +126,11 @@ def _model_fn(features, labels):
training_loss = None
training_graph = None
if labels is not None:
training_loss = graph_builder.training_loss(processed_features, labels,
data_spec=spec,
name=LOSS_NAME)
training_loss = graph_builder.training_loss(
features, labels, name=LOSS_NAME)
training_graph = control_flow_ops.group(
graph_builder.training_graph(
processed_features, labels, data_spec=spec,
input_weights=weights),
features, labels, input_weights=weights),
state_ops.assign_add(contrib_framework.get_global_step(), 1))
# Put weights back in
if weights is not None:
Expand Down
12 changes: 5 additions & 7 deletions tensorflow/contrib/tensor_forest/BUILD
Expand Up @@ -28,12 +28,10 @@ filegroup(
srcs = glob(
[
"core/ops/*.cc",
"data/*.cc",
"ops/*.cc",
],
exclude = [
"core/ops/*_test.cc",
"data/*_test.cc",
],
),
)
Expand Down Expand Up @@ -70,11 +68,10 @@ py_library(
py_library(
name = "data_ops_py",
srcs = [
"data/data_ops.py",
"python/ops/data_ops.py",
],
srcs_version = "PY2AND3",
deps = [
":constants",
":tensor_forest_ops_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
Expand All @@ -101,12 +98,12 @@ tf_custom_op_library(
"core/ops/count_extremely_random_stats_op.cc",
"core/ops/finished_nodes_op.cc",
"core/ops/grow_tree_op.cc",
"core/ops/reinterpret_string_to_float_op.cc",
"core/ops/sample_inputs_op.cc",
"core/ops/scatter_add_ndim_op.cc",
"core/ops/topn_ops.cc",
"core/ops/tree_predictions_op.cc",
"core/ops/update_fertile_slots_op.cc",
"data/reinterpret_string_to_float_op.cc",
"ops/tensor_forest_ops.cc",
],
deps = [":tree_utils"],
Expand All @@ -117,7 +114,6 @@ py_library(
srcs = [
"__init__.py",
"client/__init__.py",
"data/__init__.py",
"python/__init__.py",
],
srcs_version = "PY2AND3",
Expand Down Expand Up @@ -190,6 +186,7 @@ cc_library(
name = "tree_utils",
srcs = ["core/ops/tree_utils.cc"],
hdrs = [
"core/ops/data_spec.h",
"core/ops/tree_utils.h",
],
deps = [
Expand Down Expand Up @@ -218,7 +215,7 @@ py_test(
srcs = ["python/kernel_tests/count_extremely_random_stats_op_test.py"],
srcs_version = "PY2AND3",
deps = [
":constants",
":data_ops_py",
":tensor_forest_ops_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_test_lib",
Expand Down Expand Up @@ -259,6 +256,7 @@ py_test(
srcs = ["python/kernel_tests/sample_inputs_op_test.py"],
srcs_version = "PY2AND3",
deps = [
":data_ops_py",
":tensor_forest_ops_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
Expand Down
1 change: 0 additions & 1 deletion tensorflow/contrib/tensor_forest/__init__.py
Expand Up @@ -19,6 +19,5 @@

# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.tensor_forest.client import *
from tensorflow.contrib.tensor_forest.data import *
from tensorflow.contrib.tensor_forest.python import *
# pylint: enable=unused-import,wildcard-import
Expand Up @@ -21,8 +21,8 @@
#include <unordered_set>
#include <vector>

#include "tensorflow/contrib/tensor_forest/core/ops/data_spec.h"
#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
Expand All @@ -44,10 +44,10 @@ using tensorforest::LEAF_NODE;
using tensorforest::FREE_NODE;

using tensorforest::CheckTensorBounds;
using tensorforest::DataColumnTypes;
using tensorforest::DecideNode;
using tensorforest::TensorForestDataSpec;
using tensorforest::Initialize;
using tensorforest::IsAllInitialized;
using tensorforest::FeatureSpec;

using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
Expand All @@ -69,9 +69,10 @@ struct InputDataResult {


struct EvaluateParams {
std::function<bool(int, int, float,
tensorforest::DataColumnTypes)> decide_function;
Tensor input_spec;
TensorForestDataSpec input_spec;
Tensor dense_input;
Tensor sparse_indices;
Tensor sparse_values;
Tensor input_labels;
Tensor tree_tensor;
Tensor tree_thresholds;
Expand Down Expand Up @@ -103,11 +104,11 @@ void Evaluate(const EvaluateParams& params, int32 start, int32 end) {
while (true) {
params.results[i].node_indices.push_back(node_index);
CHECK_LT(node_index, num_nodes);
int32 left_child = internal::SubtleMustCopy(
tree(node_index, CHILDREN_INDEX));
int32 left_child =
internal::SubtleMustCopy(tree(node_index, CHILDREN_INDEX));
if (left_child == LEAF_NODE) {
const int32 accumulator = internal::SubtleMustCopy(
node_map(node_index));
const int32 accumulator =
internal::SubtleMustCopy(node_map(node_index));
params.results[i].leaf_accumulator = accumulator;
// If the leaf is not fertile or is not yet initialized, we don't
// count it in the candidate/total split per-class-weights because
Expand All @@ -119,10 +120,10 @@ void Evaluate(const EvaluateParams& params, int32 start, int32 end) {
params.results[i].splits_initialized = true;
for (int split = 0; split < num_splits; split++) {
const int32 feature = split_features(accumulator, split);

if (!params.decide_function(
i, feature, split_thresholds(accumulator, split),
FeatureSpec(feature, params.input_spec))) {
if (!DecideNode(params.dense_input, params.sparse_indices,
params.sparse_values, i, feature,
split_thresholds(accumulator, split),
params.input_spec)) {
params.results[i].split_adds.push_back(split);
}
}
Expand All @@ -135,14 +136,13 @@ void Evaluate(const EvaluateParams& params, int32 start, int32 end) {
}
const int32 feature = tree(node_index, FEATURE_INDEX);
node_index =
left_child +
params.decide_function(i, feature, thresholds(node_index),
FeatureSpec(feature, params.input_spec));
left_child + DecideNode(params.dense_input, params.sparse_indices,
params.sparse_values, i, feature,
thresholds(node_index), params.input_spec);
}
}
}


class CountExtremelyRandomStats : public OpKernel {
public:
explicit CountExtremelyRandomStats(OpKernelConstruction* context)
Expand All @@ -151,29 +151,39 @@ class CountExtremelyRandomStats : public OpKernel {
"num_classes", &num_classes_));
OP_REQUIRES_OK(context, context->GetAttr(
"regression", &regression_));
string serialized_proto;
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
input_spec_.ParseFromString(serialized_proto);
}

void Compute(OpKernelContext* context) override {
const Tensor& input_data = context->input(0);
const Tensor& sparse_input_indices = context->input(1);
const Tensor& sparse_input_values = context->input(2);
const Tensor& sparse_input_shape = context->input(3);
const Tensor& input_spec = context->input(4);
const Tensor& input_labels = context->input(5);
const Tensor& input_weights = context->input(6);
const Tensor& tree_tensor = context->input(7);
const Tensor& tree_thresholds = context->input(8);
const Tensor& node_to_accumulator = context->input(9);
const Tensor& candidate_split_features = context->input(10);
const Tensor& candidate_split_thresholds = context->input(11);
const Tensor& birth_epochs = context->input(12);
const Tensor& current_epoch = context->input(13);
const Tensor& input_labels = context->input(4);
const Tensor& input_weights = context->input(5);
const Tensor& tree_tensor = context->input(6);
const Tensor& tree_thresholds = context->input(7);
const Tensor& node_to_accumulator = context->input(8);
const Tensor& candidate_split_features = context->input(9);
const Tensor& candidate_split_thresholds = context->input(10);
const Tensor& birth_epochs = context->input(11);
const Tensor& current_epoch = context->input(12);

bool sparse_input = (sparse_input_indices.shape().dims() == 2);
bool have_weights = (input_weights.shape().dim_size(0) > 0);
int32 num_data = -1;

// Check inputs.
if (sparse_input) {
const auto sparse_shape = sparse_input_shape.unaligned_flat<int64>();
// TODO(gilberth): This is because we can't figure out the shape
// of a sparse tensor at graph-build time, even if the dimension is
// actually known.
input_spec_.mutable_sparse(0)->set_size(sparse_shape(1));
num_data = sparse_shape(0);

OP_REQUIRES(context, sparse_input_shape.shape().dims() == 1,
errors::InvalidArgument(
"sparse_input_shape should be one-dimensional"));
Expand All @@ -193,7 +203,17 @@ class CountExtremelyRandomStats : public OpKernel {
errors::InvalidArgument(
"sparse_input_indices and sparse_input_values should "
"agree on the number of non-zero values"));
} else {
}

if (input_data.shape().dim_size(0) > 0) {
const int32 dense_num_data =
static_cast<int32>(input_data.shape().dim_size(0));
if (num_data > 0) {
CHECK_EQ(num_data, dense_num_data)
<< "number of examples must match for sparse + dense input.";
}
num_data = dense_num_data;

OP_REQUIRES(context, input_data.shape().dims() == 2,
errors::InvalidArgument(
"input_data should be two-dimensional"));
Expand Down Expand Up @@ -279,33 +299,15 @@ class CountExtremelyRandomStats : public OpKernel {

// Evaluate input data in parallel.
const int32 epoch = current_epoch.unaligned_flat<int32>()(0);
int32 num_data;
std::function<bool(int, int, float,
tensorforest::DataColumnTypes)> decide_function;
if (sparse_input) {
num_data = sparse_input_shape.unaligned_flat<int64>()(0);
decide_function = [&sparse_input_indices, &sparse_input_values](
int32 i, int32 feature, float bias, DataColumnTypes type) {
const auto sparse_indices = sparse_input_indices.matrix<int64>();
const auto sparse_values = sparse_input_values.vec<float>();
return tensorforest::DecideSparseNode(
sparse_indices, sparse_values, i, feature, bias, type);
};
} else {
num_data = static_cast<int32>(input_data.shape().dim_size(0));
decide_function = [&input_data](
int32 i, int32 feature, float bias, DataColumnTypes type) {
const auto input_matrix = input_data.matrix<float>();
return tensorforest::DecideDenseNode(
input_matrix, i, feature, bias, type);
};
}

std::unique_ptr<InputDataResult[]> results(new InputDataResult[num_data]);
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
EvaluateParams params;
params.decide_function = decide_function;
params.input_spec = input_spec;
params.dense_input = input_data;
params.sparse_indices = sparse_input_indices;
params.sparse_values = sparse_input_values;
params.input_spec = input_spec_;
params.input_labels = input_labels;
params.tree_tensor = tree_tensor;
params.tree_thresholds = tree_thresholds;
Expand Down Expand Up @@ -689,6 +691,7 @@ class CountExtremelyRandomStats : public OpKernel {

int32 num_classes_;
bool regression_;
tensorforest::TensorForestDataSpec input_spec_;
};


Expand Down

0 comments on commit 70b6a5d

Please sign in to comment.