diff --git a/sktime/annotation/clasp.py b/sktime/annotation/clasp.py index 573159f2897..3bb6d0c2d93 100644 --- a/sktime/annotation/clasp.py +++ b/sktime/annotation/clasp.py @@ -226,10 +226,11 @@ class ClaSPSegmentation(BaseSeriesAnnotator): "python_dependencies": "numba", } # for unit test cases - def __init__(self, period_length=10, n_cps=1, exclusion_radius=0.05): + def __init__(self, period_length=10, n_cps=1, fmt="sparse", exclusion_radius=0.05): self.period_length = int(period_length) self.n_cps = n_cps self.exclusion_radius = exclusion_radius + self.fmt = fmt super().__init__() def _fit(self, X, Y=None): @@ -261,7 +262,12 @@ def _predict(self, X): Y : pd.Series or an IntervalSeries Change points in sequence X. """ - return self._predict_points(X) + change_points = self._predict_points(X) + if self.fmt == "dense": + return self.change_points_to_segments( + change_points, X.index.min(), X.index.max() + ) + return change_points def _predict_points(self, X): """Predict changepoints on test/deployment data. diff --git a/sktime/annotation/tests/test_clasp.py b/sktime/annotation/tests/test_clasp.py index 6e027d4c0c2..3f6d17ccb38 100644 --- a/sktime/annotation/tests/test_clasp.py +++ b/sktime/annotation/tests/test_clasp.py @@ -46,13 +46,11 @@ def test_clasp_dense(): ts, period_size, cps = load_gun_point_segmentation() # compute a ClaSP segmentation - clasp = ClaSPSegmentation(period_size, n_cps=1) + clasp = ClaSPSegmentation(period_size, n_cps=1, fmt="dense") clasp.fit(ts) - segmentation = clasp.transform(ts) + segmentation = clasp.predict(ts) - # Find the index of the first 1 - cp_index = segmentation.index[segmentation.values == 1][0] _, profile = clasp.predict_scores(ts) - assert len(segmentation) == len(ts) and cp_index == 893 + assert len(segmentation) == 2 and segmentation.index[0].right == 893 assert np.argmax(profile) == 893