Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
small refactor + test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
a-n-ermakov committed Mar 15, 2019
1 parent bfa331f commit be7decf
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
16 changes: 8 additions & 8 deletions stability_selection/stability_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,16 +410,16 @@ def get_support(self, indices=False, threshold=None, max_features=None):
'got %s' % (n_features, max_features))

threshold_cutoff = self.threshold if threshold is None else threshold
max_features_cutoff = self.max_features if max_features is None else max_features
mask = (self.stability_scores_.max(axis=1) > threshold_cutoff)

if max_features_cutoff is None:
mask = (self.stability_scores_.max(axis=1) > threshold_cutoff)
else:
max_features_cutoff = self.max_features if max_features is None else max_features
if max_features_cutoff is not None:
exceed_counts = (self.stability_scores_ > threshold_cutoff).sum(axis=1)
max_features_cutoff = min(max_features_cutoff, (exceed_counts > 0).sum())
feature_indices = (-exceed_counts).argsort()[:max_features_cutoff]
mask = np.zeros(n_features, dtype=np.bool)
mask[feature_indices] = True
if max_features_cutoff < (exceed_counts > 0).sum():
feature_indices = (-exceed_counts).argsort()[:max_features_cutoff]
mask = np.zeros(n_features, dtype=np.bool)
mask[feature_indices] = True

return mask if not indices else np.where(mask)[0]

def transform(self, X, threshold=None):
Expand Down
50 changes: 31 additions & 19 deletions stability_selection/tests/test_stability_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _generate_dummy_classification_data(p=1000, n=1000, k=5,
def test_stability_selection_classification():
n, p, k = 1000, 1000, 5

X, y, important_betas = _generate_dummy_classification_data(n=n, k=k)
X, y, important_betas = _generate_dummy_classification_data(n=n, p=p, k=k)
selector = StabilitySelection(lambda_grid=np.logspace(-5, -1, 25), verbose=1)
selector.fit(X, y)

Expand All @@ -59,7 +59,7 @@ def test_stability_selection_classification():
def test_stability_selection_regression():
n, p, k = 500, 1000, 5

X, y, important_betas = _generate_dummy_regression_data(n=n, k=k)
X, y, important_betas = _generate_dummy_regression_data(n=n, p=p, k=k)

base_estimator = Pipeline([
('scaler', StandardScaler()),
Expand All @@ -81,7 +81,7 @@ def test_stability_selection_regression():
def test_with_complementary_pairs_bootstrap():
n, p, k = 500, 1000, 5

X, y, important_betas = _generate_dummy_regression_data(n=n, k=k)
X, y, important_betas = _generate_dummy_regression_data(n=n, p=p, k=k)

base_estimator = Pipeline([
('scaler', StandardScaler()),
Expand All @@ -104,7 +104,7 @@ def test_with_complementary_pairs_bootstrap():
def test_with_stratified_bootstrap():
n, p, k = 1000, 1000, 5

X, y, important_betas = _generate_dummy_classification_data(n=n, k=k)
X, y, important_betas = _generate_dummy_classification_data(n=n, p=p, k=k)
selector = StabilitySelection(lambda_grid=np.logspace(-5, -1, 25), verbose=1,
bootstrap_func='stratified')
selector.fit(X, y)
Expand All @@ -117,7 +117,7 @@ def test_with_stratified_bootstrap():
def test_different_shape():
n, p, k = 100, 200, 5

X, y, important_betas = _generate_dummy_regression_data(n=n, k=k)
X, y, important_betas = _generate_dummy_regression_data(n=n, p=p, k=k)

base_estimator = Pipeline([
('scaler', StandardScaler()),
Expand All @@ -136,7 +136,7 @@ def test_different_shape():
def test_no_features():
n, p, k = 100, 200, 0

X, y, important_betas = _generate_dummy_regression_data(n=n, k=k)
X, y, important_betas = _generate_dummy_regression_data(n=n, p=p, k=k)

base_estimator = Pipeline([
('scaler', StandardScaler()),
Expand All @@ -155,38 +155,50 @@ def test_no_features():


def test_stability_selection_max_features():
n, p, k = 1000, 1000, 5
n, p, k = 2000, 100, 5
lambda_grid=np.logspace(-5, -1, 25)

X, y, important_betas = _generate_dummy_classification_data(n=n, k=k)
selector = StabilitySelection(lambda_grid=lambda_grid,
max_features=1,
verbose=0)
X, y, important_betas = _generate_dummy_classification_data(n=n, p=p, k=k)
selector = StabilitySelection(lambda_grid=lambda_grid, max_features=1)
selector.fit(X, y)
X_r = selector.transform(X)
assert(X_r.shape == (n, 1))

selector = StabilitySelection(lambda_grid=lambda_grid,
max_features=k,
verbose=0)
selector = StabilitySelection(lambda_grid=lambda_grid, max_features=k)
selector.fit(X, y)
X_r = selector.transform(X)
assert(X_r.shape == (n, k))

selector = StabilitySelection(lambda_grid=lambda_grid,
max_features=k+1,
verbose=0)
selector = StabilitySelection(lambda_grid=lambda_grid, max_features=k+1)
selector.fit(X, y)
X_r = selector.transform(X)
assert(X_r.shape == (n, k))

print('ok')

@raises(ValueError)
def test_get_support_max_features_low():
n, p, k = 500, 200, 5

X, y, important_betas = _generate_dummy_classification_data(n=n, p=p, k=k)
selector = StabilitySelection(lambda_grid=np.logspace(-5, -1, 25))
selector.fit(X, y)
selector.get_support(max_features=0)


@raises(ValueError)
def test_get_support_max_features_high():
n, p, k = 500, 200, 5

X, y, important_betas = _generate_dummy_classification_data(n=n, p=p, k=k)
selector = StabilitySelection(lambda_grid=np.logspace(-5, -1, 25))
selector.fit(X, y)
selector.get_support(max_features=p+1)


def test_stability_plot():
n, p, k = 500, 200, 5

X, y, important_betas = _generate_dummy_regression_data(n=n, k=k)
X, y, important_betas = _generate_dummy_regression_data(n=n, p=p, k=k)

base_estimator = Pipeline([
('scaler', StandardScaler()),
Expand Down

0 comments on commit be7decf

Please sign in to comment.