From c5f425d4dabe7814476cd83802dc1aa0d9622a73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 23 Apr 2024 22:16:26 +0100 Subject: [PATCH] [BUG] fix dependent tags of `TimeSeriesDBSCAN` (#6322) `TimeSeriesDBSCAN` does not set its tags correctly if one of the `numba` based distance in `sktime` is used. In this case, it cannot support unequal length time series, and the corresponding tag needs to be set. Similarly, the tags for some of the `numba` distances were set incorrectly, which is now corrected as well. --- sktime/clustering/dbscan.py | 9 +++++++++ sktime/dists_kernels/dtw/_dtw_sktime.py | 1 + sktime/dists_kernels/edit_dist.py | 1 + 3 files changed, 11 insertions(+) diff --git a/sktime/clustering/dbscan.py b/sktime/clustering/dbscan.py index bd531cc34c1..2b431cb8424 100644 --- a/sktime/clustering/dbscan.py +++ b/sktime/clustering/dbscan.py @@ -104,6 +104,15 @@ def __init__( ] self.clone_tags(distance, tags_to_clone) + # numba distance in sktime (indexed by string) + # cannot support unequal length data, and require numpy3D input + if isinstance(distance, str): + tags_to_set = { + "X_inner_mtype": "numpy3D", + "capability:unequal_length": False, + } + self.set_tags(**tags_to_set) + self.dbscan_ = None def _fit(self, X, y=None): diff --git a/sktime/dists_kernels/dtw/_dtw_sktime.py b/sktime/dists_kernels/dtw/_dtw_sktime.py index caa030f46c3..dab9d983d48 100644 --- a/sktime/dists_kernels/dtw/_dtw_sktime.py +++ b/sktime/dists_kernels/dtw/_dtw_sktime.py @@ -131,6 +131,7 @@ class DtwDist(BasePairwiseTransformerPanel): # -------------- "symmetric": True, # all the distances are symmetric "X_inner_mtype": "numpy3D", + "capability:unequal_length": False, # can dist handle unequal length panels? } def __init__( diff --git a/sktime/dists_kernels/edit_dist.py b/sktime/dists_kernels/edit_dist.py index 06638c193e1..66c79e4e2c5 100644 --- a/sktime/dists_kernels/edit_dist.py +++ b/sktime/dists_kernels/edit_dist.py @@ -123,6 +123,7 @@ class EditDist(BasePairwiseTransformerPanel): # -------------- "symmetric": True, # all the distances are symmetric "X_inner_mtype": "numpy3D", + "capability:unequal_length": False, # can dist handle unequal length panels? } ALLOWED_DISTANCE_STR = ["lcss", "edr", "erp", "twe"]