Skip to content

Commit

Permalink
Add fmt argument back in for backward compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-JG3 committed May 20, 2024
1 parent b8f8ee6 commit d8c865d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
10 changes: 8 additions & 2 deletions sktime/annotation/clasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ class ClaSPSegmentation(BaseSeriesAnnotator):
"python_dependencies": "numba",
} # for unit test cases

def __init__(self, period_length=10, n_cps=1, exclusion_radius=0.05):
def __init__(self, period_length=10, n_cps=1, fmt="sparse", exclusion_radius=0.05):
self.period_length = int(period_length)
self.n_cps = n_cps
self.exclusion_radius = exclusion_radius
self.fmt = fmt
super().__init__()

def _fit(self, X, Y=None):
Expand Down Expand Up @@ -261,7 +262,12 @@ def _predict(self, X):
Y : pd.Series or an IntervalSeries
Change points in sequence X.
"""
return self._predict_points(X)
change_points = self._predict_points(X)
if self.fmt == "dense":
return self.change_points_to_segments(
change_points, X.index.min(), X.index.max()
)
return change_points

def _predict_points(self, X):
"""Predict changepoints on test/deployment data.
Expand Down
8 changes: 3 additions & 5 deletions sktime/annotation/tests/test_clasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,11 @@ def test_clasp_dense():
ts, period_size, cps = load_gun_point_segmentation()

# compute a ClaSP segmentation
clasp = ClaSPSegmentation(period_size, n_cps=1)
clasp = ClaSPSegmentation(period_size, n_cps=1, fmt="dense")
clasp.fit(ts)
segmentation = clasp.transform(ts)
segmentation = clasp.predict(ts)

# Find the index of the first 1
cp_index = segmentation.index[segmentation.values == 1][0]
_, profile = clasp.predict_scores(ts)

assert len(segmentation) == len(ts) and cp_index == 893
assert len(segmentation) == 2 and segmentation.index[0].right == 893
assert np.argmax(profile) == 893

0 comments on commit d8c865d

Please sign in to comment.