Skip to content

Commit

Permalink
Merge pull request #3711 from pycaret/validate_ts_data_split_shuffle
Browse files Browse the repository at this point in the history
Validate `data_split_shuffle` for `timeseries` CV
  • Loading branch information
Yard1 committed Sep 9, 2023
2 parents d4b92e5 + d645f23 commit 9bfe756
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
15 changes: 8 additions & 7 deletions pycaret/classification/oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,18 +747,19 @@ def setup(
self.data_split_stratify = data_split_stratify
self.data_split_shuffle = data_split_shuffle

self._prepare_train_test(
train_size=train_size,
test_data=test_data,
data_split_stratify=data_split_stratify,
data_split_shuffle=data_split_shuffle,
)

self._prepare_folds(
fold_strategy=fold_strategy,
fold=fold,
fold_shuffle=fold_shuffle,
fold_groups=fold_groups,
data_split_shuffle=data_split_shuffle,
)

self._prepare_train_test(
train_size=train_size,
test_data=test_data,
data_split_stratify=data_split_stratify,
data_split_shuffle=data_split_shuffle,
)

self._prepare_column_types(
Expand Down
10 changes: 6 additions & 4 deletions pycaret/internal/preprocess/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def _prepare_column_types(
# Features to keep during all preprocessing
self._fxs["Keep"] = keep_features or []

def _prepare_folds(self, fold_strategy, fold, fold_shuffle, fold_groups):
def _prepare_folds(
self, fold_strategy, fold, fold_shuffle, fold_groups, data_split_shuffle
):
"""Assign the fold strategy."""
self.logger.info("Set up folding strategy.")
allowed_fold_strategy = ["kfold", "stratifiedkfold", "groupkfold", "timeseries"]
Expand All @@ -344,11 +346,11 @@ def _prepare_folds(self, fold_strategy, fold, fold_shuffle, fold_groups):
)

if fold_strategy == "timeseries" or isinstance(fold_strategy, TimeSeriesSplit):
if fold_shuffle:
if fold_shuffle or data_split_shuffle:
raise ValueError(
"Invalid value for the fold_strategy parameter. 'timeseries' "
"requires 'data_split_shuffle' to be False as it can lead to "
"unexpected data split."
"requires 'data_split_shuffle' and 'fold_shuffle' to be False "
"as it can lead to unexpected data split."
)

if isinstance(fold_groups, str):
Expand Down
15 changes: 8 additions & 7 deletions pycaret/regression/oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,18 +714,19 @@ def setup(
self.data_split_stratify = data_split_stratify
self.data_split_shuffle = data_split_shuffle

self._prepare_train_test(
train_size=train_size,
test_data=test_data,
data_split_stratify=data_split_stratify,
data_split_shuffle=data_split_shuffle,
)

self._prepare_folds(
fold_strategy=fold_strategy,
fold=fold,
fold_shuffle=fold_shuffle,
fold_groups=fold_groups,
data_split_shuffle=data_split_shuffle,
)

self._prepare_train_test(
train_size=train_size,
test_data=test_data,
data_split_stratify=data_split_stratify,
data_split_shuffle=data_split_shuffle,
)

self._prepare_column_types(
Expand Down

0 comments on commit 9bfe756

Please sign in to comment.