From 27206b381d3b50bc5182f17503ad4bff71a696b5 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 21 Sep 2017 23:53:05 -0500 Subject: [PATCH] Adds fit_param test for StackingClassifier --- mlxtend/classifier/stacking_classification.py | 8 +++---- .../tests/test_stacking_classifier.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/mlxtend/classifier/stacking_classification.py b/mlxtend/classifier/stacking_classification.py index 35dfec80e..5057b4c08 100644 --- a/mlxtend/classifier/stacking_classification.py +++ b/mlxtend/classifier/stacking_classification.py @@ -130,17 +130,17 @@ def fit(self, X, y, **fit_params): meta_features = self._predict_meta_features(X) # Extract fit_params for meta_clf_ - meta_clf_fit_params = {} + meta_fit_params = {} meta_clf_name = list(self.named_meta_clf_.keys())[0] for key, value in six.iteritems(fit_params): if meta_clf_name in key and 'meta-' in meta_clf_name: - meta_clf_fit_params[key.replace(meta_clf_name+'__', '')] = value + meta_fit_params[key.replace(meta_clf_name+'__', '')] = value if not self.use_features_in_secondary: - self.meta_clf_.fit(meta_features, y, **meta_clf_fit_params) + self.meta_clf_.fit(meta_features, y, **meta_fit_params) else: self.meta_clf_.fit(np.hstack((X, meta_features)), y, - **meta_clf_fit_params) + **meta_fit_params) return self diff --git a/mlxtend/classifier/tests/test_stacking_classifier.py b/mlxtend/classifier/tests/test_stacking_classifier.py index ca25d05b2..92cfa87d6 100644 --- a/mlxtend/classifier/tests/test_stacking_classifier.py +++ b/mlxtend/classifier/tests/test_stacking_classifier.py @@ -79,6 +79,29 @@ def test_StackingClassifier_proba_concat_1(): assert scores_mean == 0.93, scores_mean +def test_StackingClassifier_fit_params(): + np.random.seed(123) + meta = LogisticRegression() + clf1 = RandomForestClassifier() + clf2 = GaussianNB() + sclf = StackingClassifier(classifiers=[clf1, clf2], + meta_classifier=meta) + n_samples = X.shape[0] + fit_params = { + 'randomforestclassifier__sample_weight': np.ones(n_samples), + 'meta-logisticregression__sample_weight': np.arange(n_samples) + } + + scores = cross_val_score(sclf, + X, + y, + cv=5, + scoring='accuracy', + fit_params=fit_params) + scores_mean = (round(scores.mean(), 2)) + assert scores_mean == 0.95 + + def test_StackingClassifier_avg_vs_concat(): np.random.seed(123) lr1 = LogisticRegression()