-
Notifications
You must be signed in to change notification settings - Fork 850
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #358 from rasbt/feature-importance
Add a new feature_importance_permutation function
- Loading branch information
Showing
12 changed files
with
894 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
639 changes: 639 additions & 0 deletions
639
docs/sources/user_guide/evaluate/feature_importance_permutation.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file added
BIN
+5.9 KB
...te/feature_importance_permutation_files/feature_importance_permutation_17_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+6.63 KB
...te/feature_importance_permutation_files/feature_importance_permutation_23_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+8.04 KB
...te/feature_importance_permutation_files/feature_importance_permutation_27_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+7.14 KB
...te/feature_importance_permutation_files/feature_importance_permutation_32_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.26 KB
...te/feature_importance_permutation_files/feature_importance_permutation_35_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Sebastian Raschka 2014-2018 | ||
# mlxtend Machine Learning Library Extensions | ||
# | ||
# Feature Importance Estimation Through Permutation | ||
# Author: Sebastian Raschka <sebastianraschka.com> | ||
# | ||
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
|
||
|
||
def feature_importance_permutation(X, y, predict_method, | ||
metric, num_rounds=1, seed=None): | ||
"""Feature importance imputation via permutation importance | ||
Parameters | ||
---------- | ||
X : NumPy array, shape = [n_samples, n_features] | ||
Dataset, where n_samples is the number of samples and | ||
n_features is the number of features. | ||
y : NumPy array, shape = [n_samples] | ||
Target values. | ||
predict_method : prediction function | ||
A callable function that predicts the target values | ||
from X. | ||
metric : str, callable | ||
The metric for evaluating the feature importance through | ||
permutation. By default, the strings 'accuracy' is | ||
recommended for classifiers and the string 'r2' is | ||
recommended for regressors. Optionally, a custom | ||
scoring function (e.g., `metric=scoring_func`) that | ||
accepts two arguments, y_true and y_pred, which have | ||
similar shape to the `y` array. | ||
num_rounds : int (default=1) | ||
Number of rounds the feature columns are permuted to | ||
compute the permutation importance. | ||
seed : int or None (default=None) | ||
Random seed for permuting the feature columns. | ||
Returns | ||
--------- | ||
mean_importance_vals, all_importance_vals : NumPy arrays. | ||
The first array, mean_importance_vals has shape [n_features, ] and | ||
contains the importance values for all features. | ||
The shape of the second array is [n_features, num_rounds] and contains | ||
the feature importance for each repetition. If num_rounds=1, | ||
it contains the same values as the first array, mean_importance_vals. | ||
""" | ||
|
||
if not isinstance(num_rounds, int): | ||
raise ValueError('num_rounds must be an integer.') | ||
if num_rounds < 1: | ||
raise ValueError('num_rounds must be greater than 1.') | ||
|
||
if not (metric in ('r2', 'accuracy') or hasattr(metric, '__call__')): | ||
raise ValueError('metric must be either "r2", "accuracy", ' | ||
'or a function with signature func(y_true, y_pred).') | ||
|
||
if metric == 'r2': | ||
def score_func(y_true, y_pred): | ||
sum_of_squares = np.sum(np.square(y_true - y_pred)) | ||
res_sum_of_squares = np.sum(np.square(y_true - y_true.mean())) | ||
r2_score = 1. - (sum_of_squares / res_sum_of_squares) | ||
return r2_score | ||
|
||
elif metric == 'accuracy': | ||
def score_func(y_true, y_pred): | ||
return np.mean(y_true == y_pred) | ||
|
||
rng = np.random.RandomState(seed) | ||
|
||
mean_importance_vals = np.zeros(X.shape[1]) | ||
all_importance_vals = np.zeros((X.shape[1], num_rounds)) | ||
|
||
baseline = score_func(y, predict_method(X)) | ||
|
||
for round_idx in range(num_rounds): | ||
for col_idx in range(X.shape[1]): | ||
save_col = X[:, col_idx].copy() | ||
rng.shuffle(X[:, col_idx]) | ||
new_score = score_func(y, predict_method(X)) | ||
X[:, col_idx] = save_col | ||
importance = baseline - new_score | ||
mean_importance_vals[col_idx] += importance | ||
all_importance_vals[col_idx, round_idx] = importance | ||
mean_importance_vals /= num_rounds | ||
|
||
return mean_importance_vals, all_importance_vals |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Sebastian Raschka 2014-2018 | ||
# mlxtend Machine Learning Library Extensions | ||
# | ||
# Feature Importance Estimation Through Permutation | ||
# Author: Sebastian Raschka <sebastianraschka.com> | ||
# | ||
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
from sklearn.datasets import make_classification | ||
from sklearn.datasets import make_regression | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.svm import SVC | ||
from sklearn.svm import SVR | ||
from mlxtend.utils import assert_raises | ||
from mlxtend.evaluate import feature_importance_permutation | ||
|
||
|
||
def test_num_rounds_not_int(): | ||
assert_raises(ValueError, | ||
'num_rounds must be an integer.', | ||
feature_importance_permutation, | ||
lambda x, y: (x, y), | ||
np.array([[1], [2], [3]]), | ||
np.array([1, 2, 3]), | ||
'accuracy', | ||
1.23) | ||
|
||
|
||
def test_num_rounds_negative_int(): | ||
assert_raises(ValueError, | ||
'num_rounds must be greater than 1.', | ||
feature_importance_permutation, | ||
lambda x, y: (x, y), | ||
np.array([[1], [2], [3]]), | ||
np.array([1, 2, 3]), | ||
'accuracy', | ||
-1) | ||
|
||
|
||
def test_metric_wrong(): | ||
assert_raises(ValueError, | ||
('metric must be either "r2", "accuracy", or a ' | ||
'function with signature ' | ||
'func(y_true, y_pred).'), | ||
feature_importance_permutation, | ||
lambda x, y: (x, y), | ||
np.array([[1], [2], [3]]), | ||
np.array([1, 2, 3]), | ||
'some-metric') | ||
|
||
|
||
def test_classification(): | ||
|
||
X, y = make_classification(n_samples=1000, | ||
n_features=6, | ||
n_informative=3, | ||
n_redundant=0, | ||
n_repeated=0, | ||
n_classes=2, | ||
random_state=0, | ||
shuffle=False) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=0.3, random_state=0, stratify=y) | ||
|
||
svm = SVC(C=1.0, kernel='rbf', random_state=0) | ||
svm.fit(X_train, y_train) | ||
|
||
imp_vals, imp_all = feature_importance_permutation( | ||
predict_method=svm.predict, | ||
X=X_test, | ||
y=y_test, | ||
metric='accuracy', | ||
num_rounds=1, | ||
seed=1) | ||
|
||
assert imp_vals.shape == (X_train.shape[1], ) | ||
assert imp_all.shape == (X_train.shape[1], 1) | ||
assert imp_vals[0] > 0.2 | ||
assert imp_vals[1] > 0.2 | ||
assert imp_vals[2] > 0.2 | ||
assert sum(imp_vals[3:]) <= 0.02 | ||
|
||
|
||
def test_regression(): | ||
|
||
X, y = make_regression(n_samples=1000, | ||
n_features=5, | ||
n_informative=2, | ||
n_targets=1, | ||
random_state=123, | ||
shuffle=False) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=0.3, random_state=123) | ||
|
||
svm = SVR(kernel='rbf') | ||
svm.fit(X_train, y_train) | ||
|
||
imp_vals, imp_all = feature_importance_permutation( | ||
predict_method=svm.predict, | ||
X=X_test, | ||
y=y_test, | ||
metric='r2', | ||
num_rounds=1, | ||
seed=123) | ||
|
||
assert imp_vals.shape == (X_train.shape[1], ) | ||
assert imp_all.shape == (X_train.shape[1], 1) | ||
assert imp_vals[0] > 0.2 | ||
assert imp_vals[1] > 0.2 | ||
assert sum(imp_vals[3:]) <= 0.01 | ||
|
||
|
||
def test_n_rounds(): | ||
|
||
X, y = make_classification(n_samples=1000, | ||
n_features=6, | ||
n_informative=3, | ||
n_redundant=0, | ||
n_repeated=0, | ||
n_classes=2, | ||
random_state=0, | ||
shuffle=False) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=0.3, random_state=0, stratify=y) | ||
|
||
svm = SVC(C=1.0, kernel='rbf', random_state=0) | ||
svm.fit(X_train, y_train) | ||
|
||
imp_vals, imp_all = feature_importance_permutation( | ||
predict_method=svm.predict, | ||
X=X_test, | ||
y=y_test, | ||
metric='accuracy', | ||
num_rounds=100, | ||
seed=1) | ||
|
||
assert imp_vals.shape == (X_train.shape[1], ) | ||
assert imp_all.shape == (X_train.shape[1], 100) | ||
assert imp_vals[0].mean() > 0.2 | ||
assert imp_vals[1].mean() > 0.2 |