Skip to content

Commit

Permalink
update efs
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 6, 2018
1 parent 3b9dfa9 commit 8787033
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 32 deletions.
6 changes: 4 additions & 2 deletions docs/sources/CHANGELOG.md
Expand Up @@ -20,12 +20,14 @@ The CHANGELOG for the current development version is available at
- The `SequentialFeatureSelector` now accepts custom feature names via the `fit` method for more interpretable feature subset reports. ([#379](https://github.com/rasbt/mlxtend/pull/379))
- The `SequentialFeatureSelector` is now also compatible with Pandas DataFrames and uses DataFrame column-names for more interpretable feature subset reports. ([#379](https://github.com/rasbt/mlxtend/pull/379))
- `ColumnSelector` now works with Pandas DataFrames columns. ([#378](https://github.com/rasbt/mlxtend/pull/378) by [Manuel Garrido](https://github.com/manugarri))
- The `ExhaustiveFeatureSelector` estimator in `mlxtend.feature_selection` now is safely stoppable mid-process by control+c. ([#380](https://github.com/rasbt/mlxtend/pull/380))


##### Changes

- For concistency, the `best_idx_` attribute of the `ExhaustiveFeatureSelector` was renamed to `k_feature_idx_`, which is used by the `SequentialFeatureSelector`. Likewise, `best_score_` was renamed to `k_score_`. ([#380](https://github.com/rasbt/mlxtend/pull/380))


- -

##### Bug Fixes

Expand Down Expand Up @@ -303,7 +305,7 @@ Note that this didn't cause any difference in performance on any of the test sce
- The `StackingClassifier` has a new parameter `average_probas` that is set to `True` by default to maintain the current behavior. A deprecation warning was added though, and it will default to `False` in future releases (0.6.0); `average_probas=False` will result in stacking of the level-1 predicted probabilities rather than averaging these.
- New `StackingCVClassifier` estimator in 'mlxtend.classifier' for implementing a stacking ensemble that uses cross-validation techniques for training the meta-estimator to avoid overfitting ([Reiichiro Nakano](https://github.com/reiinakano))
- New `OnehotTransactions` encoder class added to the `preprocessing` submodule for transforming transaction data into a one-hot encoded array
- The `SequentialFeatureSelector` estimator in `mlxtend.feature_selection` now is safely stoppable mid-process by control+c, and deprecated print_progress in favor of a more tunable verbose parameter ([Will McGinnis](https://github.com/wdm0006))
- The `SequentialFeatureSelector` estimator in `mlxtend.feature_selection` now is safely stoppable mid-process by control+c, and deprecated `print_progress` in favor of a more tunable `verbose` parameter ([Will McGinnis](https://github.com/wdm0006))
- New `apriori` function in `association` to extract frequent itemsets from transaction data for association rule mining
- New `checkerboard_plot` function in `plotting` to plot checkerboard tables / heat maps
- New `mcnemar_table` and `mcnemar` functions in `evaluate` to compute 2x2 contingency tables and McNemar's test
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
120 changes: 100 additions & 20 deletions mlxtend/feature_selection/exhaustive_feature_selector.py
Expand Up @@ -40,6 +40,32 @@ def _calc_score(selector, X, y, indices, **fit_params):
return indices, scores


def _get_featurenames(subsets_dict, feature_idx, custom_feature_names, X):
feature_names = None
if feature_idx is not None:
if custom_feature_names is not None:
feature_names = tuple((custom_feature_names[i]
for i in feature_idx))
elif hasattr(X, 'loc'):
feature_names = tuple((X.columns[i] for i in feature_idx))
else:
feature_names = tuple(str(i) for i in feature_idx)

subsets_dict_ = deepcopy(subsets_dict)
for key in subsets_dict_:
if custom_feature_names is not None:
new_tuple = tuple((custom_feature_names[i]
for i in subsets_dict[key]['feature_idx']))
elif hasattr(X, 'loc'):
new_tuple = tuple((X.columns[i]
for i in subsets_dict[key]['feature_idx']))
else:
new_tuple = tuple(str(i) for i in subsets_dict[key]['feature_idx'])
subsets_dict_[key]['feature_names'] = new_tuple

return subsets_dict_, feature_names


class ExhaustiveFeatureSelector(BaseEstimator, MetaEstimatorMixin):

"""Exhaustive Feature Selection for Classification and Regression.
Expand Down Expand Up @@ -91,18 +117,30 @@ class ExhaustiveFeatureSelector(BaseEstimator, MetaEstimatorMixin):
Attributes
----------
best_idx_ : array-like, shape = [n_predictions]
k_feature_idx_ : array-like, shape = [n_predictions]
Feature Indices of the selected feature subsets.
best_score_ : float
k_feature_names_ : array-like, shape = [n_predictions]
Feature names of the selected feature subsets. If pandas
DataFrames are used in the `fit` method, the feature
names correspond to the column names. Otherwise, the
feature names are string representation of the feature
array indices. New in v 0.13.0.
k_score_ : float
Cross validation average score of the selected subset.
subsets_ : dict
A dictionary of selected feature subsets during the
sequential selection, where the dictionary keys are
the lengths k of these feature subsets. The dictionary
values are dictionaries themselves with the following
keys: 'feature_idx' (tuple of indices of the feature subset)
'feature_names' (tuple of feature names of the feat. subset)
'cv_scores' (list individual cross-validation scores)
'avg_score' (average cross-validation score)
Note that if pandas
DataFrames are used in the `fit` method, the 'feature_names'
correspond to the column names. Otherwise, the
feature names are string representation of the feature
array indices. The 'feature_names' is new in v 0.13.0.
Examples
-----------
Expand Down Expand Up @@ -132,8 +170,12 @@ def __init__(self, estimator, min_features=1, max_features=1,
else:
self.est_ = self.estimator
self.fitted = False
self.interrupted_ = False

def fit(self, X, y, **fit_params):
# don't mess with this unless testing
self._TESTING_INTERRUPT_MODE = False

def fit(self, X, y, custom_feature_names=None, **fit_params):
"""Perform feature selection and learn model from training data.
Parameters
Expand All @@ -152,6 +194,25 @@ def fit(self, X, y, **fit_params):
"""

# reset from a potential previous fit run
self.subsets_ = {}
self.fitted = False
self.interrupted_ = False
self.k_feature_idx_ = None
self.k_feature_names_ = None
self.k_score_ = None

if hasattr(X, 'loc'):
X_ = X.values
else:
X_ = X

if (custom_feature_names is not None
and len(custom_feature_names) != X.shape[1]):
raise ValueError('If custom_feature_names is not None, '
'the number of elements in custom_feature_names '
'must equal the number of columns in X.')

if (not isinstance(self.max_features, int) or
(self.max_features > X.shape[1] or self.max_features < 1)):
raise AttributeError('max_features must be'
Expand All @@ -167,12 +228,10 @@ def fit(self, X, y, **fit_params):
if self.max_features < self.min_features:
raise AttributeError('min_features must be <= max_features')

candidates = chain(*((combinations(range(X.shape[1]), r=i))
candidates = chain(*((combinations(range(X_.shape[1]), r=i))
for i in range(self.min_features,
self.max_features + 1)))

self.subsets_ = {}

def ncr(n, r):
"""Return the number of combinations of length r from n items.
Expand All @@ -196,26 +255,39 @@ def ncr(n, r):
denom = reduce(op.mul, range(1, r+1))
return numer//denom

all_comb = np.sum([ncr(n=X.shape[1], r=i)
all_comb = np.sum([ncr(n=X_.shape[1], r=i)
for i in range(self.min_features,
self.max_features + 1)])

n_jobs = min(self.n_jobs, all_comb)
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=self.pre_dispatch)
work = enumerate(parallel(delayed(_calc_score)
(self, X, y, c, **fit_params)
(self, X_, y, c, **fit_params)
for c in candidates))

for iteration, (c, cv_scores) in work:
try:
for iteration, (c, cv_scores) in work:

self.subsets_[iteration] = {'feature_idx': c,
'cv_scores': cv_scores,
'avg_score': np.mean(cv_scores)}

self.subsets_[iteration] = {'feature_idx': c,
'cv_scores': cv_scores,
'avg_score': np.mean(cv_scores)}
if self.print_progress:
sys.stderr.write('\rFeatures: %d/%d' % (
iteration + 1, all_comb))
sys.stderr.flush()

if self.print_progress:
sys.stderr.write('\rFeatures: %d/%d' % (
iteration + 1, all_comb))
sys.stderr.flush()
if self._TESTING_INTERRUPT_MODE:
self.subsets_, self.k_feature_names_ = \
_get_featurenames(self.subsets_,
self.k_feature_idx_,
custom_feature_names,
X)
raise KeyboardInterrupt

except KeyboardInterrupt as e:
self.interrupted_ = True
sys.stderr.write('\nSTOPPING EARLY DUE TO KEYBOARD INTERRUPT...')

max_score = float('-inf')
for c in self.subsets_:
Expand All @@ -225,10 +297,14 @@ def ncr(n, r):
score = max_score
idx = self.subsets_[best_subset]['feature_idx']

self.best_idx_ = idx
self.best_score_ = score
self.subsets_plus_ = dict()
self.k_feature_idx_ = idx
self.k_score_ = score
self.fitted = True
self.subsets_, self.k_feature_names_ = \
_get_featurenames(self.subsets_,
self.k_feature_idx_,
custom_feature_names,
X)
return self

def transform(self, X):
Expand All @@ -246,7 +322,11 @@ def transform(self, X):
"""
self._check_fitted()
return X[:, self.best_idx_]
if hasattr(X, 'loc'):
X_ = X.values
else:
X_ = X
return X_[:, self.k_feature_idx_]

def fit_transform(self, X, y, **fit_params):
"""Fit to training data and return the best selected features from X.
Expand Down
41 changes: 31 additions & 10 deletions mlxtend/feature_selection/tests/test_exhaustive_feature_selector.py
Expand Up @@ -20,9 +20,12 @@
def dict_compare_utility(d1, d2):
assert d1.keys() == d2.keys(), "%s != %s" % (d1, d2)
for i in d1:
err_msg = ("d1[%s]['feature_idx']"
" != d2[%s]['feature_idx']" % (i, i))
assert d1[i]['feature_idx'] == d1[i]["feature_idx"], err_msg
err_msg1 = ("d1[%s]['feature_idx']"
" != d2[%s]['feature_idx']" % (i, i))
err_msg2 = ("d1[%s]['feature_names']"
" != d2[%s]['feature_names']" % (i, i))
assert d1[i]['feature_idx'] == d2[i]["feature_idx"], err_msg1
assert d1[i]['feature_names'] == d2[i]["feature_names"], err_msg2
assert_almost_equal(d1[i]['avg_score'],
d2[i]['avg_score'],
decimal=3,
Expand Down Expand Up @@ -99,33 +102,43 @@ def test_knn_wo_cv():
print_progress=False)
efs1 = efs1.fit(X, y)
expect = {0: {'feature_idx': (0, 1),
'feature_names': ('0', '1'),
'avg_score': 0.82666666666666666,
'cv_scores': np.array([0.82666667])},
1: {'feature_idx': (0, 2),
'feature_names': ('0', '2'),
'avg_score': 0.95999999999999996,
'cv_scores': np.array([0.96])},
2: {'feature_idx': (0, 3),
'feature_names': ('0', '3'),
'avg_score': 0.96666666666666667,
'cv_scores': np.array([0.96666667])},
3: {'feature_idx': (1, 2),
'feature_names': ('1', '2'),
'avg_score': 0.95999999999999996,
'cv_scores': np.array([0.96])},
4: {'feature_idx': (1, 3),
'feature_names': ('1', '3'),
'avg_score': 0.95999999999999996,
'cv_scores': np.array([0.96])},
5: {'feature_idx': (2, 3),
'feature_names': ('2', '3'),
'avg_score': 0.97333333333333338,
'cv_scores': np.array([0.97333333])},
6: {'feature_idx': (0, 1, 2),
'feature_names': ('0', '1', '2'),
'avg_score': 0.95999999999999996,
'cv_scores': np.array([0.96])},
7: {'feature_idx': (0, 1, 3),
'feature_names': ('0', '1', '3'),
'avg_score': 0.96666666666666667,
'cv_scores': np.array([0.96666667])},
8: {'feature_idx': (0, 2, 3),
'feature_names': ('0', '2', '3'),
'avg_score': 0.96666666666666667,
'cv_scores': np.array([0.96666667])},
9: {'feature_idx': (1, 2, 3),
'feature_names': ('1', '2', '3'),
'avg_score': 0.97333333333333338,
'cv_scores': np.array([0.97333333])}}
dict_compare_utility(d1=expect, d2=efs1.subsets_)
Expand All @@ -145,23 +158,27 @@ def test_knn_cv3():
efs1 = efs1.fit(X, y)
expect = {0: {'avg_score': 0.9391025641025641,
'feature_idx': (0, 1, 2),
'feature_names': ('0', '1', '2'),
'cv_scores': np.array([0.97435897, 0.94871795,
0.88888889, 0.94444444])},
1: {'avg_score': 0.94017094017094016,
'feature_idx': (0, 1, 3),
'feature_names': ('0', '1', '3'),
'cv_scores': np.array([0.92307692, 0.94871795,
0.91666667, 0.97222222])},
2: {'avg_score': 0.95299145299145294,
'feature_idx': (0, 2, 3),
'feature_names': ('0', '2', '3'),
'cv_scores': np.array([0.97435897, 0.94871795,
0.91666667, 0.97222222])},
3: {'avg_score': 0.97275641025641035,
'feature_idx': (1, 2, 3),
'feature_names': ('1', '2', '3'),
'cv_scores': np.array([0.97435897, 1.,
0.94444444, 0.97222222])}}
dict_compare_utility(d1=expect, d2=efs1.subsets_)
assert efs1.best_idx_ == (1, 2, 3)
assert round(efs1.best_score_, 4) == 0.9728
assert efs1.k_feature_idx_ == (1, 2, 3)
assert round(efs1.k_score_, 4) == 0.9728


def test_fit_params():
Expand All @@ -178,24 +195,28 @@ def test_fit_params():
print_progress=False)
efs1 = efs1.fit(X, y, sample_weight=sample_weight)
expect = {0: {'feature_idx': (0, 1, 2),
'feature_names': ('0', '1', '2'),
'cv_scores': np.array([0.94871795, 0.92307692,
0.91666667, 0.97222222]),
'avg_score': 0.9401709401709402},
1: {'feature_idx': (0, 1, 3),
'feature_names': ('0', '1', '3'),
'cv_scores': np.array([0.92307692, 0.92307692,
0.88888889, 1.]),
'avg_score': 0.9337606837606838},
2: {'feature_idx': (0, 2, 3),
'feature_names': ('0', '2', '3'),
'cv_scores': np.array([0.97435897, 0.94871795,
0.94444444, 0.97222222]),
'avg_score': 0.9599358974358974},
3: {'feature_idx': (1, 2, 3),
'feature_names': ('1', '2', '3'),
'cv_scores': np.array([0.97435897, 0.94871795,
0.91666667, 1.]),
'avg_score': 0.9599358974358974}}
dict_compare_utility(d1=expect, d2=efs1.subsets_)
assert efs1.best_idx_ == (0, 2, 3)
assert round(efs1.best_score_, 4) == 0.9599
assert efs1.k_feature_idx_ == (0, 2, 3)
assert round(efs1.k_score_, 4) == 0.9599


def test_regression():
Expand All @@ -209,8 +230,8 @@ def test_regression():
cv=10,
print_progress=False)
efs_r = efs_r.fit(X, y)
assert efs_r.best_idx_ == (0, 2, 4)
assert round(efs_r.best_score_, 4) == -40.8777
assert efs_r.k_feature_idx_ == (0, 2, 4)
assert round(efs_r.k_score_, 4) == -40.8777


def test_clone_params_fail():
Expand Down Expand Up @@ -305,7 +326,7 @@ def test_clone_params_pass():
print_progress=False,
n_jobs=1)
efs1 = efs1.fit(X, y)
assert(efs1.best_idx_ == (1, 3))
assert(efs1.k_feature_idx_ == (1, 3))


def test_transform_not_fitted():
Expand Down

0 comments on commit 8787033

Please sign in to comment.