Skip to content

Commit

Permalink
Adds fit_params support for StackingClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Sep 22, 2017
1 parent 3424df6 commit 6fd998d
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions mlxtend/classifier/stacking_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(self, classifiers, meta_classifier,
self.verbose = verbose
self.use_features_in_secondary = use_features_in_secondary

def fit(self, X, y):
""" Fit ensemble classifers and the meta-classifier.
def fit(self, X, y, **fit_params):
"""Fit ensemble classifers and the meta-classifier.
Parameters
----------
Expand All @@ -87,18 +87,26 @@ def fit(self, X, y):
n_features is the number of features.
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
Target values.
fit_params : dict, optional
Parameters to pass to the fit methods of the classifiers and
meta_classifier.
Returns
-------
self : object
"""
self.clfs_ = [clone(clf) for clf in self.classifiers]
self.named_clfs_ = {key: value for key, value in
_name_estimators(self.clfs_)}
self.meta_clf_ = clone(self.meta_classifier)
self.named_meta_clf_ = {'meta-%s' % key: value for key, value in
_name_estimators([self.meta_clf_])}

if self.verbose > 0:
print("Fitting %d classifiers..." % (len(self.classifiers)))

for clf in self.clfs_:
for name, clf in six.iteritems(self.named_clfs_):

if self.verbose > 0:
i = self.clfs_.index(clf) + 1
Expand All @@ -112,14 +120,27 @@ def fit(self, X, y):
if self.verbose > 1:
print(_name_estimators((clf,))[0][1])

clf.fit(X, y)
# Extract fit_params for clf
clf_fit_params = {}
for key, value in six.iteritems(fit_params):
if name in key and 'meta-' not in key:
clf_fit_params[key.replace(name+'__', '')] = value

clf.fit(X, y, **clf_fit_params)

meta_features = self._predict_meta_features(X)
# Extract fit_params for meta_clf_
meta_clf_fit_params = {}
meta_clf_name = list(self.named_meta_clf_.keys())[0]
for key, value in six.iteritems(fit_params):
if meta_clf_name in key and 'meta-' in meta_clf_name:
meta_clf_fit_params[key.replace(meta_clf_name+'__', '')] = value

if not self.use_features_in_secondary:
self.meta_clf_.fit(meta_features, y)
self.meta_clf_.fit(meta_features, y, **meta_clf_fit_params)
else:
self.meta_clf_.fit(np.hstack((X, meta_features)), y)
self.meta_clf_.fit(np.hstack((X, meta_features)), y,
**meta_clf_fit_params)

return self

Expand Down

0 comments on commit 6fd998d

Please sign in to comment.