Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Sample weights for median_absolute_error #17225

Merged
merged 49 commits into from May 27, 2020
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
4c472fa
weighted percentile 2d
lucyleeow May 14, 2020
9837a2a
sample weights none
lucyleeow May 14, 2020
fd4bbe5
impl in median abs err
lucyleeow May 14, 2020
0f9ef5e
lint
lucyleeow May 14, 2020
309f55f
lint
lucyleeow May 14, 2020
89d96eb
fix shape
lucyleeow May 14, 2020
564acae
fix type
lucyleeow May 14, 2020
eacc4d9
lint
lucyleeow May 14, 2020
2baa8ab
squeeze
lucyleeow May 15, 2020
8f7c07f
ndim 0 case
lucyleeow May 15, 2020
2a04c7c
merge master
lucyleeow May 15, 2020
f71509d
fix text
lucyleeow May 15, 2020
965021b
take along axis version
lucyleeow May 15, 2020
8c6fe32
use new fun
lucyleeow May 15, 2020
64b7cdf
fix format
lucyleeow May 15, 2020
b7893d2
spelling
lucyleeow May 15, 2020
30032e9
lint
lucyleeow May 15, 2020
d33bbdf
lint
lucyleeow May 15, 2020
91c2a81
fix logic
lucyleeow May 15, 2020
02d7475
add comment
lucyleeow May 15, 2020
f8a63d9
add test for multioutput
lucyleeow May 15, 2020
93ecc29
suggestions
lucyleeow May 16, 2020
82ef7b6
dont store sorted array, suggestions
lucyleeow May 18, 2020
00953a6
remove print
lucyleeow May 18, 2020
fd26fd5
remove none sampleweight
lucyleeow May 18, 2020
5b93383
fix logic
lucyleeow May 19, 2020
7995cc3
lint
lucyleeow May 19, 2020
d6a85b9
lint
lucyleeow May 19, 2020
bd6344c
add tests, mv to utils
lucyleeow May 20, 2020
fad3f2b
reg target for test
lucyleeow May 20, 2020
3ee6c25
lint
lucyleeow May 20, 2020
45d89e1
odd size in test
lucyleeow May 20, 2020
c3aa8c0
wording
lucyleeow May 20, 2020
482c40f
lint
lucyleeow May 20, 2020
149705a
fix test
lucyleeow May 20, 2020
af5e07b
better 2d test
lucyleeow May 20, 2020
8de9397
Merge branch 'master' into median_abs_err
lucyleeow May 20, 2020
a198468
whats new
lucyleeow May 20, 2020
e528f55
lint
lucyleeow May 20, 2020
f16a7b6
merge master
lucyleeow May 21, 2020
664f13f
suggestions, better docstring
lucyleeow May 26, 2020
8c52d0b
weight 1d, amend test
lucyleeow May 26, 2020
954dc1e
split reg and clas in test sample weight
lucyleeow May 26, 2020
60f67cf
lint
lucyleeow May 26, 2020
b8c93d7
use require positive
lucyleeow May 26, 2020
62fd3ff
word
lucyleeow May 26, 2020
6920f6e
suggestions
lucyleeow May 27, 2020
ba0efa2
formatting
lucyleeow May 27, 2020
cedc674
fix typos
lucyleeow May 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion doc/whats_new/v0.24.rst
Expand Up @@ -76,7 +76,13 @@ Changelog
attribute name/path or a `callable` for extracting feature importance from
the estimator. :pr:`15361` by :user:`Venkatachalam N <venkyyuvy>`


:mod:`sklearn.metrics`
......................

- |Enhancement| Add `sample_weight` parameter to
:class:`metrics.median_absolute_error`.
:pr:`17225` by :user:`Lucy Liu <lucyleeow>`.

:mod:`sklearn.tree`
...................

Expand Down
31 changes: 0 additions & 31 deletions sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py
Expand Up @@ -8,7 +8,6 @@
import pytest

from sklearn.utils import check_random_state
from sklearn.utils.stats import _weighted_percentile
from sklearn.ensemble._gb_losses import RegressionLossFunction
from sklearn.ensemble._gb_losses import LeastSquaresError
from sklearn.ensemble._gb_losses import LeastAbsoluteError
Expand Down Expand Up @@ -103,36 +102,6 @@ def test_sample_weight_init_estimators():
assert_allclose(out, sw_out, rtol=1e-2)


def test_weighted_percentile():
y = np.empty(102, dtype=np.float64)
y[:50] = 0
y[-51:] = 2
y[-1] = 100000
y[50] = 1
sw = np.ones(102, dtype=np.float64)
sw[-1] = 0.0
score = _weighted_percentile(y, sw, 50)
assert score == 1


def test_weighted_percentile_equal():
y = np.empty(102, dtype=np.float64)
y.fill(0.0)
sw = np.ones(102, dtype=np.float64)
sw[-1] = 0.0
score = _weighted_percentile(y, sw, 50)
assert score == 0


def test_weighted_percentile_zero_weight():
y = np.empty(102, dtype=np.float64)
y.fill(1.0)
sw = np.ones(102, dtype=np.float64)
sw.fill(0.0)
score = _weighted_percentile(y, sw, 50)
assert score == 1.0


def test_quantile_loss_function():
# Non regression test for the QuantileLossFunction object
# There was a sign problem when evaluating the function
Expand Down
17 changes: 15 additions & 2 deletions sklearn/metrics/_regression.py
Expand Up @@ -30,6 +30,8 @@
_num_samples)
from ..utils.validation import column_or_1d
from ..utils.validation import _deprecate_positional_args
from ..utils.validation import _check_sample_weight
from ..utils.stats import _weighted_percentile
from ..exceptions import UndefinedMetricWarning


Expand Down Expand Up @@ -335,7 +337,8 @@ def mean_squared_log_error(y_true, y_pred, *,


@_deprecate_positional_args
def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average'):
def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average',
sample_weight=None):
"""Median absolute error regression loss

Median absolute error output is non-negative floating point. The best value
Expand All @@ -360,6 +363,11 @@ def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average'):
'uniform_average' :
Errors of all outputs are averaged with uniform weight.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

.. versionadded:: 0.24

Returns
-------
loss : float or ndarray of floats
Expand Down Expand Up @@ -387,7 +395,12 @@ def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average'):
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
if sample_weight is None:
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
else:
sample_weight = _check_sample_weight(sample_weight, y_pred)
output_errors = _weighted_percentile(np.abs(y_pred - y_true),
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
lucyleeow marked this conversation as resolved.
Show resolved Hide resolved
sample_weight=sample_weight)
if isinstance(multioutput, str):
if multioutput == 'raw_values':
return output_errors
Expand Down
70 changes: 55 additions & 15 deletions sklearn/metrics/tests/test_score_objects.py
Expand Up @@ -33,7 +33,7 @@
from sklearn.linear_model import Ridge, LogisticRegression, Perceptron
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.datasets import make_blobs
from sklearn.datasets import make_classification
from sklearn.datasets import make_classification, make_regression
from sklearn.datasets import make_multilabel_classification
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split, cross_val_score
Expand Down Expand Up @@ -89,7 +89,7 @@ def _make_estimators(X_train, y_train, y_ml_train):
# Make estimators that make sense to test various scoring methods
sensible_regr = DecisionTreeRegressor(random_state=0)
# some of the regressions scorers require strictly positive input.
sensible_regr.fit(X_train, y_train + 1)
sensible_regr.fit(X_train, _require_positive_y(y_train))
sensible_clf = DecisionTreeClassifier(random_state=0)
sensible_clf.fit(X_train, y_train)
sensible_ml_clf = DecisionTreeClassifier(random_state=0)
Expand Down Expand Up @@ -474,8 +474,9 @@ def test_raises_on_score_list():


@ignore_warnings
def test_scorer_sample_weight():
# Test that scorers support sample_weight or raise sensible errors
def test_classification_scorer_sample_weight():
# Test that classification scorers support sample_weight or raise sensible
# errors

# Unlike the metrics invariance test, in the scorer case it's harder
# to ensure that, on the classifier output, weighted and unweighted
Expand All @@ -493,31 +494,70 @@ def test_scorer_sample_weight():
estimator = _make_estimators(X_train, y_train, y_ml_train)

for name, scorer in SCORERS.items():
if name in REGRESSION_SCORERS:
# skip the regression scores
continue
if name in MULTILABEL_ONLY_SCORERS:
target = y_ml_test
else:
target = y_test
if name in REQUIRE_POSITIVE_Y_SCORERS:
target = _require_positive_y(target)
try:
weighted = scorer(estimator[name], X_test, target,
sample_weight=sample_weight)
ignored = scorer(estimator[name], X_test[10:], target[10:])
unweighted = scorer(estimator[name], X_test, target)
assert weighted != unweighted, (
"scorer {0} behaves identically when "
"called with sample weights: {1} vs "
"{2}".format(name, weighted, unweighted))
f"scorer {name} behaves identically when called with "
f"sample weights: {weighted} vs {unweighted}")
assert_almost_equal(weighted, ignored,
err_msg="scorer {0} behaves differently when "
"ignoring samples and setting sample_weight to"
" 0: {1} vs {2}".format(name, weighted,
ignored))
err_msg=f"scorer {name} behaves differently "
"when ignoring samples and setting "
"sample_weight to 0: {weighted} vs {ignored}")
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

except TypeError as e:
assert "sample_weight" in str(e), (
"scorer {0} raises unhelpful exception when called "
"with sample weights: {1}".format(name, str(e)))
f"scorer {name} raises unhelpful exception when called "
f"with sample weights: {str(e)}")


@ignore_warnings
def test_regression_scorer_sample_weight():
# Test that regression scorers support sample_weight or raise sensible
# errors

# Odd number of test samples req for neg_median_absolute_error
X, y = make_regression(n_samples=101, n_features=20, random_state=0)
y = _require_positive_y(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

sample_weight = np.ones_like(y_test)
# Odd number req for neg_median_absolute_error
sample_weight[:11] = 0

reg = DecisionTreeRegressor(random_state=0)
reg.fit(X_train, y_train)

for name, scorer in SCORERS.items():
if name not in REGRESSION_SCORERS:
# skip classification scorers
continue
try:
weighted = scorer(reg, X_test, y_test,
sample_weight=sample_weight)
ignored = scorer(reg, X_test[11:], y_test[11:])
unweighted = scorer(reg, X_test, y_test)
assert weighted != unweighted, (
f"scorer {name} behaves identically when called with "
f"sample weights: {weighted} vs {unweighted}")
assert_almost_equal(weighted, ignored,
err_msg=f"scorer {name} behaves differently "
"when ignoring samples and setting "
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
"sample_weight to 0: {weighted} vs {ignored}")
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

except TypeError as e:
assert "sample_weight" in str(e), (
f"scorer {name} raises unhelpful exception when called "
f"with sample weights: {str(e)}")


@pytest.mark.parametrize('name', SCORERS)
Expand Down
36 changes: 36 additions & 0 deletions sklearn/utils/fixes.py
Expand Up @@ -165,3 +165,39 @@ class loguniform(scipy.stats.reciprocal):
)
class MaskedArray(_MaskedArray):
pass # TODO: remove in 0.25


def _take_along_axis(arr, indices, axis):
"""Implements a simplified version of np.take_along_axis if numpy
version < 1.15"""
if np_version > (1, 14):
return np.take_along_axis(arr=arr, indices=indices, axis=axis)
else:
if axis is None:
arr = arr.flatten()

if not np.issubdtype(indices.dtype, np.intp):
raise IndexError('`indices` must be an integer array')
if arr.ndim != indices.ndim:
raise ValueError(
"`indices` and `arr` must have the same number of dimensions")

shape_ones = (1,) * indices.ndim
dest_dims = (
list(range(axis)) +
[None] +
list(range(axis+1, indices.ndim))
)

# build a fancy index, consisting of orthogonal aranges, with the
# requested index inserted at the right location
fancy_index = []
for dim, n in zip(dest_dims, arr.shape):
if dim is None:
fancy_index.append(indices)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
fancy_index.append(np.arange(n).reshape(ind_shape))

fancy_index = tuple(fancy_index)
return arr[fancy_index]
61 changes: 52 additions & 9 deletions sklearn/utils/stats.py
@@ -1,18 +1,61 @@
import numpy as np

from .extmath import stable_cumsum
from .fixes import _take_along_axis


def _weighted_percentile(array, sample_weight, percentile=50):
"""Compute weighted percentile

Computes lower weighted percentile. If `array` is a 2D array, the
`percentile` is computed along the axis 0.

.. versionchanged:: 0.24
Accepts 2D `array`.

Parameters
----------
array : 1D or 2D array
Values to take the weighted percentile of.

sample_weight: 1D or 2D array
Weights for each value in `array`. Must be same shape as `array` or
of shape `(array.shape[0],)`.

percentile: int, default=50
Percentile to compute. Must be value between 0 and 100.

Returns
-------
percentile : int if `array` 1D, ndarray if `array` 2D
Weighted percentile.
"""
Compute the weighted ``percentile`` of ``array`` with ``sample_weight``.
"""
sorted_idx = np.argsort(array)
n_dim = array.ndim
if n_dim == 0:
return array[()]
if array.ndim == 1:
array = array.reshape((-1, 1))
# When sample_weight 1D, repeat for each array.shape[1]
if (array.shape != sample_weight.shape and
array.shape[0] == sample_weight.shape[0]):
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
sample_weight = np.tile(sample_weight, (array.shape[1], 1)).T
sorted_idx = np.argsort(array, axis=0)
sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0)

# Find index of median prediction for each sample
weight_cdf = stable_cumsum(sample_weight[sorted_idx])
percentile_idx = np.searchsorted(
weight_cdf, (percentile / 100.) * weight_cdf[-1])
# in rare cases, percentile_idx equals to len(sorted_idx)
percentile_idx = np.clip(percentile_idx, 0, len(sorted_idx)-1)
return array[sorted_idx[percentile_idx]]
weight_cdf = stable_cumsum(sorted_weights, axis=0)
adjusted_percentile = percentile / 100 * weight_cdf[-1]
percentile_idx = np.array([
np.searchsorted(weight_cdf[:, i], adjusted_percentile[i])
for i in range(weight_cdf.shape[1])
])
percentile_idx = np.array(percentile_idx)
# In rare cases, percentile_idx equals to sorted_idx.shape[0]
max_idx = sorted_idx.shape[0] - 1
percentile_idx = np.apply_along_axis(lambda x: np.clip(x, 0, max_idx),
axis=0, arr=percentile_idx)

col_index = np.arange(array.shape[1])
percentile_in_sorted = sorted_idx[percentile_idx, col_index]
percentile = array[percentile_in_sorted, col_index]
return percentile[0] if n_dim == 1 else percentile