Skip to content

Commit

Permalink
Merge pull request #354 from zdgriffith/efs-fit-params
Browse files Browse the repository at this point in the history
Adds fit_params support for ExhaustiveFeatureSelector
  • Loading branch information
rasbt committed Mar 27, 2018
2 parents c019c87 + d1ef89f commit 9c8529a
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 20 deletions.
5 changes: 4 additions & 1 deletion docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ The CHANGELOG for the current development version is available at
##### New Features


The fit method of the SequentialFeatureSelector now optionally accepts **fit_params for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith)
- The fit method of the ExhaustiveFeatureSelector now optionally accepts
**fit_params for the estimator that is used for the feature selection. ([#354](https://github.com/rasbt/mlxtend/pull/354) by Zach Griffith)
- The fit method of the SequentialFeatureSelector now optionally accepts
**fit_params for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith)


- -
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@
"\n",
"<hr>\n",
"\n",
"*fit(X, y)*\n",
"*fit(X, y, **fit_params)*\n",
"\n",
"Perform feature selection and learn model from training data.\n",
"\n",
Expand All @@ -1681,14 +1681,18 @@
"\n",
" Target values.\n",
"\n",
"- `fit_params` : dict of string -> object, optional\n",
"\n",
" Parameters to pass to to the fit method of classifier.\n",
"\n",
"**Returns**\n",
"\n",
"- `self` : object\n",
"\n",
"\n",
"<hr>\n",
"\n",
"*fit_transform(X, y)*\n",
"*fit_transform(X, y, **fit_params)*\n",
"\n",
"Fit to training data and return the best selected features from X.\n",
"\n",
Expand All @@ -1699,6 +1703,14 @@
" Training vectors, where n_samples is the number of samples and\n",
" n_features is the number of features.\n",
"\n",
"- `y` : array-like, shape = [n_samples]\n",
"\n",
" Target values.\n",
"\n",
"- `fit_params` : dict of string -> object, optional\n",
"\n",
" Parameters to pass to to the fit method of classifier.\n",
"\n",
"**Returns**\n",
"\n",
"Feature subset of X, shape={n_samples, k_features}\n",
Expand Down Expand Up @@ -1815,7 +1827,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
"version": "3.6.3"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,10 @@
" Training vectors, where n_samples is the number of samples and\n",
" n_features is the number of features.\n",
"\n",
"- `y` : array-like, shape = [n_samples]\n",
"\n",
" Target values.\n",
"\n",
"- `fit_params` : dict of string -> object, optional\n",
"\n",
" Parameters to pass to to the fit method of classifier.\n",
Expand Down
40 changes: 24 additions & 16 deletions mlxtend/feature_selection/exhaustive_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@
from sklearn.externals.joblib import Parallel, delayed


def _calc_score(selector, X, y, indices):
def _calc_score(selector, X, y, indices, **fit_params):
if selector.cv:
scores = cross_val_score(selector.est_,
X[:, indices], y,
cv=selector.cv,
scoring=selector.scorer,
n_jobs=1,
pre_dispatch=selector.pre_dispatch)
pre_dispatch=selector.pre_dispatch,
fit_params=fit_params)
else:
selector.est_.fit(X[:, indices], y)
selector.est_.fit(X[:, indices], y, **fit_params)
scores = np.array([selector.scorer(selector.est_, X[:, indices], y)])
return indices, scores

Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(self, estimator, min_features=1, max_features=1,
self.est_ = self.estimator
self.fitted = False

def fit(self, X, y):
def fit(self, X, y, **fit_params):
"""Perform feature selection and learn model from training data.
Parameters
Expand All @@ -137,6 +138,8 @@ def fit(self, X, y):
n_features is the number of features.
y : array-like, shape = [n_samples]
Target values.
fit_params : dict of string -> object, optional
Parameters to pass to to the fit method of classifier.
Returns
-------
Expand All @@ -160,41 +163,42 @@ def fit(self, X, y):
raise AttributeError('min_features must be <= max_features')

candidates = chain(*((combinations(range(X.shape[1]), r=i))
for i in range(self.min_features,
self.max_features + 1)))
for i in range(self.min_features,
self.max_features + 1)))

self.subsets_ = {}

def ncr(n, r):
"""Return the number of combinations of length r from n items.
Parameters
----------
n : {integer}
Total number of items
r : {integer}
Number of items to select from n
Returns
-------
Number of combinations, integer
"""

r = min(r, n-r)
if r == 0:
return 1
numer = reduce(op.mul, range(n, n-r, -1))
denom = reduce(op.mul, range(1, r+1))
return numer//denom

all_comb = np.sum([ncr(n=X.shape[1], r=i)
for i in range(self.min_features,
self.max_features + 1)])

n_jobs = min(self.n_jobs, all_comb)
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=self.pre_dispatch)
work = enumerate(parallel(delayed(_calc_score)(self, X, y, c)
work = enumerate(parallel(delayed(_calc_score)
(self, X, y, c, **fit_params)
for c in candidates))

for iteration, (c, cv_scores) in work:
Expand Down Expand Up @@ -239,21 +243,25 @@ def transform(self, X):
self._check_fitted()
return X[:, self.best_idx_]

def fit_transform(self, X, y):
def fit_transform(self, X, y, **fit_params):
"""Fit to training data and return the best selected features from X.
Parameters
----------
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
Training vectors, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape = [n_samples]
Target values.
fit_params : dict of string -> object, optional
Parameters to pass to to the fit method of classifier.
Returns
-------
Feature subset of X, shape={n_samples, k_features}
"""
self.fit(X, y)
self.fit(X, y, **fit_params)
return self.transform(X)

def get_metric_dict(self, confidence_interval=0.95):
Expand Down
2 changes: 2 additions & 0 deletions mlxtend/feature_selection/sequential_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ def fit_transform(self, X, y, **fit_params):
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
Training vectors, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape = [n_samples]
Target values.
fit_params : dict of string -> object, optional
Parameters to pass to to the fit method of classifier.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from numpy.testing import assert_almost_equal
from mlxtend.feature_selection import ExhaustiveFeatureSelector as EFS
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from mlxtend.classifier import SoftmaxRegression
from sklearn.datasets import load_iris
Expand Down Expand Up @@ -164,6 +165,40 @@ def test_knn_cv3():
assert round(efs1.best_score_, 4) == 0.9728


def test_fit_params():
iris = load_iris()
X = iris.data
y = iris.target
sample_weight = np.ones(X.shape[0])
forest = RandomForestClassifier(n_estimators=100, random_state=123)
efs1 = EFS(forest,
min_features=3,
max_features=3,
scoring='accuracy',
cv=4,
print_progress=False)
efs1 = efs1.fit(X, y, sample_weight=sample_weight)
expect = {0: {'feature_idx': (0, 1, 2),
'cv_scores': np.array([0.94871795, 0.92307692,
0.91666667, 0.97222222]),
'avg_score': 0.9401709401709402},
1: {'feature_idx': (0, 1, 3),
'cv_scores': np.array([0.92307692, 0.92307692,
0.88888889, 1.]),
'avg_score': 0.9337606837606838},
2: {'feature_idx': (0, 2, 3),
'cv_scores': np.array([0.97435897, 0.94871795,
0.94444444, 0.97222222]),
'avg_score': 0.9599358974358974},
3: {'feature_idx': (1, 2, 3),
'cv_scores': np.array([0.97435897, 0.94871795,
0.91666667, 1.]),
'avg_score': 0.9599358974358974}}
dict_compare_utility(d1=expect, d2=efs1.subsets_)
assert efs1.best_idx_ == (0, 2, 3)
assert round(efs1.best_score_, 4) == 0.9599


def test_regression():
boston = load_boston()
X, y = boston.data[:, [1, 2, 6, 8, 12]], boston.target
Expand Down

0 comments on commit 9c8529a

Please sign in to comment.