Skip to content

Commit

Permalink
Adds fit_param test for StackingClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Sep 22, 2017
1 parent 6fd998d commit 27206b3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
8 changes: 4 additions & 4 deletions mlxtend/classifier/stacking_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ def fit(self, X, y, **fit_params):

meta_features = self._predict_meta_features(X)
# Extract fit_params for meta_clf_
meta_clf_fit_params = {}
meta_fit_params = {}
meta_clf_name = list(self.named_meta_clf_.keys())[0]
for key, value in six.iteritems(fit_params):
if meta_clf_name in key and 'meta-' in meta_clf_name:
meta_clf_fit_params[key.replace(meta_clf_name+'__', '')] = value
meta_fit_params[key.replace(meta_clf_name+'__', '')] = value

if not self.use_features_in_secondary:
self.meta_clf_.fit(meta_features, y, **meta_clf_fit_params)
self.meta_clf_.fit(meta_features, y, **meta_fit_params)
else:
self.meta_clf_.fit(np.hstack((X, meta_features)), y,
**meta_clf_fit_params)
**meta_fit_params)

return self

Expand Down
23 changes: 23 additions & 0 deletions mlxtend/classifier/tests/test_stacking_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ def test_StackingClassifier_proba_concat_1():
assert scores_mean == 0.93, scores_mean


def test_StackingClassifier_fit_params():
np.random.seed(123)
meta = LogisticRegression()
clf1 = RandomForestClassifier()
clf2 = GaussianNB()
sclf = StackingClassifier(classifiers=[clf1, clf2],
meta_classifier=meta)
n_samples = X.shape[0]
fit_params = {
'randomforestclassifier__sample_weight': np.ones(n_samples),
'meta-logisticregression__sample_weight': np.arange(n_samples)
}

scores = cross_val_score(sclf,
X,
y,
cv=5,
scoring='accuracy',
fit_params=fit_params)
scores_mean = (round(scores.mean(), 2))
assert scores_mean == 0.95


def test_StackingClassifier_avg_vs_concat():
np.random.seed(123)
lr1 = LogisticRegression()
Expand Down

0 comments on commit 27206b3

Please sign in to comment.