Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix ARCH._check_predict_proba #5349 #5384

Merged
merged 5 commits into from Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 0 additions & 22 deletions sktime/forecasting/arch/_uarch.py
Expand Up @@ -464,28 +464,6 @@ def _predict_var(self, fh=None, X=None, cov=False):

return pred_var

def _predict_proba(self, fh, X, marginal=True):
"""Compute/return fully probabilistic forecasts.

Parameters
----------
fh : int, list, np.array or ForecastingHorizon (not optional)
The forecasting horizon encoding the time stamps to forecast at.
if has not been passed in fit, must be passed, not optional
X : sktime time series object, optional (default=None)
Exogeneous time series for the forecast
marginal : bool, optional (default=True)
whether returned distribution is marginal by time index

Returns
-------
pred_dist : sktime BaseDistribution
predictive distribution
"""
pred_dist = super()._predict_proba(fh=fh, X=X, marginal=marginal)
pred_dist.columns = pd.Index([0])
return pred_dist

def _get_fitted_params(self):
"""Get fitted parameters.

Expand Down
5 changes: 4 additions & 1 deletion sktime/forecasting/tests/test_all_forecasters.py
Expand Up @@ -528,7 +528,10 @@ def _check_predict_proba(self, pred_dist, y_train, fh_int):

# check columns
if isinstance(y_train, pd.Series):
assert (pred_cols == pd.Index([0])).all()
if y_train.name is not None:
assert (pred_cols == y_train.name).all()
else:
assert (pred_cols == pd.Index([0])).all()
else:
assert (pred_cols == y_train.columns).all()

Expand Down