Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] speed up clustering dunder test #4982

Merged
merged 2 commits into from Aug 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 9 additions & 23 deletions sktime/clustering/compose/tests/test_pipeline.py
Expand Up @@ -4,34 +4,28 @@
__author__ = ["fkiraly"]
__all__ = []

import pytest
from sklearn.preprocessing import StandardScaler

from sktime.clustering.compose import ClustererPipeline
from sktime.clustering.k_means import TimeSeriesKMeans
from sktime.clustering.kernel_k_means import TimeSeriesKernelKMeans
from sktime.clustering.dbscan import TimeSeriesDBSCAN
from sktime.dists_kernels import FlatDist
from sktime.transformations.panel.padder import PaddingTransformer
from sktime.transformations.series.exponent import ExponentTransformer
from sktime.transformations.series.impute import Imputer
from sktime.utils._testing.estimator_checks import _assert_array_almost_equal
from sktime.utils._testing.panel import _make_panel_X
from sktime.utils.validation._dependencies import _check_estimator_deps


@pytest.mark.skipif(
not _check_estimator_deps(TimeSeriesKernelKMeans, severity="none"),
reason="skip test if required soft dependencies not available",
)
def test_dunder_mul():
"""Test the mul dunder method."""
RAND_SEED = 42
X = _make_panel_X(n_instances=10, n_timepoints=20, random_state=RAND_SEED)
X_test = _make_panel_X(n_instances=5, n_timepoints=20, random_state=RAND_SEED)
X = _make_panel_X(n_instances=10, n_timepoints=12, random_state=RAND_SEED)
X_test = X

t1 = ExponentTransformer(power=4)
t2 = ExponentTransformer(power=0.25)

c = TimeSeriesKernelKMeans(random_state=RAND_SEED)
c = TimeSeriesDBSCAN(FlatDist.create_test_instance(), eps=4, min_samples=1)
t12c_1 = t1 * (t2 * c)
t12c_2 = (t1 * t2) * c
t12c_3 = t1 * t2 * c
Expand All @@ -47,19 +41,15 @@ def test_dunder_mul():
_assert_array_almost_equal(y_pred, t12c_3.fit(X).predict(X_test))


@pytest.mark.skipif(
not _check_estimator_deps(TimeSeriesKMeans, severity="none"),
reason="skip test if required soft dependencies not available",
)
def test_mul_sklearn_autoadapt():
"""Test auto-adapter for sklearn in mul."""
RAND_SEED = 42
X = _make_panel_X(n_instances=10, n_timepoints=20, random_state=RAND_SEED)
X_test = _make_panel_X(n_instances=10, n_timepoints=20, random_state=RAND_SEED)
X = _make_panel_X(n_instances=10, n_timepoints=12, random_state=RAND_SEED)
X_test = X

t1 = ExponentTransformer(power=2)
t2 = StandardScaler()
c = TimeSeriesKMeans(random_state=RAND_SEED)
c = TimeSeriesDBSCAN(FlatDist.create_test_instance(), eps=4, min_samples=1)

t12c_1 = t1 * (t2 * c)
t12c_2 = (t1 * t2) * c
Expand All @@ -75,13 +65,9 @@ def test_mul_sklearn_autoadapt():
_assert_array_almost_equal(y_pred, t12c_3.fit(X).predict(X_test))


@pytest.mark.skipif(
not _check_estimator_deps(TimeSeriesKMeans, severity="none"),
reason="skip test if required soft dependencies not available",
)
def test_missing_unequal_tag_inference():
"""Test that ClustererPipeline infers missing/unequal tags correctly."""
c = TimeSeriesKMeans()
c = TimeSeriesDBSCAN(FlatDist.create_test_instance())
c1 = ExponentTransformer() * PaddingTransformer() * ExponentTransformer() * c
c2 = ExponentTransformer() * ExponentTransformer() * c
c3 = Imputer() * ExponentTransformer() * c
Expand Down