Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] quantile method for distributions, default implementation of forecaster predict_quantiles if predict_proba is present #4513

Merged
merged 4 commits into from Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions sktime/forecasting/base/_base.py
Expand Up @@ -2103,6 +2103,12 @@ def _predict_quantiles(self, fh, X, alpha):

pred_int.columns = int_idx

elif implements_proba:

# 0.19.0 - one instance of legacy_interface to remove
pred_proba = self.predict_proba(fh=fh, X=X, legacy_interface=False)
pred_int = pred_proba.quantile(alpha=alpha)

return pred_int

def _predict_var(self, fh=None, X=None, cov=False):
Expand Down
50 changes: 50 additions & 0 deletions sktime/proba/base.py
Expand Up @@ -358,6 +358,56 @@ def pdfnorm(self, a=2):
spl = [self.pdf(self.sample()) ** (a - 1) for _ in range(self.APPROX_SPL)]
return pd.concat(spl, axis=0).groupby(level=0).mean()

def _coerce_to_self_index_df(self, x):
x = np.array(x)
x = x.reshape(1, -1)
df_shape = self.shape
x = np.broadcast_to(x, df_shape)
df = pd.DataFrame(x, index=self.index, columns=self.columns)
return df

def quantile(self, alpha):
"""Return entry-wise quantiles, in Proba/pred_quantiles mtype format.

This method broadcasts as follows:
for a scalar `alpha`, computes the `alpha`-quantile entry-wise,
and returns as a `pd.DataFrame` with same index, and columns as in return.
If `alpha` is iterable, multiple quantiles will be calculated,
and the result will be concatenated column-wise (axis=1).

The `ppf` method also computes quantiles, but broadcasts differently, in
`numpy` style closer to `tensorflow`.
In contrast, this `quantile` method broadcasts
as forecaster `predict_quantiles`, i.e., columns first.

Parameters
----------
alpha : float or list of float of unique values
A probability or list of, at which quantiles are computed.

Returns
-------
quantiles : pd.DataFrame
Column has multi-index: first level is variable name from `self.columns`,
second level being the values of `alpha` passed to the function.
Row index is `self.index`.
Entries in the i-th row, (j, p)-the column is
the p-th quantile of the marginal of `self` at index (i, j).
"""
if not isinstance(alpha, list):
alpha = [alpha]

qdfs = []
for p in alpha:
p = self._coerce_to_self_index_df(p)
qdf = self.ppf(p)
qdfs += [qdf]

qres = pd.concat(qdfs, axis=1, keys=alpha)
qres = qres.reorder_levels([1, 0], axis=1)
quantiles = qres.sort_index(axis=1)
return quantiles

def sample(self, n_samples=None):
"""Sample from the distribution.

Expand Down
21 changes: 21 additions & 0 deletions sktime/proba/tests/test_all_distrs.py
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import pytest

from sktime.datatypes import check_is_mtype
from sktime.tests.test_all_estimators import BaseFixtureGenerator, QuickTester


Expand Down Expand Up @@ -108,6 +109,26 @@ def test_methods_p(self, estimator_instance, method):

_check_output_format(res, d, method)

@pytest.mark.parametrize("q", [0.7, [0.1, 0.3, 0.9]])
def test_quantile(self, estimator_instance, q):
"""Test expected return of quantile method."""
if not _has_capability(estimator_instance, "ppf"):
return None

d = estimator_instance

def _check_quantile_output(obj, q):
assert check_is_mtype(obj, "pred_quantiles", "Proba")
assert (obj.index == d.index).all()

if not isinstance(q, list):
q = [q]
expected_columns = pd.MultiIndex.from_product([d.columns, q])
assert (obj.columns == expected_columns).all()

res = d.quantile(q)
_check_quantile_output(res, q)


def _check_output_format(res, dist, method):
"""Check output format expectations for BaseDistribution tests."""
Expand Down