Skip to content

Commit

Permalink
[ENH] clusterer test scenario with unequal length time series; fix cl…
Browse files Browse the repository at this point in the history
…usterer tags (#6277)

Fixes #6276:

* corrects tag `capability:unequal_length` for `TimeSeriesKMeansTslearn`
(to `False`)
* adds a clusterer test scenario with unequal length time series panels
  • Loading branch information
fkiraly committed Apr 14, 2024
1 parent 0185d5c commit 66176bf
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sktime/clustering/k_means/_k_means_tslearn.py
Expand Up @@ -107,7 +107,7 @@ class TimeSeriesKMeansTslearn(_TslearnAdapter, BaseClusterer):
# estimator type
# --------------
"capability:multivariate": True,
"capability:unequal_length": True,
"capability:unequal_length": False,
}

# defines the name of the attribute containing the tslearn estimator
Expand Down
2 changes: 1 addition & 1 deletion sktime/clustering/k_shapes.py
Expand Up @@ -57,7 +57,7 @@ class TimeSeriesKShapes(_TslearnAdapter, BaseClusterer):
# estimator type
# --------------
"capability:multivariate": True,
"capability:unequal_length": True,
"capability:unequal_length": False,
}

# defines the name of the attribute containing the tslearn estimator
Expand Down
80 changes: 79 additions & 1 deletion sktime/utils/_testing/scenarios_clustering.py
Expand Up @@ -7,7 +7,11 @@

__all__ = ["scenarios_clustering"]

from inspect import isclass

from sktime.base import BaseObject
from sktime.registry import scitype
from sktime.utils._testing.hierarchical import _make_hierarchical
from sktime.utils._testing.panel import _make_panel_X, make_clustering_problem
from sktime.utils._testing.scenarios import TestScenario

Expand Down Expand Up @@ -43,11 +47,52 @@ def get_args(self, key, obj=None, deepcopy_args=True):

return super().get_args(key=key, obj=obj, deepcopy_args=deepcopy_args)

def is_applicable(self, obj):
"""Check whether scenario is applicable to obj.
Parameters
----------
obj : class or object to check against scenario
Returns
-------
applicable: bool
True if self is applicable to obj, False if not
"""

def get_tag(obj, tag_name):
if isclass(obj):
return obj.get_class_tag(tag_name)
else:
return obj.get_tag(tag_name)

# applicable only if obj inherits from BaseClassifier, BaseEarlyClassifier or
# BaseRegressor. currently we test both classifiers and regressors using these
# scenarios
if scitype(obj) != "clusterer":
return False

# if X is multivariate, applicable only if can handle multivariate
is_multivariate = not self.get_tag("X_univariate")
if is_multivariate and not get_tag(obj, "capability:multivariate"):
return False

# if X is unequal length, applicable only if can handle unequal length
is_unequal_length = self.get_tag("X_unequal_length")
if is_unequal_length and not get_tag(obj, "capability:unequal_length"):
return False

return True


class ClustererFitPredict(ClustererTestScenario):
"""Fit/predict with panel Xmake_clustering_problem."""

_tags = {"X_univariate": True, "is_enabled": True}
_tags = {
"X_univariate": True,
"X_unequal_length": False,
"is_enabled": True,
}

@property
def args(self):
Expand All @@ -59,6 +104,39 @@ def args(self):
default_method_sequence = ["fit", "predict"]


class ClustererFitPredictUnequalLength(ClustererTestScenario):
"""Fit/predict with univariate panel X, unequal length series."""

_tags = {
"X_univariate": True,
"X_unequal_length": True,
"is_enabled": True,
}

@property
def args(self):
X_unequal_length = _make_hierarchical(
hierarchy_levels=(10,),
min_timepoints=10,
max_timepoints=15,
random_state=RAND_SEED,
)
X_unequal_length_test = _make_hierarchical(
hierarchy_levels=(5,),
min_timepoints=10,
max_timepoints=15,
random_state=RAND_SEED,
)
return {
"fit": {"X": X_unequal_length},
"predict": {"X": X_unequal_length_test},
}

default_method_sequence = ["fit", "predict", "predict_proba", "decision_function"]
default_arg_sequence = ["fit", "predict", "predict", "predict"]


scenarios_clustering = [
ClustererFitPredict,
ClustererFitPredictUnequalLength,
]

0 comments on commit 66176bf

Please sign in to comment.