Skip to content

Commit 0782254

Browse files
mlewis1729paulha
authored andcommitted
[MRG+1] fixed OOB_Score bug for bagging classifiers. (scikit-learn#8936)
* fixed OOB_Score bug for bagging slassifiers. See: scikit-learn#8933 * Added white space * more white space fixing * Adding test for oob_score validity * removing pandas, replacing with numpy matrices * fixing white space * more white space fixing * white space ... * fixed labels to allow for strings * white space * simplifying test * white space * reformatting test * white space * pressed enter at end of file * removing line at end of file
1 parent 5590312 commit 0782254

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

sklearn/ensemble/bagging.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,7 @@ def _set_oob_score(self, X, y):
608608

609609
oob_decision_function = (predictions /
610610
predictions.sum(axis=1)[:, np.newaxis])
611-
oob_score = accuracy_score(y, classes_.take(np.argmax(predictions,
612-
axis=1)))
611+
oob_score = accuracy_score(y, np.argmax(predictions, axis=1))
613612

614613
self.oob_decision_function_ = oob_decision_function
615614
self.oob_score_ = oob_score

sklearn/ensemble/tests/test_bagging.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,20 @@ def test_max_samples_consistency():
723723
max_features=0.5, random_state=1)
724724
bagging.fit(X, y)
725725
assert_equal(bagging._max_samples, max_samples)
726+
727+
728+
def test_set_oob_score_label_encoding():
729+
# Make sure the oob_score doesn't change when the labels change
730+
# See: https://github.com/scikit-learn/scikit-learn/issues/8933
731+
randState = 5
732+
X = [[-1], [0], [1]] * 5
733+
Y1 = ['A', 'B', 'C'] * 5
734+
Y2 = [-1, 0, 1] * 5
735+
Y3 = [0, 1, 2] * 5
736+
x1 = BaggingClassifier(oob_score=True,
737+
random_state=randState).fit(X, Y1).oob_score_
738+
x2 = BaggingClassifier(oob_score=True,
739+
random_state=randState).fit(X, Y2).oob_score_
740+
x3 = BaggingClassifier(oob_score=True,
741+
random_state=randState).fit(X, Y3).oob_score_
742+
assert_equal([x1, x2], [x3, x3])

0 commit comments

Comments
 (0)