Skip to content

Commit

Permalink
[ENH] Adds _predict_interval to SARIMAX to support `predict_inter…
Browse files Browse the repository at this point in the history
…val` and `predict_quantiles` (#4439)

Fixes #4301.

Before this PR, `SARIMAX` did not support prediction of confidence intervals and quantiles. This PR adds those functionalities, and tests to check that results match with underlying `statsmodels` estimator.
  • Loading branch information
yarnabrina committed Apr 13, 2023
1 parent a095631 commit f37b7ea
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .all-contributorsrc
Expand Up @@ -2144,7 +2144,8 @@
"profile": "https://github.com/yarnabrina/",
"contributions": [
"bug",
"code"
"code",
"test"
]
}
]
Expand Down
61 changes: 60 additions & 1 deletion sktime/forecasting/sarimax.py
Expand Up @@ -4,7 +4,9 @@
"""Implements SARIMAX."""

__all__ = ["SARIMAX"]
__author__ = ["TNTran92"]
__author__ = ["TNTran92", "yarnabrina"]

import pandas as pd

from sktime.forecasting.base.adapters import _StatsModelsAdapter

Expand Down Expand Up @@ -116,6 +118,7 @@ class SARIMAX(_StatsModelsAdapter):

_tags = {
"ignores-exogeneous-X": False,
"capability:pred_int": True,
}

def __init__(
Expand Down Expand Up @@ -193,3 +196,59 @@ def summary(self):
https://www.statsmodels.org/dev/examples/notebooks/generated/statespace_structural_harvey_jaeger.html
"""
return self._fitted_forecaster.summary()

def _predict_interval(self, fh, X=None, coverage=0.95):
"""Compute/return prediction interval forecasts.
private _predict_interval containing the core logic,
called from predict_interval and default _predict_quantiles
Parameters
----------
fh : guaranteed to be ForecastingHorizon
The forecasting horizon with the steps ahead to to predict.
X : optional (default=None)
guaranteed to be of a type in self.get_tag("X_inner_mtype")
Exogeneous time series to predict from.
coverage : float or list of float, optional (default=0.95)
nominal coverage(s) of predictive interval(s)
Returns
-------
pred_int : pd.DataFrame
Column has multi-index: first level is variable name from y in fit,
second level coverage fractions for which intervals were computed.
in the same order as in input `coverage`.
Third level is string "lower" or "upper", for lower/upper interval end.
Row index is fh, with additional (upper) levels equal to instance levels,
from y seen in fit, if y_inner_mtype is Panel or Hierarchical.
Entries are forecasts of lower/upper interval end,
for var in col index, at nominal coverage in second col index,
lower/upper depending on third col index, for the row index.
Upper/lower interval end forecasts are equivalent to
quantile forecasts at alpha = 0.5 - c/2, 0.5 + c/2 for c in coverage.
"""
start, end = fh.to_absolute_int(self._y.index[0], self.cutoff)[[0, -1]]
valid_indices = fh.to_absolute(self.cutoff).to_pandas()

prediction_results = self._fitted_forecaster.get_prediction(
start=start, end=end, exog=X
)

columns = pd.MultiIndex.from_product(
[["Coverage"], coverage, ["lower", "upper"]]
)
pred_int = pd.DataFrame(index=valid_indices, columns=columns)

for c in coverage:
pred_statsmodels = prediction_results.conf_int(alpha=(1 - c))
pred_statsmodels.columns = ["lower", "upper"]

pred_int[("Coverage", c, "lower")] = pred_statsmodels.loc[
valid_indices, "lower"
]
pred_int[("Coverage", c, "upper")] = pred_statsmodels.loc[
valid_indices, "upper"
]

return pred_int
62 changes: 61 additions & 1 deletion sktime/forecasting/tests/test_sarimax.py
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-
"""Tests the SARIMAX model."""
__author__ = ["TNTran92"]
__author__ = ["TNTran92", "yarnabrina"]

import pytest
from numpy.testing import assert_allclose
from pandas.testing import assert_frame_equal

from sktime.forecasting.sarimax import SARIMAX
from sktime.utils._testing.forecasting import make_forecasting_problem
Expand All @@ -28,3 +29,62 @@ def test_SARIMAX_against_statsmodels():
stats_fit = stats.fit()
stats_pred = stats_fit.predict(df.index[0])
assert_allclose(y_pred.tolist(), stats_pred.tolist())


@pytest.mark.skipif(
not _check_soft_dependencies("statsmodels", severity="none"),
reason="skip test if required soft dependency not available",
)
def test_SARIMAX_single_interval_against_statsmodels():
"""Compares Sktime's and Statsmodel's SARIMAX.
Notes
-----
* Predict confidence intervals using underlying estimator and the wrapper.
* Predicts for a single coverage.
* Uses a non-default value of 97.5% to test inputs are actually being respected.
"""
from statsmodels.tsa.api import SARIMAX as _SARIMAX

sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
sktime_model.fit(df)
sktime_pred_int = sktime_model.predict_interval(df.index, coverage=0.975)
sktime_pred_int = sktime_pred_int.xs(("Coverage", 0.975), axis="columns")

stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
stats_fit = stats.fit()
stats_pred_int = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.025)
stats_pred_int.columns = ["lower", "upper"]

assert_frame_equal(sktime_pred_int, stats_pred_int)


@pytest.mark.skipif(
not _check_soft_dependencies("statsmodels", severity="none"),
reason="skip test if required soft dependency not available",
)
def test_SARIMAX_multiple_intervals_against_statsmodels():
"""Compares Sktime's and Statsmodel's SARIMAX.
Notes
-----
* Predict confidence intervals using underlying estimator and the wrapper.
* Predicts for multiple coverage values, viz. 70% and 80%.
"""
from statsmodels.tsa.api import SARIMAX as _SARIMAX

sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
sktime_model.fit(df)
sktime_pred_int = sktime_model.predict_interval(df.index, coverage=[0.70, 0.80])
sktime_pred_int_70 = sktime_pred_int.xs(("Coverage", 0.70), axis="columns")
sktime_pred_int_80 = sktime_pred_int.xs(("Coverage", 0.80), axis="columns")

stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
stats_fit = stats.fit()
stats_pred_int_70 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.30)
stats_pred_int_70.columns = ["lower", "upper"]
stats_pred_int_80 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.20)
stats_pred_int_80.columns = ["lower", "upper"]

assert_frame_equal(sktime_pred_int_70, stats_pred_int_70)
assert_frame_equal(sktime_pred_int_80, stats_pred_int_80)

0 comments on commit f37b7ea

Please sign in to comment.