Skip to content

Commit

Permalink
Merge branch 'master' into randomholdoutsplit
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Sep 24, 2018
2 parents b0804b0 + c55d849 commit 858b3a4
Show file tree
Hide file tree
Showing 11 changed files with 469 additions and 31 deletions.
1 change: 1 addition & 0 deletions docs/sources/CHANGELOG.md
Expand Up @@ -17,6 +17,7 @@ The CHANGELOG for the current development version is available at
##### New Features

- Added a `scatterplotmatrix` function to the `plotting` module. ([#437](https://github.com/rasbt/mlxtend/pull/437))
- Added `sample_weight` option to `StackingRegressor`, `StackingClassifier`, `StackingCVRegressor`, `StackingCVClassifier`, `EnsembleVoteClassifier`. ([#438](https://github.com/rasbt/mlxtend/issues/438))
- Added a `RandomHoldoutSplit` class to perform a train/valid split without rotation in `SequentialFeatureSelector`, scikit-learn `GridSearchCV` etc. ([#442](https://github.com/rasbt/mlxtend/pull/442))

##### Changes
Expand Down
14 changes: 12 additions & 2 deletions mlxtend/classifier/ensemble_vote.py
Expand Up @@ -106,7 +106,7 @@ def __init__(self, clfs, voting='hard',
self.verbose = verbose
self.refit = refit

def fit(self, X, y):
def fit(self, X, y, sample_weight=None):
"""Learn weight coefficients from training data for each classifier.
Parameters
Expand All @@ -118,6 +118,12 @@ def fit(self, X, y):
y : array-like, shape = [n_samples]
Target values.
sample_weight : array-like, shape = [n_samples], optional
Sample weights passed as sample_weights to each regressor
in the regressors list as well as the meta_regressor.
Raises error if some regressor does not support
sample_weight in the fit() method.
Returns
-------
self : object
Expand Down Expand Up @@ -164,7 +170,11 @@ def fit(self, X, y):
if self.verbose > 1:
print(_name_estimators((clf,))[0][1])

clf.fit(X, self.le_.transform(y))
if sample_weight is None:
clf.fit(X, self.le_.transform(y))
else:
clf.fit(X, self.le_.transform(y),
sample_weight=sample_weight)
return self

def predict(self, X):
Expand Down
24 changes: 18 additions & 6 deletions mlxtend/classifier/stacking_classification.py
Expand Up @@ -106,7 +106,7 @@ def __init__(self, classifiers, meta_classifier,
self.store_train_meta_features = store_train_meta_features
self.use_clones = use_clones

def fit(self, X, y):
def fit(self, X, y, sample_weight=None):
""" Fit ensemble classifers and the meta-classifier.
Parameters
Expand All @@ -116,6 +116,11 @@ def fit(self, X, y):
n_features is the number of features.
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
Target values.
sample_weight : array-like, shape = [n_samples], optional
Sample weights passed as sample_weights to each regressor
in the regressors list as well as the meta_regressor.
Raises error if some regressor does not support
sample_weight in the fit() method.
Returns
-------
Expand Down Expand Up @@ -145,20 +150,27 @@ def fit(self, X, y):

if self.verbose > 1:
print(_name_estimators((clf,))[0][1])

clf.fit(X, y)
if sample_weight is None:
clf.fit(X, y)
else:
clf.fit(X, y, sample_weight=sample_weight)

meta_features = self.predict_meta_features(X)

if self.store_train_meta_features:
self.train_meta_features_ = meta_features

if not self.use_features_in_secondary:
self.meta_clf_.fit(meta_features, y)
pass
elif sparse.issparse(X):
self.meta_clf_.fit(sparse.hstack((X, meta_features)), y)
meta_features = sparse.hstack((X, meta_features))
else:
meta_features = np.hstack((X, meta_features))

if sample_weight is None:
self.meta_clf_.fit(meta_features, y)
else:
self.meta_clf_.fit(np.hstack((X, meta_features)), y)
self.meta_clf_.fit(meta_features, y, sample_weight=sample_weight)

return self

Expand Down
36 changes: 26 additions & 10 deletions mlxtend/classifier/stacking_cv_classification.py
Expand Up @@ -141,7 +141,7 @@ def __init__(self, classifiers, meta_classifier,
self.store_train_meta_features = store_train_meta_features
self.use_clones = use_clones

def fit(self, X, y, groups=None):
def fit(self, X, y, groups=None, sample_weight=None):
""" Fit ensemble classifers and the meta-classifier.
Parameters
Expand All @@ -157,6 +157,12 @@ def fit(self, X, y, groups=None):
The group that each sample belongs to. This is used by specific
folding strategies such as GroupKFold()
sample_weight : array-like, shape = [n_samples], optional
Sample weights passed as sample_weights to each regressor
in the regressors list as well as the meta_regressor.
Raises error if some regressor does not support
sample_weight in the fit() method.
Returns
-------
self : object
Expand Down Expand Up @@ -206,7 +212,11 @@ def fit(self, X, y, groups=None):
((num + 1), final_cv.get_n_splits()))

try:
model.fit(X[train_index], y[train_index])
if sample_weight is None:
model.fit(X[train_index], y[train_index])
else:
model.fit(X[train_index], y[train_index],
sample_weight=sample_weight[train_index])
except TypeError as e:

if str(e).startswith('A sparse matrix was passed,'
Expand Down Expand Up @@ -279,19 +289,25 @@ def fit(self, X, y, groups=None):

# Fit the base models correctly this time using ALL the training set
for model in self.clfs_:
model.fit(X, y)
if sample_weight is None:
model.fit(X, y)
else:
model.fit(X, y, sample_weight=sample_weight)

# Fit the secondary model
if not self.use_features_in_secondary:
self.meta_clf_.fit(all_model_predictions, reordered_labels)
meta_features = all_model_predictions
elif sparse.issparse(X):
self.meta_clf_.fit(sparse.hstack((reordered_features,
all_model_predictions)),
reordered_labels)
meta_features = sparse.hstack((reordered_features,
all_model_predictions))
else:
meta_features = np.hstack((reordered_features,
all_model_predictions))
if sample_weight is None:
self.meta_clf_.fit(meta_features, reordered_labels)
else:
self.meta_clf_.fit(np.hstack((reordered_features,
all_model_predictions)),
reordered_labels)
self.meta_clf_.fit(meta_features, reordered_labels,
sample_weight=sample_weight)

return self

Expand Down
57 changes: 57 additions & 0 deletions mlxtend/classifier/tests/test_ensemble_vote_classifier.py
Expand Up @@ -4,6 +4,7 @@
#
# License: BSD 3 clause

import random
from mlxtend.classifier import EnsembleVoteClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
Expand All @@ -14,6 +15,7 @@
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.base import clone
from nose.tools import raises


iris = datasets.load_iris()
Expand All @@ -37,6 +39,61 @@ def test_EnsembleVoteClassifier():
assert(scores_mean == 0.94)


def test_sample_weight():
# with no weight
np.random.seed(123)
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = GaussianNB()
eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3], voting='hard')
prob1 = eclf.fit(X, y).predict_proba(X)

# with weight = 1
w = np.ones(len(y))
np.random.seed(123)
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = GaussianNB()
eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3], voting='hard')
prob2 = eclf.fit(X, y, sample_weight=w).predict_proba(X)

# with random weight
random.seed(87)
w = np.array([random.random() for _ in range(len(y))])
np.random.seed(123)
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = GaussianNB()
eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3], voting='hard')
prob3 = eclf.fit(X, y, sample_weight=w).predict_proba(X)

diff12 = np.max(np.abs(prob1 - prob2))
diff23 = np.max(np.abs(prob2 - prob3))
assert diff12 < 1e-3, "max diff is %.4f" % diff12
assert diff23 > 1e-3, "max diff is %.4f" % diff23


@raises(TypeError)
def test_no_weight_support():
random.seed(87)
w = np.array([random.random() for _ in range(len(y))])
logi = LogisticRegression()
rf = RandomForestClassifier()
gnb = GaussianNB()
knn = KNeighborsClassifier()
eclf = EnsembleVoteClassifier(clfs=[logi, rf, gnb, knn], voting='hard')
eclf.fit(X, y, sample_weight=w)


def test_no_weight_support_with_no_weight():
logi = LogisticRegression()
rf = RandomForestClassifier()
gnb = GaussianNB()
knn = KNeighborsClassifier()
eclf = EnsembleVoteClassifier(clfs=[logi, rf, gnb, knn], voting='hard')
eclf.fit(X, y)


def test_1model_labels():
clf = LogisticRegression(multi_class='multinomial',
solver='newton-cg', random_state=123)
Expand Down
66 changes: 66 additions & 0 deletions mlxtend/classifier/tests/test_stacking_classifier.py
Expand Up @@ -4,6 +4,7 @@
#
# License: BSD 3 clause

import random
from mlxtend.classifier import StackingClassifier
from mlxtend.externals.estimator_checks import NotFittedError
from scipy import sparse
Expand All @@ -19,6 +20,7 @@
from nose.tools import assert_almost_equal
from sklearn.model_selection import train_test_split
from sklearn.base import clone
from nose.tools import raises


iris = datasets.load_iris()
Expand All @@ -43,6 +45,70 @@ def test_StackingClassifier():
assert scores_mean == 0.95


def test_sample_weight():
# Make sure that:
# prediction with weight
# != prediction with no weight
# == prediction with weight ones
random.seed(87)
w = np.array([random.random() for _ in range(len(y))])

np.random.seed(123)
meta = LogisticRegression()
clf1 = RandomForestClassifier()
clf2 = GaussianNB()
sclf = StackingClassifier(classifiers=[clf1, clf2],
meta_classifier=meta)
prob1 = sclf.fit(X, y, sample_weight=w).predict_proba(X)

np.random.seed(123)
meta = LogisticRegression()
clf1 = RandomForestClassifier()
clf2 = GaussianNB()
sclf = StackingClassifier(classifiers=[clf1, clf2],
meta_classifier=meta)
prob2 = sclf.fit(X, y, sample_weight=None).predict_proba(X)

maxdiff = np.max(np.abs(prob1 - prob2))
assert maxdiff > 1e-3, "max diff is %.4f" % maxdiff

np.random.seed(123)
meta = LogisticRegression()
clf1 = RandomForestClassifier()
clf2 = GaussianNB()
sclf = StackingClassifier(classifiers=[clf1, clf2],
meta_classifier=meta)
prob3 = sclf.fit(X, y, sample_weight=np.ones(len(y))).predict_proba(X)

maxdiff = np.max(np.abs(prob2 - prob3))
assert maxdiff < 1e-3, "max diff is %.4f" % maxdiff


@raises(TypeError)
def test_weight_unsupported():
# Error since KNN does not support sample_weight
meta = LogisticRegression()
clf1 = RandomForestClassifier()
clf2 = GaussianNB()
clf3 = KNeighborsClassifier()
sclf = StackingClassifier(classifiers=[clf1, clf2, clf3],
meta_classifier=meta)
random.seed(87)
w = np.array([random.random() for _ in range(len(y))])
sclf.fit(X, y, sample_seight=w)


def test_weight_unsupported_no_weight():
# This is okay since we do not pass sample weight
meta = LogisticRegression()
clf1 = RandomForestClassifier()
clf2 = GaussianNB()
clf3 = KNeighborsClassifier()
sclf = StackingClassifier(classifiers=[clf1, clf2, clf3],
meta_classifier=meta)
sclf.fit(X, y)


def test_StackingClassifier_proba_avg_1():

np.random.seed(123)
Expand Down

0 comments on commit 858b3a4

Please sign in to comment.