Skip to content

Commit

Permalink
Merge pull request #322 from rasbt/support-string-labels
Browse files Browse the repository at this point in the history
Add unittests for testing the EnsembleVoteClassifier with class labels in string format
  • Loading branch information
rasbt committed Jan 18, 2018
2 parents 827ba35 + 65663b3 commit cbb19df
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The CHANGELOG for the current development version is available at

##### Bug Fixes

- -
- Fixed issue when class labels were provided to the `EnsembleVoteClassifier` when `refit` was set to `false`. ([#322](https://github.com/rasbt/mlxtend/issues/322))



Expand Down
7 changes: 6 additions & 1 deletion mlxtend/classifier/ensemble_vote.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,12 @@ def get_params(self, deep=True):

def _predict(self, X):
"""Collect results from clf.predict calls."""
return np.asarray([clf.predict(X) for clf in self.clfs_]).T

if self.refit:
return np.asarray([clf.predict(X) for clf in self.clfs_]).T
else:
return np.asarray([self.le_.transform(clf.predict(X))
for clf in self.clfs_]).T

def _predict_probas(self, X):
"""Collect results from clf.predict_proba calls."""
Expand Down
73 changes: 73 additions & 0 deletions mlxtend/classifier/tests/test_ensemble_vote_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,76 @@ def test_classifier_gridsearch():
grid.fit(X, y)

assert len(grid.best_params_['clfs']) == 2


def test_string_labels_numpy_array():
np.random.seed(123)
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = GaussianNB()
eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3], voting='hard')

y_str = y.copy()
y_str = y_str.astype(str)
y_str[:50] = 'a'
y_str[50:100] = 'b'
y_str[100:150] = 'c'

scores = cross_val_score(eclf,
X,
y_str,
cv=5,
scoring='accuracy')
scores_mean = (round(scores.mean(), 2))
assert(scores_mean == 0.94)


def test_string_labels_python_list():
np.random.seed(123)
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = GaussianNB()
eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3], voting='hard')

y_str = (['a' for a in range(50)] +
['b' for a in range(50)] +
['c' for a in range(50)])

scores = cross_val_score(eclf,
X,
y_str,
cv=5,
scoring='accuracy')
scores_mean = (round(scores.mean(), 2))
assert(scores_mean == 0.94)


def test_string_labels_refit_false():
np.random.seed(123)
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = GaussianNB()

y_str = y.copy()
y_str = y_str.astype(str)
y_str[:50] = 'a'
y_str[50:100] = 'b'
y_str[100:150] = 'c'

clf1.fit(X, y_str)
clf2.fit(X, y_str)
clf3.fit(X, y_str)

eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3],
voting='hard',
refit=False)

eclf.fit(X, y_str)
assert round(eclf.score(X, y_str), 2) == 0.97

eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3],
voting='soft',
refit=False)

eclf.fit(X, y_str)
assert round(eclf.score(X, y_str), 2) == 0.97

0 comments on commit cbb19df

Please sign in to comment.