-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calculate feature_importances for BoostedTreesRegressor and BoostedTreesClassifier #21509
Merged
tensorflow-copybara
merged 24 commits into
tensorflow:master
from
facaiy:ENH/feature_importances_for_boosted_tree
Sep 25, 2018
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
b81f4bb
ENH: implement feature importances
facaiy 54fbe83
TST: add test case
facaiy 7ad6047
ENH: mapping idx to feature_name
facaiy 0845a01
CLN: revise code according to comments
facaiy 196f547
CLN: use CheckpointReader to load TreeEnsemble proto
facaiy 7ed0680
TST: revise test case
facaiy 52d637e
CLN: normalize is False by default
facaiy 5630efc
CLN: revise according to comments
facaiy e39bbe4
TST: add test case for negative feature importances
facaiy 88d722c
ENH: don't divide by the sum of tree weights
facaiy 73c8cbb
TST: add test case for full tree with leaves
facaiy 4979d73
CLN: revise codes
facaiy 407a64b
TST: revise test case and too long line
facaiy f8ee979
ENH: raise exception if unsupported features/columns is given
facaiy b3114e5
Merge remote-tracking branch 'upstream/master' into ENH/feature_impor…
facaiy 8c51bbc
BLD: update golden file
facaiy fd41d2c
CLN: fix code style
facaiy 04ddc2d
Merge branch 'master' into ENH/feature_importances_for_boosted_tree
facaiy 30e176f
CLN: only assert gains >= 0 for normalization
facaiy 9fcf40a
CLN: remove unused import
facaiy cc3a7a8
CLN: minor changes
facaiy c7fcdf8
Merge remote-tracking branch 'upstream/master' into ENH/feature_impor…
facaiy fb2918f
TST: introduce test case from upstream/master
facaiy 046c74c
Merge remote-tracking branch 'upstream/master' into ENH/feature_impor…
facaiy File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,9 @@ | |
import collections | ||
import functools | ||
|
||
import numpy as np | ||
|
||
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 | ||
from tensorflow.python.estimator import estimator | ||
from tensorflow.python.estimator import model_fn as model_fn_lib | ||
from tensorflow.python.estimator.canned import boosted_trees_utils | ||
|
@@ -40,6 +43,7 @@ | |
from tensorflow.python.ops.array_ops import identity as tf_identity | ||
from tensorflow.python.ops.losses import losses | ||
from tensorflow.python.summary import summary | ||
from tensorflow.python.training import checkpoint_utils | ||
from tensorflow.python.training import session_run_hook | ||
from tensorflow.python.training import training_util | ||
from tensorflow.python.util.tf_export import estimator_export | ||
|
@@ -193,6 +197,43 @@ def _calculate_num_features(sorted_feature_columns): | |
return num_features | ||
|
||
|
||
def _generate_feature_name_mapping(sorted_feature_columns): | ||
"""Return a list of feature name for feature ids. | ||
|
||
Args: | ||
sorted_feature_columns: a list/set of tf.feature_column sorted by name. | ||
|
||
Returns: | ||
feature_name_mapping: a list of feature names indexed by the feature ids. | ||
|
||
Raises: | ||
ValueError: when unsupported features/columns are tried. | ||
""" | ||
names = [] | ||
for column in sorted_feature_columns: | ||
if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access | ||
categorical_column = column.categorical_column | ||
if isinstance(categorical_column, | ||
feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access | ||
for value in categorical_column.vocabulary_list: | ||
names.append('{}:{}'.format(column.name, value)) | ||
elif isinstance(categorical_column, | ||
feature_column_lib._BucketizedColumn): # pylint:disable=protected-access | ||
boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf] | ||
for pair in zip(boundaries[:-1], boundaries[1:]): | ||
names.append('{}:{}'.format(column.name, pair)) | ||
else: | ||
for num in range(categorical_column._num_buckets): # pylint:disable=protected-access | ||
names.append('{}:{}'.format(column.name, num)) | ||
elif isinstance(column, feature_column_lib._BucketizedColumn): | ||
names.append(column.name) | ||
else: | ||
raise ValueError( | ||
'For now, only bucketized_column and indicator_column is supported ' | ||
'but got: {}'.format(column)) | ||
return names | ||
|
||
|
||
def _cache_transformed_features(features, sorted_feature_columns, batch_size): | ||
"""Transform features and cache, then returns (cached_features, cache_op).""" | ||
num_features = _calculate_num_features(sorted_feature_columns) | ||
|
@@ -966,6 +1007,60 @@ def _create_regression_head(label_dimension, weight_column=None): | |
# pylint: enable=protected-access | ||
|
||
|
||
def _compute_feature_importances_per_tree(tree, num_features): | ||
"""Computes the importance of each feature in the tree.""" | ||
importances = np.zeros(num_features) | ||
|
||
for node in tree.nodes: | ||
node_type = node.WhichOneof('node') | ||
if node_type == 'bucketized_split': | ||
feature_id = node.bucketized_split.feature_id | ||
importances[feature_id] += node.metadata.gain | ||
elif node_type == 'leaf': | ||
assert node.metadata.gain == 0 | ||
else: | ||
raise ValueError('Unexpected split type %s', node_type) | ||
|
||
return importances | ||
|
||
|
||
def _compute_feature_importances(tree_ensemble, num_features, normalize): | ||
"""Computes gain-based feature importances. | ||
|
||
The higher the value, the more important the feature. | ||
|
||
Args: | ||
tree_ensemble: a trained tree ensemble, instance of proto | ||
boosted_trees.TreeEnsemble. | ||
num_features: The total number of feature ids. | ||
normalize: If True, normalize the feature importances. | ||
|
||
Returns: | ||
sorted_feature_idx: A list of feature_id which is sorted | ||
by its feature importance. | ||
feature_importances: A list of corresponding feature importances. | ||
|
||
Raises: | ||
AssertionError: When normalize = True, if feature importances | ||
contain negative value, or if normalization is not possible | ||
(e.g. ensemble is empty or trees contain only a root node). | ||
""" | ||
tree_importances = [_compute_feature_importances_per_tree(tree, num_features) | ||
for tree in tree_ensemble.trees] | ||
tree_importances = np.array(tree_importances) | ||
tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1) | ||
feature_importances = np.sum(tree_importances * tree_weights, axis=0) | ||
if normalize: | ||
assert np.all(feature_importances >= 0), ('feature_importances ' | ||
'must be non-negative.') | ||
normalizer = np.sum(feature_importances) | ||
assert normalizer > 0, 'Trees are all empty or contain only a root node.' | ||
feature_importances /= normalizer | ||
|
||
sorted_feature_idx = np.argsort(feature_importances)[::-1] | ||
return sorted_feature_idx, feature_importances[sorted_feature_idx] | ||
|
||
|
||
def _bt_explanations_fn(features, | ||
head, | ||
sorted_feature_columns, | ||
|
@@ -1053,9 +1148,41 @@ class for more detail. | |
feature_columns, key=lambda tc: tc.name) | ||
self._head = head | ||
self._n_features = _calculate_num_features(self._sorted_feature_columns) | ||
self._names_for_feature_id = np.array( | ||
_generate_feature_name_mapping(self._sorted_feature_columns)) | ||
self._center_bias = center_bias | ||
self._is_classification = is_classification | ||
|
||
def experimental_feature_importances(self, normalize=False): | ||
"""Computes gain-based feature importances. | ||
|
||
The higher the value, the more important the corresponding feature. | ||
|
||
Args: | ||
normalize: If True, normalize the feature importances. | ||
|
||
Returns: | ||
sorted_feature_names: 1-D array of feature name which is sorted | ||
by its feature importance. | ||
feature_importances: 1-D array of the corresponding feature importance. | ||
|
||
Raises: | ||
ValueError: When attempting to normalize on an empty ensemble | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or when attempting to normalize and feature importances have negative values |
||
or an ensemble of trees which have no splits. Or when attempting | ||
to normalize and feature importances have negative values. | ||
""" | ||
reader = checkpoint_utils.load_checkpoint(self._model_dir) | ||
serialized = reader.get_tensor('boosted_trees:0_serialized') | ||
if not serialized: | ||
raise ValueError('Found empty serialized string for TreeEnsemble.' | ||
'You should only call this method after training.') | ||
ensemble_proto = boosted_trees_pb2.TreeEnsemble() | ||
ensemble_proto.ParseFromString(serialized) | ||
|
||
sorted_feature_id, importances = _compute_feature_importances( | ||
ensemble_proto, self._n_features, normalize) | ||
return self._names_for_feature_id[sorted_feature_id], importances | ||
|
||
def experimental_predict_with_explanations(self, | ||
input_fn, | ||
predict_keys=None, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems this function needs unit testing.
we're planning to expand support of feature columns. it will be quite easy to miss this function. Either we should support all feature columns in this function, or error for not supported cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
testFeatureImportancesNamesForCategoricalColumn
verifies the method in part. Do we need to add unit tests for this private method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's say we added a support for multi dimensional numeric_column. and we forget to update this function. what will happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a problem if we forget to update this function.
I'm afraid that
_calculate_num_feature
will meet the same problem: if we add new subclass, and we forget to update it. Because we use external function to calculate num_feature and featuer_name, it's easy to forget it when adding new subclass. Perhaps we'd better to make them become the property of FeatureColumn base class, and force every subclass to handle the details. I think the question has beyond the scope of this PR.I think unit tests can only assure old behaviors unchanged. Say, we have supported numeric_column class here, and we use tests to track its behavior. If we add new class, multiple dimensional numeric_column, but we forget to add the corresponding tests. I'm afraid that unit test can do nothing for this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Yan,
Proposal is erroring out if an unexpected feature column is given. And have unit tests related to do that. So that whenever this method is not updated, we will have a clear error message.
The other functions you mentioned will cause learning will behave unexpected. So it's likely that we'll caught error in them since there will be unit tests for new columns. But feature importance is not part of learning tests. that's why it's likely to miss. If you assume that it may not be one of you or I or Natalia who will add new features, these kind of errors help new developers.
please let me know does it make sense or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, my mistake. You're right! The function doesn't support all feature columns, say, multiple dimension numeric_column.
I'll add the check like
_get_transformed_features
: we only support _BucketizedColumn and _IndicatorColumn by far.