-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]: Number of folds not reported correctly when passing ExpandingWindowSplitter #3011
Comments
Hi @ngupta23, I was looking into this, and initially, I thought the def _set_fold_generator(self) -> "TSForecastingExperiment":
"""Sets up the cross-validation fold generator for the training dataset."""
possible_time_series_fold_strategies = ["expanding", "sliding", "rolling"]
if not (
self.fold_strategy in possible_time_series_fold_strategies
or is_sklearn_cv_generator(self.fold_strategy)
):
raise TypeError(
"fold_strategy parameter must be either a sktime compatible CV generator "
f"object or one of '{', '.join(possible_time_series_fold_strategies)}'."
)
if self.fold_strategy in possible_time_series_fold_strategies:
# Number of folds
self.fold_param = self.fold
self.fold_generator = self.get_fold_generator(fold=self.fold_param)
else:
self.fold_generator = self.fold_strategy
self.fold_param = self.fold_generator.get_n_splits(y=self.y_train)
return self So, I modified the second if block to calculate the fold count for CV objects like if self.fold_strategy in possible_time_series_fold_strategies:
# Calculate the fold_param based on splits
self.fold_generator = self.get_fold_generator(fold=self.fold)
self.fold_param = (
self.fold_generator.get_n_splits(y=self.y_train)
if hasattr(self.fold_generator, 'get_n_splits') else None
) This fixed the output. So running the same code as you, I got this: Now the fold number is showing up correctly. I was about to make a pull request with the changes I have made, however, at this point, I started looking at the documentation of Pycaret's Time Series libraries. And I found out that the So, your code should be modified to this: # !pip install pycaret==3.0.0rc3
import numpy as np
from pycaret.time_series import TSForecastingExperiment
from pycaret.datasets import get_data
from sktime.forecasting.model_selection import ExpandingWindowSplitter
y = get_data(114, folder="time_series/seasonal", verbose=False)
cv = ExpandingWindowSplitter(fh=np.arange(1, 13), initial_window=24, step_length=4)
exp = TSForecastingExperiment()
exp.setup(y, fh=12, fold_strategy=cv) This should return the output as expected! |
However, since passing a CV object to the # Validate fold
if not isinstance(fold, int):
raise TypeError(
f"The 'fold' parameter must be an integer. You provided: {type(fold).__name__}. "
"If you intended to use a custom cross-validation object such as SlidingWindowSplitter or "
"ExpandingWindowSplitter, please pass it to the 'fold_strategy' parameter instead. "
"The 'fold' parameter is ignored when 'fold_strategy' is a custom CV object."
) I am making a pull request. Let me know if you have any suggestions! |
pycaret version checks
I have checked that this issue has not already been reported here.
I have confirmed this bug exists on the latest version of pycaret.
I have confirmed this bug exists on the master branch of pycaret (pip install -U git+https://github.com/pycaret/pycaret.git@master).
Issue Description
The number of folds is not shown correctly when passing an sktime
ExpandingWindowSplitter
to the fold parameter.Reproducible Example
Expected Behavior
Expected: "Fold Number" should list the number of folds in the CV step, but it just lists the ExpandingWindowSplitter object
Actual results:

Actual Results
Installed Versions
PyCaret required dependencies:
pip: 21.1.3
setuptools: 57.4.0
pycaret: 3.0.0.rc3
IPython: 7.9.0
ipywidgets: 7.7.1
tqdm: 4.64.1
numpy: 1.21.6
pandas: 1.3.5
jinja2: 2.11.3
scipy: 1.7.3
joblib: 1.1.0
sklearn: 1.0.2
pyod: Installed but version unavailable
imblearn: 0.8.1
category_encoders: 2.5.0
lightgbm: 3.3.2
numba: 0.55.2
requests: 2.28.1
matplotlib: 3.5.3
scikitplot: 0.3.7
yellowbrick: 1.5
plotly: 5.5.0
kaleido: 0.2.1
statsmodels: 0.13.2
sktime: 0.11.4
tbats: Installed but version unavailable
pmdarima: 2.0.1
psutil: 5.9.2
The text was updated successfully, but these errors were encountered: