-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] migrate tests of distribution prediction metrics to
skbase
cl…
…ass (#208) This PR migrates tests of distribution prediction metrics to an `skbase` class. Also generalizes the test for censored ground truth.
- Loading branch information
Showing
3 changed files
with
61 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,64 @@ | ||
"""Tests for probabilistic metrics for distribution predictions.""" | ||
import warnings | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from skbase.testing import QuickTester | ||
|
||
from skpro.distributions import Normal | ||
from skpro.metrics import ( | ||
CRPS, | ||
SPLL, | ||
ConcordanceHarrell, | ||
LinearizedLogLoss, | ||
LogLoss, | ||
SquaredDistrLoss, | ||
) | ||
|
||
warnings.filterwarnings("ignore", category=FutureWarning) | ||
|
||
DISTR_METRICS = [ | ||
CRPS, | ||
SPLL, | ||
ConcordanceHarrell, | ||
LinearizedLogLoss, | ||
LogLoss, | ||
SquaredDistrLoss, | ||
] | ||
|
||
normal_dists = [Normal] | ||
|
||
|
||
@pytest.mark.parametrize("normal", normal_dists) | ||
@pytest.mark.parametrize("metric", DISTR_METRICS) | ||
@pytest.mark.parametrize("multivariate", [True, False]) | ||
@pytest.mark.parametrize("multioutput", ["raw_values", "uniform_average"]) | ||
def test_distr_evaluate(normal, metric, multivariate, multioutput): | ||
"""Test expected output of evaluate functions.""" | ||
y_pred = normal.create_test_instance() | ||
y_true = y_pred.sample() | ||
|
||
m = metric(multivariate=multivariate, multioutput=multioutput) | ||
|
||
if not multivariate: | ||
expected_cols = y_true.columns | ||
else: | ||
expected_cols = ["score"] | ||
|
||
res = m.evaluate_by_index(y_true, y_pred) | ||
assert isinstance(res, pd.DataFrame) | ||
assert (res.columns == expected_cols).all() | ||
assert res.shape == (y_true.shape[0], len(expected_cols)) | ||
|
||
res = m.evaluate(y_true, y_pred) | ||
|
||
expect_df = not multivariate and multioutput == "raw_values" | ||
if expect_df: | ||
from skpro.tests.test_all_estimators import BaseFixtureGenerator, PackageConfig | ||
|
||
TEST_DISTS = [Normal] | ||
|
||
|
||
class TestAllDistrMetrics(PackageConfig, BaseFixtureGenerator, QuickTester): | ||
"""Generic tests for all regressors in the mini package.""" | ||
|
||
# class variables which can be overridden by descendants | ||
# ------------------------------------------------------ | ||
|
||
# which object types are generated; None=all, or scitype string | ||
# passed to skpro.registry.all_objects as object_type | ||
object_type_filter = "metric_distr" | ||
|
||
@pytest.mark.parametrize("dist", TEST_DISTS) | ||
@pytest.mark.parametrize("pass_c", [True, False]) | ||
@pytest.mark.parametrize("multivariate", [True, False]) | ||
@pytest.mark.parametrize("multioutput", ["raw_values", "uniform_average"]) | ||
def test_distr_evaluate( | ||
self, object_instance, dist, pass_c, multivariate, multioutput | ||
): | ||
"""Test expected output of evaluate functions.""" | ||
metric = object_instance | ||
|
||
y_pred = dist.create_test_instance() | ||
y_true = y_pred.sample() | ||
|
||
m = metric.set_params(multioutput=multioutput) | ||
if "multivariate" in metric.get_params(): | ||
m = m.set_params(multivariate=multivariate) | ||
|
||
if not multivariate: | ||
expected_cols = y_true.columns | ||
else: | ||
expected_cols = ["score"] | ||
|
||
metric_args = {"y_true": y_true, "y_pred": y_pred} | ||
if pass_c: | ||
c_true = np.random.randint(0, 2, size=y_true.shape) | ||
c_true = pd.DataFrame(c_true, columns=y_true.columns, index=y_true.index) | ||
metric_args["c_true"] = c_true | ||
|
||
res = m.evaluate_by_index(**metric_args) | ||
assert isinstance(res, pd.DataFrame) | ||
assert (res.columns == expected_cols).all() | ||
assert res.shape == (1, len(expected_cols)) | ||
else: | ||
assert isinstance(res, float) | ||
assert res.shape == (y_true.shape[0], len(expected_cols)) | ||
|
||
res = m.evaluate(**metric_args) | ||
|
||
expect_df = not multivariate and multioutput == "raw_values" | ||
if expect_df: | ||
assert isinstance(res, pd.DataFrame) | ||
assert (res.columns == expected_cols).all() | ||
assert res.shape == (1, len(expected_cols)) | ||
else: | ||
assert isinstance(res, float) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters