Skip to content

Commit

Permalink
some formatting changes, e.g., sort imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Ni committed Mar 12, 2020
1 parent 2e896dc commit 2aad039
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
30 changes: 16 additions & 14 deletions mlxtend/classifier/ensemble_vote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
#
# License: BSD 3 clause

from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.base import TransformerMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.base import clone
import numpy as np
from sklearn.base import (BaseEstimator, ClassifierMixin, TransformerMixin,
clone)
from sklearn.exceptions import NotFittedError
from ..externals.name_estimators import _name_estimators
from sklearn.preprocessing import LabelEncoder

from ..externals import six
import numpy as np
from ..externals.name_estimators import _name_estimators


class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
Expand Down Expand Up @@ -61,7 +60,7 @@ class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
fit_base_estimators : bool (default: True)
Refits classifiers in `clfs` if True; uses references to the `clfs`,
otherwise (assumes that the classifiers were already fit).
Note: fit_base_estimators=False will enforce use_clones to be False,
Note: fit_base_estimators=False will enforce use_clones to be False,
and is incompatible to most scikit-learn wrappers!
For instance, if any form of cross-validation is performed
this would require the re-fitting classifiers to training folds, which
Expand Down Expand Up @@ -109,8 +108,10 @@ class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
http://rasbt.github.io/mlxtend/user_guide/classifier/EnsembleVoteClassifier/
"""

def __init__(self, clfs, voting='hard', weights=None,
verbose=0, use_clones=True, fit_base_estimators=True):
def __init__(self, clfs, voting='hard',
weights=None, verbose=0,
use_clones=True,
fit_base_estimators=True):

self.clfs = clfs
self.named_clfs = {key: value for key, value in _name_estimators(clfs)}
Expand Down Expand Up @@ -161,6 +162,7 @@ def fit(self, X, y, sample_weight=None):
self.classes_ = self.le_.classes_

if not self.fit_base_estimators:
print('Warning: enforce use_clones to be False')
self.use_clones = False

if self.use_clones:
Expand Down Expand Up @@ -221,8 +223,8 @@ def predict(self, X):
predictions = self._predict(X)

maj = np.apply_along_axis(lambda x:
np.argmax(np.bincount(x,
weights=self.weights)),
np.argmax(np.bincount(
x, weights=self.weights)),
axis=1,
arr=predictions)

Expand Down Expand Up @@ -283,8 +285,8 @@ def get_params(self, deep=True):
for key, value in six.iteritems(step.get_params(deep=True)):
out['%s__%s' % (name, key)] = value

for key, value in six.iteritems(super(EnsembleVoteClassifier,
self).get_params(deep=False)):
for key, value in six.iteritems(
super(EnsembleVoteClassifier, self).get_params(deep=False)):
out['%s' % key] = value
return out

Expand Down
8 changes: 4 additions & 4 deletions mlxtend/classifier/stacking_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
#
# License: BSD 3 clause

import numpy as np
from scipy import sparse
from sklearn.base import TransformerMixin, clone

from ..externals.estimator_checks import check_is_fitted
from ..externals.name_estimators import _name_estimators
from ..utils.base_compostion import _BaseXComposition
from ._base_classification import _BaseStackingClassifier
from scipy import sparse
from sklearn.base import TransformerMixin
from sklearn.base import clone
import numpy as np


class StackingClassifier(_BaseXComposition, _BaseStackingClassifier,
Expand Down
13 changes: 7 additions & 6 deletions mlxtend/classifier/stacking_cv_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
#
# License: BSD 3 clause

from ..externals.name_estimators import _name_estimators
from ..externals.estimator_checks import check_is_fitted
from ..utils.base_compostion import _BaseXComposition
from ._base_classification import _BaseStackingClassifier
import numpy as np
from scipy import sparse
from sklearn.base import TransformerMixin
from sklearn.base import clone
from sklearn.base import TransformerMixin, clone
from sklearn.model_selection import cross_val_predict
from sklearn.model_selection._split import check_cv

from ..externals.estimator_checks import check_is_fitted
from ..externals.name_estimators import _name_estimators
from ..utils.base_compostion import _BaseXComposition
from ._base_classification import _BaseStackingClassifier

# from sklearn.utils import check_X_y


Expand Down
1 change: 0 additions & 1 deletion mlxtend/classifier/tests/test_stacking_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def test_not_fitted():
clf2 = GaussianNB()
sclf = StackingClassifier(classifiers=[clf1, clf2],
use_probas=True,
use_clones=True,
meta_classifier=meta)

X, _ = iris_data()
Expand Down
6 changes: 3 additions & 3 deletions mlxtend/classifier/tests/test_stacking_cv_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ def test_meta_feat_reordering():
shuffle=True,
random_state=42,
store_train_meta_features=True)
X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast,
random_state=0,
test_size=0.3)
X_train, _, y_train, _ = train_test_split(X_breast, y_breast,
random_state=0,
test_size=0.3)
stclf.fit(X_train, y_train)

if Version(sklearn_version) < Version("0.21"):
Expand Down

0 comments on commit 2aad039

Please sign in to comment.