Skip to content

Commit

Permalink
[MNT] speed up various non-suite tests, part 2 (#5071)
Browse files Browse the repository at this point in the history
This PR speeds up various non-suite tests:

* `test_gscv_hierarchical`, by shortening the hierarchical data used in
the test
* tests for classifiers in `test_sklearn_compatibility`, by using a
smaller data set
* `test_docs_tsfresh_extractor`, by using a smaller data set
  • Loading branch information
fkiraly committed Aug 13, 2023
1 parent 08e2d66 commit 1d299b4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
20 changes: 11 additions & 9 deletions sktime/classification/tests/test_sklearn_compatability.py
Expand Up @@ -17,8 +17,6 @@
GridSearchCV,
GroupKFold,
GroupShuffleSplit,
HalvingGridSearchCV,
HalvingRandomSearchCV,
KFold,
LeaveOneOut,
LeavePGroupsOut,
Expand All @@ -31,6 +29,10 @@
TimeSeriesSplit,
cross_val_score,
)

# removed due to too small data:
# HalvingGridSearchCV,
# HalvingRandomSearchCV,
from sklearn.pipeline import Pipeline

from sktime.classification.interval_based import CanonicalIntervalForest
Expand All @@ -39,29 +41,29 @@
from sktime.utils.validation._dependencies import _check_soft_dependencies

DATA_ARGS = [
{"return_numpy": True, "n_columns": 2},
{"return_numpy": False, "n_columns": 2},
{"return_numpy": True, "n_columns": 2, "n_instances": 7, "n_timepoints": 12},
{"return_numpy": False, "n_columns": 2, "n_instances": 7, "n_timepoints": 12},
]

# StratifiedGroupKFold(n_splits=2), removed because it is not available in sklearn 0.24
CROSS_VALIDATION_METHODS = [
KFold(n_splits=2),
RepeatedKFold(n_splits=2, n_repeats=2),
LeaveOneOut(),
LeavePOut(p=5),
LeavePOut(p=2),
ShuffleSplit(n_splits=2, test_size=0.25),
StratifiedKFold(n_splits=2),
StratifiedShuffleSplit(n_splits=2, test_size=0.25),
GroupKFold(n_splits=2),
LeavePGroupsOut(n_groups=5),
LeavePGroupsOut(n_groups=2),
GroupShuffleSplit(n_splits=2, test_size=0.25),
TimeSeriesSplit(n_splits=2),
]
PARAMETER_TUNING_METHODS = [
GridSearchCV,
RandomizedSearchCV,
HalvingGridSearchCV,
HalvingRandomSearchCV,
# HalvingGridSearchCV,
# HalvingRandomSearchCV,
]

if _check_soft_dependencies("numba", severity="none"):
Expand Down Expand Up @@ -107,7 +109,7 @@ def test_sklearn_cross_validation(data_args):
def test_sklearn_cross_validation_iterators(data_args, cross_validation_method):
"""Test if sklearn cross-validation iterators can handle sktime panel data."""
fit_args = make_classification_problem(**data_args)
groups = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10]
groups = [1, 1, 2, 2, 3, 3, 4]

for train, test in cross_validation_method.split(*fit_args, groups=groups):
assert isinstance(train, np.ndarray) and isinstance(test, np.ndarray)
Expand Down
10 changes: 5 additions & 5 deletions sktime/forecasting/model_selection/tests/test_tune.py
Expand Up @@ -77,14 +77,14 @@ def _create_hierarchical_data():
y = _make_hierarchical(
random_state=TEST_RANDOM_SEEDS[0],
hierarchy_levels=(2, 2),
min_timepoints=20,
max_timepoints=20,
min_timepoints=15,
max_timepoints=15,
)
X = _make_hierarchical(
random_state=TEST_RANDOM_SEEDS[1],
hierarchy_levels=(2, 2),
min_timepoints=20,
max_timepoints=20,
min_timepoints=15,
max_timepoints=15,
)
return y, X

Expand All @@ -103,7 +103,7 @@ def _create_hierarchical_data():
}
CVs = [
*[SingleWindowSplitter(fh=fh) for fh in TEST_OOS_FHS],
SlidingWindowSplitter(fh=1, initial_window=15),
SlidingWindowSplitter(fh=1, initial_window=12, step_length=3),
]
ERROR_SCORES = [np.nan, "raise", 1000]

Expand Down
7 changes: 3 additions & 4 deletions sktime/transformations/panel/tests/test_tsfresh.py
Expand Up @@ -39,12 +39,11 @@ def test_tsfresh_extractor(default_fc_parameters):
)
def test_docs_tsfresh_extractor():
"""Test whether doc example runs through."""
X, y = load_arrow_head(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
X, _ = load_arrow_head(return_X_y=True)[:3]
ts_eff = TSFreshFeatureExtractor(
default_fc_parameters="efficient", disable_progressbar=True
)
ts_eff.fit_transform(X_train)
ts_eff.fit_transform(X)
features_to_calc = [
"dim_0__quantile__q_0.6",
"dim_0__longest_strike_above_mean",
Expand All @@ -53,7 +52,7 @@ def test_docs_tsfresh_extractor():
ts_custom = TSFreshFeatureExtractor(
kind_to_fc_parameters=features_to_calc, disable_progressbar=True
)
ts_custom.fit_transform(X_train)
ts_custom.fit_transform(X)


@pytest.mark.skipif(
Expand Down

0 comments on commit 1d299b4

Please sign in to comment.