Skip to content

Commit

Permalink
FIX Use take instead of choose in compute_sample_weight (#12165)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed Oct 15, 2018
1 parent cec0fba commit e2a7b31
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.20.rst
Expand Up @@ -17,6 +17,10 @@ enhancements to features released in 0.20.0.
``n_jobs > 1``.
:issue:`12159` by :user:`Olivier Grisel <ogrisel>`.

- |Fix| Fixed a bug mostly affecting :class:`ensemble.RandomForestClassifier`
where ``class_weight='balanced_subsample'`` failed with more than 32 classes.
:issue:`12165` by `Joel Nothman`_.

- |Fix| :func:`linear_model.SGDClassifier` and variants
with ``early_stopping=True`` would not use a consistent validation
split in the multiclass case and this would cause a crash when using
Expand Down
12 changes: 6 additions & 6 deletions sklearn/utils/class_weight.py
Expand Up @@ -150,12 +150,12 @@ def compute_sample_weight(class_weight, y, indices=None):
y_subsample = y[indices, k]
classes_subsample = np.unique(y_subsample)

weight_k = np.choose(np.searchsorted(classes_subsample,
classes_full),
compute_class_weight(class_weight_k,
classes_subsample,
y_subsample),
mode='clip')
weight_k = np.take(compute_class_weight(class_weight_k,
classes_subsample,
y_subsample),
np.searchsorted(classes_subsample,
classes_full),
mode='clip')

classes_missing = set(classes_full) - set(classes_subsample)
else:
Expand Down
8 changes: 8 additions & 0 deletions sklearn/utils/tests/test_class_weight.py
Expand Up @@ -251,3 +251,11 @@ def test_compute_sample_weight_errors():

# Incorrect length list for multi-output
assert_raises(ValueError, compute_sample_weight, [{1: 2, 2: 1}], y_)


def test_compute_sample_weight_more_than_32():
# Non-regression smoke test for #12146
y = np.arange(50) # more than 32 distinct classes
indices = np.arange(50) # use subsampling
weight = compute_sample_weight('balanced', y, indices=indices)
assert_array_almost_equal(weight, np.ones(y.shape[0]))

0 comments on commit e2a7b31

Please sign in to comment.