Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ details.
metrics.mean_poisson_deviance
metrics.mean_gamma_deviance
metrics.mean_tweedie_deviance
metrics.d2_tweedie_score
metrics.mean_pinball_loss

Multilabel ranking metrics
Expand Down
30 changes: 29 additions & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2354,6 +2354,34 @@ the difference in errors decreases. Finally, by setting, ``power=2``::
we would get identical errors. The deviance when ``power=2`` is thus only
sensitive to relative errors.

.. _d2_tweedie_score:

D² score, the coefficient of determination
-------------------------------------------

The :func:`d2_tweedie_score` function computes the percentage of deviance
explained. It is a generalization of R², where the squared error is replaced by
the Tweedie deviance. D², also known as McFadden's likelihood ratio index, is
calculated as

.. math::

D^2(y, \hat{y}) = 1 - \frac{\text{D}(y, \hat{y})}{\text{D}(y, \bar{y})} \,.

The argument ``power`` defines the Tweedie power as for
:func:`mean_tweedie_deviance`. Note that for `power=0`,
:func:`d2_tweedie_score` equals :func:`r2_score` (for single targets).

Like R², the best possible score is 1.0 and it can be negative (because the
model can be arbitrarily worse). A constant model that always predicts the
expected value of y, disregarding the input features, would get a D² score
of 0.0.

A scorer object with a specific choice of ``power`` can be built by::

>>> from sklearn.metrics import d2_tweedie_score, make_scorer
>>> d2_tweedie_score_15 = make_scorer(d2_tweedie_score, pwoer=1.5)

.. _pinball_loss:

Pinball loss
Expand Down Expand Up @@ -2386,7 +2414,7 @@ Here is a small example of usage of the :func:`mean_pinball_loss` function::
>>> mean_pinball_loss(y_true, y_true, alpha=0.9)
0.0

It is possible to build a scorer object with a specific choice of alpha::
It is possible to build a scorer object with a specific choice of ``alpha``::

>>> from sklearn.metrics import make_scorer
>>> mean_pinball_loss_95p = make_scorer(mean_pinball_loss, alpha=0.95)
Expand Down
8 changes: 7 additions & 1 deletion doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ Changelog
quantile regression. :pr:`19415` by :user:`Xavier Dupré <sdpython>`
and :user:`Oliver Grisel <ogrisel>`.

- |Feature| :func:`metrics.d2_tweedie_score` calculates the D^2 regression
score for Tweedie deviances with power parameter ``power``. This is a
generalization of the `r2_score` and can be interpreted as percentage of
Tweedie deviance explained.
:pr:`17036` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Feature| :func:`metrics.mean_squared_log_error` now supports
`squared=False`.
:pr:`20326` by :user:`Uttam kumar <helper-uttam>`.
Expand Down Expand Up @@ -683,7 +689,7 @@ Changelog
.............................

- |Fix| :class:`neural_network.MLPClassifier` and
:class:`neural_network.MLPRegressor` now correct supports continued training
:class:`neural_network.MLPRegressor` now correctly support continued training
when loading from a pickled file. :pr:`19631` by `Thomas Fan`_.

:mod:`sklearn.pipeline`
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from ._regression import mean_tweedie_deviance
from ._regression import mean_poisson_deviance
from ._regression import mean_gamma_deviance
from ._regression import d2_tweedie_score


from ._scorer import check_scoring
Expand Down Expand Up @@ -109,6 +110,7 @@
"confusion_matrix",
"consensus_score",
"coverage_error",
"d2_tweedie_score",
"dcg_score",
"davies_bouldin_score",
"DetCurveDisplay",
Expand Down
109 changes: 107 additions & 2 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@
# Uttam kumar <bajiraouttamsinha@gmail.com>
# License: BSD 3 clause

import numpy as np
import warnings

import numpy as np

from .._loss.glm_distribution import TweedieDistribution
from ..exceptions import UndefinedMetricWarning
from ..utils.validation import check_array, check_consistent_length, _num_samples
from ..utils.validation import column_or_1d
from ..utils.validation import _check_sample_weight
from ..utils.stats import _weighted_percentile
from ..exceptions import UndefinedMetricWarning


__ALL__ = [
Expand Down Expand Up @@ -986,3 +987,107 @@ def mean_gamma_deviance(y_true, y_pred, *, sample_weight=None):
1.0568...
"""
return mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=2)


def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0):
"""D^2 regression score function, percentage of Tweedie deviance explained.

Best possible score is 1.0 and it can be negative (because the model can be
arbitrarily worse). A model that always uses the empirical mean of `y_true` as
constant prediction, disregarding the input features, gets a D^2 score of 0.0.

Read more in the :ref:`User Guide <d2_tweedie_score>`.

.. versionadded:: 1.0

Parameters
----------
y_true : array-like of shape (n_samples,)
Ground truth (correct) target values.

y_pred : array-like of shape (n_samples,)
Estimated target values.

sample_weight : array-like of shape (n_samples,), optional
Sample weights.

power : float, default=0
Tweedie power parameter. Either power <= 0 or power >= 1.

The higher `p` the less weight is given to extreme
deviations between true and predicted targets.

- power < 0: Extreme stable distribution. Requires: y_pred > 0.
- power = 0 : Normal distribution, output corresponds to r2_score.
y_true and y_pred can be any real numbers.
- power = 1 : Poisson distribution. Requires: y_true >= 0 and
y_pred > 0.
- 1 < p < 2 : Compound Poisson distribution. Requires: y_true >= 0
and y_pred > 0.
- power = 2 : Gamma distribution. Requires: y_true > 0 and y_pred > 0.
- power = 3 : Inverse Gaussian distribution. Requires: y_true > 0
and y_pred > 0.
- otherwise : Positive stable distribution. Requires: y_true > 0
and y_pred > 0.

Returns
-------
z : float or ndarray of floats
The D^2 score.

Notes
-----
This is not a symmetric function.

Like R^2, D^2 score may be negative (it need not actually be the square of
a quantity D).

This metric is not well-defined for single samples and will return a NaN
value if n_samples is less than two.

References
----------
.. [1] Eq. (3.11) of Hastie, Trevor J., Robert Tibshirani and Martin J.
Wainwright. "Statistical Learning with Sparsity: The Lasso and
Generalizations." (2015). https://trevorhastie.github.io

Examples
--------
>>> from sklearn.metrics import d2_tweedie_score
>>> y_true = [0.5, 1, 2.5, 7]
>>> y_pred = [1, 1, 5, 3.5]
>>> d2_tweedie_score(y_true, y_pred)
0.285...
>>> d2_tweedie_score(y_true, y_pred, power=1)
0.487...
>>> d2_tweedie_score(y_true, y_pred, power=2)
0.630...
>>> d2_tweedie_score(y_true, y_true, power=2)
1.0
"""
y_type, y_true, y_pred, _ = _check_reg_targets(
y_true, y_pred, None, dtype=[np.float64, np.float32]
)
if y_type == "continuous-multioutput":
raise ValueError("Multioutput not supported in d2_tweedie_score")
check_consistent_length(y_true, y_pred, sample_weight)

if _num_samples(y_pred) < 2:
msg = "D^2 score is not well-defined with less than two samples."
warnings.warn(msg, UndefinedMetricWarning)
return float("nan")

if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
sample_weight = sample_weight[:, np.newaxis]

dist = TweedieDistribution(power=power)

dev = dist.unit_deviance(y_true, y_pred, check_input=True)
numerator = np.average(dev, weights=sample_weight)

y_avg = np.average(y_true, weights=sample_weight)
dev = dist.unit_deviance(y_true, y_avg, check_input=True)
denominator = np.average(dev, weights=sample_weight)

return 1 - numerator / denominator
4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import coverage_error
from sklearn.metrics import d2_tweedie_score
from sklearn.metrics import det_curve
from sklearn.metrics import explained_variance_score
from sklearn.metrics import f1_score
Expand Down Expand Up @@ -110,6 +111,7 @@
"mean_poisson_deviance": mean_poisson_deviance,
"mean_gamma_deviance": mean_gamma_deviance,
"mean_compound_poisson_deviance": partial(mean_tweedie_deviance, power=1.4),
"d2_tweedie_score": partial(d2_tweedie_score, power=1.4),
}

CLASSIFICATION_METRICS = {
Expand Down Expand Up @@ -510,6 +512,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"mean_gamma_deviance",
"mean_poisson_deviance",
"mean_compound_poisson_deviance",
"d2_tweedie_score",
"mean_absolute_percentage_error",
}

Expand All @@ -526,6 +529,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"mean_poisson_deviance",
"mean_gamma_deviance",
"mean_compound_poisson_deviance",
"d2_tweedie_score",
}


Expand Down
59 changes: 45 additions & 14 deletions sklearn/metrics/tests/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from scipy import optimize
from numpy.testing import assert_allclose
from scipy.special import factorial, xlogy
from itertools import product
import pytest

Expand All @@ -20,6 +21,7 @@
from sklearn.metrics import mean_pinball_loss
from sklearn.metrics import r2_score
from sklearn.metrics import mean_tweedie_deviance
from sklearn.metrics import d2_tweedie_score
from sklearn.metrics import make_scorer

from sklearn.metrics._regression import _check_reg_targets
Expand Down Expand Up @@ -53,6 +55,9 @@ def test_regression_metrics(n_samples=50):
mean_tweedie_deviance(y_true, y_pred, power=0),
mean_squared_error(y_true, y_pred),
)
assert_almost_equal(
d2_tweedie_score(y_true, y_pred, power=0), r2_score(y_true, y_pred)
)

# Tweedie deviance needs positive y_pred, except for p=0,
# p>=2 needs positive y_true
Expand All @@ -78,6 +83,17 @@ def test_regression_metrics(n_samples=50):
mean_tweedie_deviance(y_true, y_pred, power=3), np.sum(1 / y_true) / (4 * n)
)

dev_mean = 2 * np.mean(xlogy(y_true, 2 * y_true / (n + 1)))
assert_almost_equal(
d2_tweedie_score(y_true, y_pred, power=1),
1 - (n + 1) * (1 - np.log(2)) / dev_mean,
)

dev_mean = 2 * np.log((n + 1) / 2) - 2 / n * np.log(factorial(n))
assert_almost_equal(
d2_tweedie_score(y_true, y_pred, power=2), 1 - (2 * np.log(2) - 1) / dev_mean
)


def test_mean_squared_error_multioutput_raw_value_squared():
# non-regression test for
Expand Down Expand Up @@ -131,59 +147,74 @@ def test_regression_metrics_at_limits():
assert_almost_equal(max_error([0.0], [0.0]), 0.0)
assert_almost_equal(explained_variance_score([0.0], [0.0]), 1.0)
assert_almost_equal(r2_score([0.0, 1], [0.0, 1]), 1.0)
err_msg = (
msg = (
"Mean Squared Logarithmic Error cannot be used when targets "
"contain negative values."
)
with pytest.raises(ValueError, match=err_msg):
with pytest.raises(ValueError, match=msg):
mean_squared_log_error([-1.0], [-1.0])
err_msg = (
msg = (
"Mean Squared Logarithmic Error cannot be used when targets "
"contain negative values."
)
with pytest.raises(ValueError, match=err_msg):
with pytest.raises(ValueError, match=msg):
mean_squared_log_error([1.0, 2.0, 3.0], [1.0, -2.0, 3.0])
err_msg = (
msg = (
"Mean Squared Logarithmic Error cannot be used when targets "
"contain negative values."
)
with pytest.raises(ValueError, match=err_msg):
with pytest.raises(ValueError, match=msg):
mean_squared_log_error([1.0, -2.0, 3.0], [1.0, 2.0, 3.0])

# Tweedie deviance error
power = -1.2
assert_allclose(
mean_tweedie_deviance([0], [1.0], power=power), 2 / (2 - power), rtol=1e-3
)
with pytest.raises(
ValueError, match="can only be used on strictly positive y_pred."
):
msg = "can only be used on strictly positive y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.0], [0.0], power=power)
assert_almost_equal(mean_tweedie_deviance([0.0], [0.0], power=0), 0.00, 2)
with pytest.raises(ValueError, match=msg):
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)

assert_almost_equal(mean_tweedie_deviance([0.0], [0.0], power=0), 0.0, 2)

power = 1.0
msg = "only be used on non-negative y and strictly positive y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.0], [0.0], power=1.0)
mean_tweedie_deviance([0.0], [0.0], power=power)
with pytest.raises(ValueError, match=msg):
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)

power = 1.5
assert_allclose(mean_tweedie_deviance([0.0], [1.0], power=power), 2 / (2 - power))
msg = "only be used on non-negative y and strictly positive y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.0], [0.0], power=power)
with pytest.raises(ValueError, match=msg):
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)

power = 2.0
assert_allclose(mean_tweedie_deviance([1.0], [1.0], power=power), 0.00, atol=1e-8)
msg = "can only be used on strictly positive y and y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.0], [0.0], power=power)
with pytest.raises(ValueError, match=msg):
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)

power = 3.0
assert_allclose(mean_tweedie_deviance([1.0], [1.0], power=power), 0.00, atol=1e-8)

msg = "can only be used on strictly positive y and y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.0], [0.0], power=power)
with pytest.raises(ValueError, match=msg):
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)

power = 0.5
with pytest.raises(ValueError, match="is only defined for power<=0 and power>=1"):
mean_tweedie_deviance([0.0], [0.0], power=power)
with pytest.raises(ValueError, match="is only defined for power<=0 and power>=1"):
mean_tweedie_deviance([0.0], [0.0], power=0.5)
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)


def test__check_reg_targets():
Expand Down Expand Up @@ -319,7 +350,7 @@ def test_regression_custom_weights():
assert_almost_equal(msle, msle2, decimal=2)


@pytest.mark.parametrize("metric", [r2_score])
@pytest.mark.parametrize("metric", [r2_score, d2_tweedie_score])
def test_regression_single_sample(metric):
y_true = [0]
y_pred = [1]
Expand Down