diff --git a/.all-contributorsrc b/.all-contributorsrc index a8bf95e11dc..b29626d8cfd 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -2144,7 +2144,8 @@ "profile": "https://github.com/yarnabrina/", "contributions": [ "bug", - "code" + "code", + "test" ] } ] diff --git a/sktime/forecasting/sarimax.py b/sktime/forecasting/sarimax.py index 9c99ca00256..20e2dfaa56c 100644 --- a/sktime/forecasting/sarimax.py +++ b/sktime/forecasting/sarimax.py @@ -4,7 +4,9 @@ """Implements SARIMAX.""" __all__ = ["SARIMAX"] -__author__ = ["TNTran92"] +__author__ = ["TNTran92", "yarnabrina"] + +import pandas as pd from sktime.forecasting.base.adapters import _StatsModelsAdapter @@ -116,6 +118,7 @@ class SARIMAX(_StatsModelsAdapter): _tags = { "ignores-exogeneous-X": False, + "capability:pred_int": True, } def __init__( @@ -193,3 +196,59 @@ def summary(self): https://www.statsmodels.org/dev/examples/notebooks/generated/statespace_structural_harvey_jaeger.html """ return self._fitted_forecaster.summary() + + def _predict_interval(self, fh, X=None, coverage=0.95): + """Compute/return prediction interval forecasts. + + private _predict_interval containing the core logic, + called from predict_interval and default _predict_quantiles + + Parameters + ---------- + fh : guaranteed to be ForecastingHorizon + The forecasting horizon with the steps ahead to to predict. + X : optional (default=None) + guaranteed to be of a type in self.get_tag("X_inner_mtype") + Exogeneous time series to predict from. + coverage : float or list of float, optional (default=0.95) + nominal coverage(s) of predictive interval(s) + + Returns + ------- + pred_int : pd.DataFrame + Column has multi-index: first level is variable name from y in fit, + second level coverage fractions for which intervals were computed. + in the same order as in input `coverage`. + Third level is string "lower" or "upper", for lower/upper interval end. + Row index is fh, with additional (upper) levels equal to instance levels, + from y seen in fit, if y_inner_mtype is Panel or Hierarchical. + Entries are forecasts of lower/upper interval end, + for var in col index, at nominal coverage in second col index, + lower/upper depending on third col index, for the row index. + Upper/lower interval end forecasts are equivalent to + quantile forecasts at alpha = 0.5 - c/2, 0.5 + c/2 for c in coverage. + """ + start, end = fh.to_absolute_int(self._y.index[0], self.cutoff)[[0, -1]] + valid_indices = fh.to_absolute(self.cutoff).to_pandas() + + prediction_results = self._fitted_forecaster.get_prediction( + start=start, end=end, exog=X + ) + + columns = pd.MultiIndex.from_product( + [["Coverage"], coverage, ["lower", "upper"]] + ) + pred_int = pd.DataFrame(index=valid_indices, columns=columns) + + for c in coverage: + pred_statsmodels = prediction_results.conf_int(alpha=(1 - c)) + pred_statsmodels.columns = ["lower", "upper"] + + pred_int[("Coverage", c, "lower")] = pred_statsmodels.loc[ + valid_indices, "lower" + ] + pred_int[("Coverage", c, "upper")] = pred_statsmodels.loc[ + valid_indices, "upper" + ] + + return pred_int diff --git a/sktime/forecasting/tests/test_sarimax.py b/sktime/forecasting/tests/test_sarimax.py index 17fdd9984eb..45c7e2fd772 100644 --- a/sktime/forecasting/tests/test_sarimax.py +++ b/sktime/forecasting/tests/test_sarimax.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- """Tests the SARIMAX model.""" -__author__ = ["TNTran92"] +__author__ = ["TNTran92", "yarnabrina"] import pytest from numpy.testing import assert_allclose +from pandas.testing import assert_frame_equal from sktime.forecasting.sarimax import SARIMAX from sktime.utils._testing.forecasting import make_forecasting_problem @@ -28,3 +29,62 @@ def test_SARIMAX_against_statsmodels(): stats_fit = stats.fit() stats_pred = stats_fit.predict(df.index[0]) assert_allclose(y_pred.tolist(), stats_pred.tolist()) + + +@pytest.mark.skipif( + not _check_soft_dependencies("statsmodels", severity="none"), + reason="skip test if required soft dependency not available", +) +def test_SARIMAX_single_interval_against_statsmodels(): + """Compares Sktime's and Statsmodel's SARIMAX. + + Notes + ----- + * Predict confidence intervals using underlying estimator and the wrapper. + * Predicts for a single coverage. + * Uses a non-default value of 97.5% to test inputs are actually being respected. + """ + from statsmodels.tsa.api import SARIMAX as _SARIMAX + + sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6)) + sktime_model.fit(df) + sktime_pred_int = sktime_model.predict_interval(df.index, coverage=0.975) + sktime_pred_int = sktime_pred_int.xs(("Coverage", 0.975), axis="columns") + + stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6)) + stats_fit = stats.fit() + stats_pred_int = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.025) + stats_pred_int.columns = ["lower", "upper"] + + assert_frame_equal(sktime_pred_int, stats_pred_int) + + +@pytest.mark.skipif( + not _check_soft_dependencies("statsmodels", severity="none"), + reason="skip test if required soft dependency not available", +) +def test_SARIMAX_multiple_intervals_against_statsmodels(): + """Compares Sktime's and Statsmodel's SARIMAX. + + Notes + ----- + * Predict confidence intervals using underlying estimator and the wrapper. + * Predicts for multiple coverage values, viz. 70% and 80%. + """ + from statsmodels.tsa.api import SARIMAX as _SARIMAX + + sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6)) + sktime_model.fit(df) + sktime_pred_int = sktime_model.predict_interval(df.index, coverage=[0.70, 0.80]) + sktime_pred_int_70 = sktime_pred_int.xs(("Coverage", 0.70), axis="columns") + sktime_pred_int_80 = sktime_pred_int.xs(("Coverage", 0.80), axis="columns") + + stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6)) + stats_fit = stats.fit() + stats_pred_int_70 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.30) + stats_pred_int_70.columns = ["lower", "upper"] + stats_pred_int_80 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.20) + stats_pred_int_80.columns = ["lower", "upper"] + + assert_frame_equal(sktime_pred_int_70, stats_pred_int_70) + assert_frame_equal(sktime_pred_int_80, stats_pred_int_80)