Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate feature_importances for BoostedTreesRegressor and BoostedTreesClassifier #21509

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b81f4bb
ENH: implement feature importances
facaiy Aug 9, 2018
54fbe83
TST: add test case
facaiy Aug 9, 2018
7ad6047
ENH: mapping idx to feature_name
facaiy Aug 13, 2018
0845a01
CLN: revise code according to comments
facaiy Aug 17, 2018
196f547
CLN: use CheckpointReader to load TreeEnsemble proto
facaiy Aug 17, 2018
7ed0680
TST: revise test case
facaiy Aug 17, 2018
52d637e
CLN: normalize is False by default
facaiy Aug 19, 2018
5630efc
CLN: revise according to comments
facaiy Aug 21, 2018
e39bbe4
TST: add test case for negative feature importances
facaiy Aug 21, 2018
88d722c
ENH: don't divide by the sum of tree weights
facaiy Aug 21, 2018
73c8cbb
TST: add test case for full tree with leaves
facaiy Aug 21, 2018
4979d73
CLN: revise codes
facaiy Aug 22, 2018
407a64b
TST: revise test case and too long line
facaiy Aug 23, 2018
f8ee979
ENH: raise exception if unsupported features/columns is given
facaiy Aug 24, 2018
b3114e5
Merge remote-tracking branch 'upstream/master' into ENH/feature_impor…
facaiy Aug 31, 2018
8c51bbc
BLD: update golden file
facaiy Sep 12, 2018
fd41d2c
CLN: fix code style
facaiy Sep 12, 2018
04ddc2d
Merge branch 'master' into ENH/feature_importances_for_boosted_tree
facaiy Sep 13, 2018
30e176f
CLN: only assert gains >= 0 for normalization
facaiy Sep 14, 2018
9fcf40a
CLN: remove unused import
facaiy Sep 14, 2018
cc3a7a8
CLN: minor changes
facaiy Sep 18, 2018
c7fcdf8
Merge remote-tracking branch 'upstream/master' into ENH/feature_impor…
facaiy Sep 19, 2018
fb2918f
TST: introduce test case from upstream/master
facaiy Sep 19, 2018
046c74c
Merge remote-tracking branch 'upstream/master' into ENH/feature_impor…
facaiy Sep 19, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@ def _input_fn():
return _input_fn


# pylint: disable=protected-access
def _is_classification_head(head):
"""Infers if the head is a classification head."""
# Check using all classification heads defined in canned/head.py. However, it
# is not a complete list - it does not check for other classification heads
# not defined in the head library.
# pylint: disable=protected-access
return isinstance(head,
(head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
# pylint: enable=protected-access


class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access
"""An Estimator for Tensorflow Boosted Trees models."""

def __init__(self,
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(self,
are requested.
"""
# HParams for the model.
# pylint: disable=protected-access
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
Expand Down
127 changes: 127 additions & 0 deletions tensorflow/python/estimator/canned/boosted_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import collections
import functools

import numpy as np

from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import boosted_trees_utils
Expand All @@ -40,6 +43,7 @@
from tensorflow.python.ops.array_ops import identity as tf_identity
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
Expand Down Expand Up @@ -193,6 +197,43 @@ def _calculate_num_features(sorted_feature_columns):
return num_features


def _generate_feature_name_mapping(sorted_feature_columns):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this function needs unit testing.
we're planning to expand support of feature columns. it will be quite easy to miss this function. Either we should support all feature columns in this function, or error for not supported cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think testFeatureImportancesNamesForCategoricalColumn verifies the method in part. Do we need to add unit tests for this private method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's say we added a support for multi dimensional numeric_column. and we forget to update this function. what will happen?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a problem if we forget to update this function.

  1. I'm afraid that _calculate_num_feature will meet the same problem: if we add new subclass, and we forget to update it. Because we use external function to calculate num_feature and featuer_name, it's easy to forget it when adding new subclass. Perhaps we'd better to make them become the property of FeatureColumn base class, and force every subclass to handle the details. I think the question has beyond the scope of this PR.

  2. I think unit tests can only assure old behaviors unchanged. Say, we have supported numeric_column class here, and we use tests to track its behavior. If we add new class, multiple dimensional numeric_column, but we forget to add the corresponding tests. I'm afraid that unit test can do nothing for this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Yan,
Proposal is erroring out if an unexpected feature column is given. And have unit tests related to do that. So that whenever this method is not updated, we will have a clear error message.

The other functions you mentioned will cause learning will behave unexpected. So it's likely that we'll caught error in them since there will be unit tests for new columns. But feature importance is not part of learning tests. that's why it's likely to miss. If you assume that it may not be one of you or I or Natalia who will add new features, these kind of errors help new developers.

please let me know does it make sense or not

Copy link
Member Author

@facaiy facaiy Aug 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, my mistake. You're right! The function doesn't support all feature columns, say, multiple dimension numeric_column.

Either we should support all feature columns in this function, or error for not supported cases.

I'll add the check like _get_transformed_features: we only support _BucketizedColumn and _IndicatorColumn by far.

"""Return a list of feature name for feature ids.

Args:
sorted_feature_columns: a list/set of tf.feature_column sorted by name.

Returns:
feature_name_mapping: a list of feature names indexed by the feature ids.

Raises:
ValueError: when unsupported features/columns are tried.
"""
names = []
for column in sorted_feature_columns:
if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access
categorical_column = column.categorical_column
if isinstance(categorical_column,
feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access
for value in categorical_column.vocabulary_list:
names.append('{}:{}'.format(column.name, value))
elif isinstance(categorical_column,
feature_column_lib._BucketizedColumn): # pylint:disable=protected-access
boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf]
for pair in zip(boundaries[:-1], boundaries[1:]):
names.append('{}:{}'.format(column.name, pair))
else:
for num in range(categorical_column._num_buckets): # pylint:disable=protected-access
names.append('{}:{}'.format(column.name, num))
elif isinstance(column, feature_column_lib._BucketizedColumn):
names.append(column.name)
else:
raise ValueError(
'For now, only bucketized_column and indicator_column is supported '
'but got: {}'.format(column))
return names


def _cache_transformed_features(features, sorted_feature_columns, batch_size):
"""Transform features and cache, then returns (cached_features, cache_op)."""
num_features = _calculate_num_features(sorted_feature_columns)
Expand Down Expand Up @@ -966,6 +1007,60 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access


def _compute_feature_importances_per_tree(tree, num_features):
"""Computes the importance of each feature in the tree."""
importances = np.zeros(num_features)

for node in tree.nodes:
node_type = node.WhichOneof('node')
if node_type == 'bucketized_split':
feature_id = node.bucketized_split.feature_id
importances[feature_id] += node.metadata.gain
elif node_type == 'leaf':
assert node.metadata.gain == 0
else:
raise ValueError('Unexpected split type %s', node_type)

return importances


def _compute_feature_importances(tree_ensemble, num_features, normalize):
"""Computes gain-based feature importances.

The higher the value, the more important the feature.

Args:
tree_ensemble: a trained tree ensemble, instance of proto
boosted_trees.TreeEnsemble.
num_features: The total number of feature ids.
normalize: If True, normalize the feature importances.

Returns:
sorted_feature_idx: A list of feature_id which is sorted
by its feature importance.
feature_importances: A list of corresponding feature importances.

Raises:
AssertionError: When normalize = True, if feature importances
contain negative value, or if normalization is not possible
(e.g. ensemble is empty or trees contain only a root node).
"""
tree_importances = [_compute_feature_importances_per_tree(tree, num_features)
for tree in tree_ensemble.trees]
tree_importances = np.array(tree_importances)
tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1)
feature_importances = np.sum(tree_importances * tree_weights, axis=0)
if normalize:
assert np.all(feature_importances >= 0), ('feature_importances '
'must be non-negative.')
normalizer = np.sum(feature_importances)
assert normalizer > 0, 'Trees are all empty or contain only a root node.'
feature_importances /= normalizer

sorted_feature_idx = np.argsort(feature_importances)[::-1]
return sorted_feature_idx, feature_importances[sorted_feature_idx]


def _bt_explanations_fn(features,
head,
sorted_feature_columns,
Expand Down Expand Up @@ -1053,9 +1148,41 @@ class for more detail.
feature_columns, key=lambda tc: tc.name)
self._head = head
self._n_features = _calculate_num_features(self._sorted_feature_columns)
self._names_for_feature_id = np.array(
_generate_feature_name_mapping(self._sorted_feature_columns))
self._center_bias = center_bias
self._is_classification = is_classification

def experimental_feature_importances(self, normalize=False):
"""Computes gain-based feature importances.

The higher the value, the more important the corresponding feature.

Args:
normalize: If True, normalize the feature importances.

Returns:
sorted_feature_names: 1-D array of feature name which is sorted
by its feature importance.
feature_importances: 1-D array of the corresponding feature importance.

Raises:
ValueError: When attempting to normalize on an empty ensemble

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or when attempting to normalize and feature importances have negative values

or an ensemble of trees which have no splits. Or when attempting
to normalize and feature importances have negative values.
"""
reader = checkpoint_utils.load_checkpoint(self._model_dir)
serialized = reader.get_tensor('boosted_trees:0_serialized')
if not serialized:
raise ValueError('Found empty serialized string for TreeEnsemble.'
'You should only call this method after training.')
ensemble_proto = boosted_trees_pb2.TreeEnsemble()
ensemble_proto.ParseFromString(serialized)

sorted_feature_id, importances = _compute_feature_importances(
ensemble_proto, self._n_features, normalize)
return self._names_for_feature_id[sorted_feature_id], importances

def experimental_predict_with_explanations(self,
input_fn,
predict_keys=None,
Expand Down