Skip to content

Commit

Permalink
fix str label issue in ensemblevoteclassifier with refit=false
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jan 18, 2018
1 parent 2a5fd8d commit 65663b3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/sources/CHANGELOG.md
Expand Up @@ -26,7 +26,7 @@ The CHANGELOG for the current development version is available at

##### Bug Fixes

- -
- Fixed issue when class labels were provided to the `EnsembleVoteClassifier` when `refit` was set to `false`. ([#322](https://github.com/rasbt/mlxtend/issues/322))



Expand Down
7 changes: 6 additions & 1 deletion mlxtend/classifier/ensemble_vote.py
Expand Up @@ -260,7 +260,12 @@ def get_params(self, deep=True):

def _predict(self, X):
"""Collect results from clf.predict calls."""
return np.asarray([clf.predict(X) for clf in self.clfs_]).T

if self.refit:
return np.asarray([clf.predict(X) for clf in self.clfs_]).T
else:
return np.asarray([self.le_.transform(clf.predict(X))
for clf in self.clfs_]).T

def _predict_probas(self, X):
"""Collect results from clf.predict_proba calls."""
Expand Down
31 changes: 31 additions & 0 deletions mlxtend/classifier/tests/test_ensemble_vote_classifier.py
Expand Up @@ -164,3 +164,34 @@ def test_string_labels_python_list():
scoring='accuracy')
scores_mean = (round(scores.mean(), 2))
assert(scores_mean == 0.94)


def test_string_labels_refit_false():
np.random.seed(123)
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = GaussianNB()

y_str = y.copy()
y_str = y_str.astype(str)
y_str[:50] = 'a'
y_str[50:100] = 'b'
y_str[100:150] = 'c'

clf1.fit(X, y_str)
clf2.fit(X, y_str)
clf3.fit(X, y_str)

eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3],
voting='hard',
refit=False)

eclf.fit(X, y_str)
assert round(eclf.score(X, y_str), 2) == 0.97

eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3],
voting='soft',
refit=False)

eclf.fit(X, y_str)
assert round(eclf.score(X, y_str), 2) == 0.97

0 comments on commit 65663b3

Please sign in to comment.