Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Adds _predict_interval to SARIMAX to support predict_interval and predict_quantiles #4439

Merged
merged 8 commits into from Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion .all-contributorsrc
Expand Up @@ -2144,7 +2144,8 @@
"profile": "https://github.com/yarnabrina/",
"contributions": [
"bug",
"code"
"code",
"test"
]
}
]
Expand Down
61 changes: 60 additions & 1 deletion sktime/forecasting/sarimax.py
Expand Up @@ -4,7 +4,9 @@
"""Implements SARIMAX."""

__all__ = ["SARIMAX"]
__author__ = ["TNTran92"]
__author__ = ["TNTran92", "yarnabrina"]

import pandas as pd

from sktime.forecasting.base.adapters import _StatsModelsAdapter

Expand Down Expand Up @@ -116,6 +118,7 @@ class SARIMAX(_StatsModelsAdapter):

_tags = {
"ignores-exogeneous-X": False,
"capability:pred_int": True,
}

def __init__(
Expand Down Expand Up @@ -193,3 +196,59 @@ def summary(self):
https://www.statsmodels.org/dev/examples/notebooks/generated/statespace_structural_harvey_jaeger.html
"""
return self._fitted_forecaster.summary()

def _predict_interval(self, fh, X=None, coverage=0.95):
"""Compute/return prediction interval forecasts.

private _predict_interval containing the core logic,
called from predict_interval and default _predict_quantiles

Parameters
----------
fh : guaranteed to be ForecastingHorizon
The forecasting horizon with the steps ahead to to predict.
X : optional (default=None)
guaranteed to be of a type in self.get_tag("X_inner_mtype")
Exogeneous time series to predict from.
coverage : float or list of float, optional (default=0.95)
nominal coverage(s) of predictive interval(s)

Returns
-------
pred_int : pd.DataFrame
Column has multi-index: first level is variable name from y in fit,
second level coverage fractions for which intervals were computed.
in the same order as in input `coverage`.
Third level is string "lower" or "upper", for lower/upper interval end.
Row index is fh, with additional (upper) levels equal to instance levels,
from y seen in fit, if y_inner_mtype is Panel or Hierarchical.
Entries are forecasts of lower/upper interval end,
for var in col index, at nominal coverage in second col index,
lower/upper depending on third col index, for the row index.
Upper/lower interval end forecasts are equivalent to
quantile forecasts at alpha = 0.5 - c/2, 0.5 + c/2 for c in coverage.
"""
start, end = fh.to_absolute_int(self._y.index[0], self.cutoff)[[0, -1]]
valid_indices = fh.to_absolute(self.cutoff).to_pandas()

prediction_results = self._fitted_forecaster.get_prediction(
start=start, end=end, exog=X
)

columns = pd.MultiIndex.from_product(
[["Coverage"], coverage, ["lower", "upper"]]
)
pred_int = pd.DataFrame(index=valid_indices, columns=columns)

for c in coverage:
pred_statsmodels = prediction_results.conf_int(alpha=(1 - c))
pred_statsmodels.columns = ["lower", "upper"]

pred_int[("Coverage", c, "lower")] = pred_statsmodels.loc[
valid_indices, "lower"
]
pred_int[("Coverage", c, "upper")] = pred_statsmodels.loc[
valid_indices, "upper"
]

return pred_int
62 changes: 61 additions & 1 deletion sktime/forecasting/tests/test_sarimax.py
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-
"""Tests the SARIMAX model."""
__author__ = ["TNTran92"]
__author__ = ["TNTran92", "yarnabrina"]

import pytest
from numpy.testing import assert_allclose
from pandas.testing import assert_frame_equal

from sktime.forecasting.sarimax import SARIMAX
from sktime.utils._testing.forecasting import make_forecasting_problem
Expand All @@ -28,3 +29,62 @@ def test_SARIMAX_against_statsmodels():
stats_fit = stats.fit()
stats_pred = stats_fit.predict(df.index[0])
assert_allclose(y_pred.tolist(), stats_pred.tolist())


@pytest.mark.skipif(
not _check_soft_dependencies("statsmodels", severity="none"),
reason="skip test if required soft dependency not available",
)
def test_SARIMAX_single_interval_against_statsmodels():
"""Compares Sktime's and Statsmodel's SARIMAX.

Notes
-----
* Predict confidence intervals using underlying estimator and the wrapper.
* Predicts for a single coverage.
* Uses a non-default value of 97.5% to test inputs are actually being respected.
"""
from statsmodels.tsa.api import SARIMAX as _SARIMAX

sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
sktime_model.fit(df)
sktime_pred_int = sktime_model.predict_interval(df.index, coverage=0.975)
sktime_pred_int = sktime_pred_int.xs(("Coverage", 0.975), axis="columns")

stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
stats_fit = stats.fit()
stats_pred_int = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.025)
stats_pred_int.columns = ["lower", "upper"]

assert_frame_equal(sktime_pred_int, stats_pred_int)


@pytest.mark.skipif(
not _check_soft_dependencies("statsmodels", severity="none"),
reason="skip test if required soft dependency not available",
)
def test_SARIMAX_multiple_intervals_against_statsmodels():
"""Compares Sktime's and Statsmodel's SARIMAX.

Notes
-----
* Predict confidence intervals using underlying estimator and the wrapper.
* Predicts for multiple coverage values, viz. 70% and 80%.
"""
from statsmodels.tsa.api import SARIMAX as _SARIMAX

sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
sktime_model.fit(df)
sktime_pred_int = sktime_model.predict_interval(df.index, coverage=[0.70, 0.80])
sktime_pred_int_70 = sktime_pred_int.xs(("Coverage", 0.70), axis="columns")
sktime_pred_int_80 = sktime_pred_int.xs(("Coverage", 0.80), axis="columns")

stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
stats_fit = stats.fit()
stats_pred_int_70 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.30)
stats_pred_int_70.columns = ["lower", "upper"]
stats_pred_int_80 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.20)
stats_pred_int_80.columns = ["lower", "upper"]

assert_frame_equal(sktime_pred_int_70, stats_pred_int_70)
assert_frame_equal(sktime_pred_int_80, stats_pred_int_80)