Skip to content

Commit

Permalink
[ENH] migrate tests of distribution prediction metrics to skbase cl…
Browse files Browse the repository at this point in the history
…ass (#208)

This PR migrates tests of distribution prediction metrics to an `skbase`
class.

Also generalizes the test for censored ground truth.
  • Loading branch information
fkiraly committed Feb 9, 2024
1 parent 6a23216 commit bf86474
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 54 deletions.
3 changes: 2 additions & 1 deletion skpro/metrics/base.py
Expand Up @@ -376,6 +376,7 @@ class BaseDistrMetric(BaseProbaMetric):
"""

_tags = {
"object_type": ["metric", "metric_distr"], # type of object
"scitype:y_pred": "pred_proba",
"lower_is_better": True,
}
Expand All @@ -402,7 +403,7 @@ def evaluate(self, y_true, y_pred, **kwargs):
multioutput = self.multioutput
multivariate = self.multivariate

index_df = self.evaluate_by_index(y_true, y_pred)
index_df = self.evaluate_by_index(y_true, y_pred, **kwargs)
out_df = pd.DataFrame(index_df.mean(axis=0)).T
out_df.columns = index_df.columns

Expand Down
110 changes: 57 additions & 53 deletions skpro/metrics/tests/test_distr_metrics.py
@@ -1,60 +1,64 @@
"""Tests for probabilistic metrics for distribution predictions."""
import warnings

import numpy as np
import pandas as pd
import pytest
from skbase.testing import QuickTester

from skpro.distributions import Normal
from skpro.metrics import (
CRPS,
SPLL,
ConcordanceHarrell,
LinearizedLogLoss,
LogLoss,
SquaredDistrLoss,
)

warnings.filterwarnings("ignore", category=FutureWarning)

DISTR_METRICS = [
CRPS,
SPLL,
ConcordanceHarrell,
LinearizedLogLoss,
LogLoss,
SquaredDistrLoss,
]

normal_dists = [Normal]


@pytest.mark.parametrize("normal", normal_dists)
@pytest.mark.parametrize("metric", DISTR_METRICS)
@pytest.mark.parametrize("multivariate", [True, False])
@pytest.mark.parametrize("multioutput", ["raw_values", "uniform_average"])
def test_distr_evaluate(normal, metric, multivariate, multioutput):
"""Test expected output of evaluate functions."""
y_pred = normal.create_test_instance()
y_true = y_pred.sample()

m = metric(multivariate=multivariate, multioutput=multioutput)

if not multivariate:
expected_cols = y_true.columns
else:
expected_cols = ["score"]

res = m.evaluate_by_index(y_true, y_pred)
assert isinstance(res, pd.DataFrame)
assert (res.columns == expected_cols).all()
assert res.shape == (y_true.shape[0], len(expected_cols))

res = m.evaluate(y_true, y_pred)

expect_df = not multivariate and multioutput == "raw_values"
if expect_df:
from skpro.tests.test_all_estimators import BaseFixtureGenerator, PackageConfig

TEST_DISTS = [Normal]


class TestAllDistrMetrics(PackageConfig, BaseFixtureGenerator, QuickTester):
"""Generic tests for all regressors in the mini package."""

# class variables which can be overridden by descendants
# ------------------------------------------------------

# which object types are generated; None=all, or scitype string
# passed to skpro.registry.all_objects as object_type
object_type_filter = "metric_distr"

@pytest.mark.parametrize("dist", TEST_DISTS)
@pytest.mark.parametrize("pass_c", [True, False])
@pytest.mark.parametrize("multivariate", [True, False])
@pytest.mark.parametrize("multioutput", ["raw_values", "uniform_average"])
def test_distr_evaluate(
self, object_instance, dist, pass_c, multivariate, multioutput
):
"""Test expected output of evaluate functions."""
metric = object_instance

y_pred = dist.create_test_instance()
y_true = y_pred.sample()

m = metric.set_params(multioutput=multioutput)
if "multivariate" in metric.get_params():
m = m.set_params(multivariate=multivariate)

if not multivariate:
expected_cols = y_true.columns
else:
expected_cols = ["score"]

metric_args = {"y_true": y_true, "y_pred": y_pred}
if pass_c:
c_true = np.random.randint(0, 2, size=y_true.shape)
c_true = pd.DataFrame(c_true, columns=y_true.columns, index=y_true.index)
metric_args["c_true"] = c_true

res = m.evaluate_by_index(**metric_args)
assert isinstance(res, pd.DataFrame)
assert (res.columns == expected_cols).all()
assert res.shape == (1, len(expected_cols))
else:
assert isinstance(res, float)
assert res.shape == (y_true.shape[0], len(expected_cols))

res = m.evaluate(**metric_args)

expect_df = not multivariate and multioutput == "raw_values"
if expect_df:
assert isinstance(res, pd.DataFrame)
assert (res.columns == expected_cols).all()
assert res.shape == (1, len(expected_cols))
else:
assert isinstance(res, float)
2 changes: 2 additions & 0 deletions skpro/tests/test_class_register.py
Expand Up @@ -21,6 +21,7 @@ def get_test_class_registry():
keys are scitypes, values are test classes TestAll[Scitype]
"""
from skpro.distributions.tests.test_all_distrs import TestAllDistributions
from skpro.metrics.tests.test_distr_metrics import TestAllDistrMetrics
from skpro.regression.tests.test_all_regressors import TestAllRegressors
from skpro.tests.test_all_estimators import TestAllEstimators, TestAllObjects

Expand All @@ -37,6 +38,7 @@ def get_test_class_registry():
# so also imply estimator and object tests, or only object tests
testclass_dict["distribution"] = TestAllDistributions
testclass_dict["regressor_proba"] = TestAllRegressors
testclass_dict["metric_proba"] = TestAllDistrMetrics

return testclass_dict

Expand Down

0 comments on commit bf86474

Please sign in to comment.