Skip to content

Commit

Permalink
[MRG+2] Norm inconsistency between RFE and SelectFromModel (was _Lear…
Browse files Browse the repository at this point in the history
…ntSelectorMixin) #2121 (#6181)

* Norm inconsistency between RFE and SelectFromModel (was _LearntSelectorMixin) #2121

* safe_pwr utility

* Norm fix

* Removed safe_pwr

* 1D arrays support for norm fix

* Test case for 2d coef in SelectFromModel

* Fix numpy version requirement for norm fix

* Implement fixes suggested by @jnothman

* Add numpy version requiring the fix.
  • Loading branch information
antoinewdg authored and amueller committed Oct 24, 2016
1 parent 177ac84 commit 74a9756
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -52,6 +52,9 @@ Enhancements
(`#7506` <https://github.com/scikit-learn/scikit-learn/pull/7506>_) by
`Narine Kokhlikyan`_.

- Added ``norm_order`` parameter to :class:`feature_selection.SelectFromModel`
to enable selection of the norm order when ``coef_`` is more than 1D

Bug fixes
.........

Expand Down
16 changes: 12 additions & 4 deletions sklearn/feature_selection/from_model.py
Expand Up @@ -10,9 +10,10 @@
from ..utils import safe_mask, check_array, deprecated
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
from ..utils.fixes import norm


def _get_feature_importances(estimator):
def _get_feature_importances(estimator, norm_order=1):
"""Retrieve or aggregate feature importances from estimator"""
importances = getattr(estimator, "feature_importances_", None)

Expand All @@ -21,7 +22,7 @@ def _get_feature_importances(estimator):
importances = np.abs(estimator.coef_)

else:
importances = np.sum(np.abs(estimator.coef_), axis=0)
importances = norm(estimator.coef_, axis=0, ord=norm_order)

elif importances is None:
raise ValueError(
Expand Down Expand Up @@ -172,6 +173,11 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
Otherwise train the model using ``fit`` and then ``transform`` to do
feature selection.
norm_order : non-zero int, inf, -inf, default 1
Order of the norm used to filter the vectors of coefficients below
``threshold`` in the case where the ``coef_`` attribute of the
estimator is of dimension 2.
Attributes
----------
`estimator_`: an estimator
Expand All @@ -182,10 +188,12 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
`threshold_`: float
The threshold value used for feature selection.
"""
def __init__(self, estimator, threshold=None, prefit=False):

def __init__(self, estimator, threshold=None, prefit=False, norm_order=1):
self.estimator = estimator
self.threshold = threshold
self.prefit = prefit
self.norm_order = norm_order

def _get_support_mask(self):
# SelectFromModel can directly call on transform.
Expand All @@ -197,7 +205,7 @@ def _get_support_mask(self):
raise ValueError(
'Either fit the model before transform or set "prefit=True"'
' while passing the fitted estimator to the constructor.')
scores = _get_feature_importances(estimator)
scores = _get_feature_importances(estimator, self.norm_order)
self.threshold_ = _calculate_threshold(estimator, scores,
self.threshold)
return scores >= self.threshold_
Expand Down
26 changes: 26 additions & 0 deletions sklearn/feature_selection/tests/test_from_model.py
Expand Up @@ -17,6 +17,7 @@
from sklearn.feature_selection import SelectFromModel
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.utils.fixes import norm

iris = datasets.load_iris()
data, y = iris.data, iris.target
Expand Down Expand Up @@ -102,6 +103,31 @@ def test_feature_importances():
assert_array_equal(X_new, X[:, mask])


@skip_if_32bit
def test_feature_importances_2d_coef():
X, y = datasets.make_classification(
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, shuffle=False, random_state=0, n_classes=4)

est = LogisticRegression()
for threshold, func in zip(["mean", "median"], [np.mean, np.median]):
for order in [1, 2, np.inf]:
# Fit SelectFromModel a multi-class problem
transformer = SelectFromModel(estimator=LogisticRegression(),
threshold=threshold,
norm_order=order)
transformer.fit(X, y)
assert_true(hasattr(transformer.estimator_, 'coef_'))
X_new = transformer.transform(X)
assert_less(X_new.shape[1], X.shape[1])

# Manually check that the norm is correctly performed
est.fit(X, y)
importances = norm(est.coef_, axis=0, ord=order)
feature_mask = importances > func(importances)
assert_array_equal(X_new, X[:, feature_mask])


def test_partial_fit():
est = PassiveAggressiveClassifier(random_state=0, shuffle=False)
transformer = SelectFromModel(estimator=est)
Expand Down
30 changes: 30 additions & 0 deletions sklearn/utils/fixes.py
Expand Up @@ -419,3 +419,33 @@ def __getstate__(self):
self._fill_value)
else:
from numpy.ma import MaskedArray # noqa

if 'axis' not in signature(np.linalg.norm).parameters:

def norm(X, ord=None, axis=None):
"""
Handles the axis parameter for the norm function
in old versions of numpy (useless for numpy >= 1.8).
"""

if axis is None or X.ndim == 1:
result = np.linalg.norm(X, ord=ord)
return result

if axis not in (0, 1):
raise NotImplementedError("""
The fix that adds axis parameter to the old numpy
norm only works for 1D or 2D arrays.
""")

if axis == 0:
X = X.T

result = np.zeros(X.shape[0])
for i in range(len(result)):
result[i] = np.linalg.norm(X[i], ord=ord)

return result

else:
norm = np.linalg.norm
25 changes: 25 additions & 0 deletions sklearn/utils/tests/test_fixes.py
Expand Up @@ -5,6 +5,7 @@

import pickle
import numpy as np
import math

from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_false
Expand All @@ -16,6 +17,7 @@
from sklearn.utils.fixes import divide, expit
from sklearn.utils.fixes import astype
from sklearn.utils.fixes import MaskedArray
from sklearn.utils.fixes import norm


def test_expit():
Expand Down Expand Up @@ -66,3 +68,26 @@ def test_masked_array_obj_dtype_pickleable():
marr_pickled = pickle.loads(pickle.dumps(marr))
assert_array_equal(marr.data, marr_pickled.data)
assert_array_equal(marr.mask, marr_pickled.mask)


def test_norm():
X = np.array([[-2, 4, 5],
[1, 3, -4],
[0, 0, 8],
[0, 0, 0]]).astype(float)

# Test various axis and order
assert_equal(math.sqrt(135), norm(X))
assert_array_equal(
np.array([math.sqrt(5), math.sqrt(25), math.sqrt(105)]),
norm(X, axis=0)
)
assert_array_equal(np.array([3, 7, 17]), norm(X, axis=0, ord=1))
assert_array_equal(np.array([2, 4, 8]), norm(X, axis=0, ord=np.inf))
assert_array_equal(np.array([0, 0, 0]), norm(X, axis=0, ord=-np.inf))
assert_array_equal(np.array([11, 8, 8, 0]), norm(X, axis=1, ord=1))

# Test shapes
assert_equal((), norm(X).shape)
assert_equal((3,), norm(X, axis=0).shape)
assert_equal((4,), norm(X, axis=1).shape)

0 comments on commit 74a9756

Please sign in to comment.