Skip to content

Commit

Permalink
[MNT] speed up clustering dunder test (#4982)
Browse files Browse the repository at this point in the history
This PR speeds up a costly test for the clustering estimator dunders, by
using a much smaller dataset.
The test time was >1 min and should now be in the range of seconds at
most.
Related to: #2890

Also removes the dependency on `numba` by using a clusterer that depends
only on `sklearn`.

Test coverage does not change, as the requirements are only that the
estimator is a clusterer, and produces a non-trivial clustering (ensured
by the choice of parameters and random seeds).
  • Loading branch information
fkiraly committed Aug 3, 2023
1 parent 1d8d6de commit 3e3d98c
Showing 1 changed file with 9 additions and 23 deletions.
32 changes: 9 additions & 23 deletions sktime/clustering/compose/tests/test_pipeline.py
Expand Up @@ -4,34 +4,28 @@
__author__ = ["fkiraly"]
__all__ = []

import pytest
from sklearn.preprocessing import StandardScaler

from sktime.clustering.compose import ClustererPipeline
from sktime.clustering.k_means import TimeSeriesKMeans
from sktime.clustering.kernel_k_means import TimeSeriesKernelKMeans
from sktime.clustering.dbscan import TimeSeriesDBSCAN
from sktime.dists_kernels import FlatDist
from sktime.transformations.panel.padder import PaddingTransformer
from sktime.transformations.series.exponent import ExponentTransformer
from sktime.transformations.series.impute import Imputer
from sktime.utils._testing.estimator_checks import _assert_array_almost_equal
from sktime.utils._testing.panel import _make_panel_X
from sktime.utils.validation._dependencies import _check_estimator_deps


@pytest.mark.skipif(
not _check_estimator_deps(TimeSeriesKernelKMeans, severity="none"),
reason="skip test if required soft dependencies not available",
)
def test_dunder_mul():
"""Test the mul dunder method."""
RAND_SEED = 42
X = _make_panel_X(n_instances=10, n_timepoints=20, random_state=RAND_SEED)
X_test = _make_panel_X(n_instances=5, n_timepoints=20, random_state=RAND_SEED)
X = _make_panel_X(n_instances=10, n_timepoints=12, random_state=RAND_SEED)
X_test = X

t1 = ExponentTransformer(power=4)
t2 = ExponentTransformer(power=0.25)

c = TimeSeriesKernelKMeans(random_state=RAND_SEED)
c = TimeSeriesDBSCAN(FlatDist.create_test_instance(), eps=4, min_samples=1)
t12c_1 = t1 * (t2 * c)
t12c_2 = (t1 * t2) * c
t12c_3 = t1 * t2 * c
Expand All @@ -47,19 +41,15 @@ def test_dunder_mul():
_assert_array_almost_equal(y_pred, t12c_3.fit(X).predict(X_test))


@pytest.mark.skipif(
not _check_estimator_deps(TimeSeriesKMeans, severity="none"),
reason="skip test if required soft dependencies not available",
)
def test_mul_sklearn_autoadapt():
"""Test auto-adapter for sklearn in mul."""
RAND_SEED = 42
X = _make_panel_X(n_instances=10, n_timepoints=20, random_state=RAND_SEED)
X_test = _make_panel_X(n_instances=10, n_timepoints=20, random_state=RAND_SEED)
X = _make_panel_X(n_instances=10, n_timepoints=12, random_state=RAND_SEED)
X_test = X

t1 = ExponentTransformer(power=2)
t2 = StandardScaler()
c = TimeSeriesKMeans(random_state=RAND_SEED)
c = TimeSeriesDBSCAN(FlatDist.create_test_instance(), eps=4, min_samples=1)

t12c_1 = t1 * (t2 * c)
t12c_2 = (t1 * t2) * c
Expand All @@ -75,13 +65,9 @@ def test_mul_sklearn_autoadapt():
_assert_array_almost_equal(y_pred, t12c_3.fit(X).predict(X_test))


@pytest.mark.skipif(
not _check_estimator_deps(TimeSeriesKMeans, severity="none"),
reason="skip test if required soft dependencies not available",
)
def test_missing_unequal_tag_inference():
"""Test that ClustererPipeline infers missing/unequal tags correctly."""
c = TimeSeriesKMeans()
c = TimeSeriesDBSCAN(FlatDist.create_test_instance())
c1 = ExponentTransformer() * PaddingTransformer() * ExponentTransformer() * c
c2 = ExponentTransformer() * ExponentTransformer() * c
c3 = Imputer() * ExponentTransformer() * c
Expand Down

0 comments on commit 3e3d98c

Please sign in to comment.