diff --git a/sktime/datatypes/_utilities.py b/sktime/datatypes/_utilities.py index a834550011d..ddc77a02a12 100644 --- a/sktime/datatypes/_utilities.py +++ b/sktime/datatypes/_utilities.py @@ -501,7 +501,7 @@ def get_window(obj, window_length=None, lag=None): ) -def get_slice(obj, start=None, end=None): +def get_slice(obj, start=None, end=None, start_inclusive=True, end_inclusive=False): """Slice obj with start (inclusive) and end (exclusive) indices. Returns time series or time series panel with time indices @@ -522,9 +522,16 @@ def get_slice(obj, start=None, end=None): must be int if obj is int indexed, timestamp if datetime indexed Exclusive end of slice. Default = None If None, then no slice at the end + start_inclusive : bool, optional, default = True + whether start index is inclusive (True) or not (False) + end_inclusive : bool, optional, default = False + whether end index is inclusive (True) or not (False) + Returns ------- - obj sub-set sliced for `start` (inclusive) and `end` (exclusive) indices + obj sub-set sliced for `start` to `end`, default is start/inclusive, end/exclusive + contains all indices from `start` (in- or exclusive as per `start_inclusive`) + up until `end` (in- or exclusive as per `end_inclusive`) None if obj was None """ from sktime.datatypes import check_is_scitype, convert_to @@ -550,6 +557,16 @@ def get_slice(obj, start=None, end=None): # and always subset on first dimension if obj.ndim > 1: obj = obj.swapaxes(1, -1) + # deal with inclusive/exclusive + if not start_inclusive: + start = start + 1 + if end_inclusive: + end = end + 1 + # deal with out-of-index + if start < 0: + start = 0 + if start >= len(obj): + start = len(obj) - 1 # subsetting if start and end: obj_subset = obj[start:end] @@ -571,12 +588,24 @@ def get_slice(obj, start=None, end=None): else: time_indices = obj.index.get_level_values(-1) + def get_start_cond(): + if start_inclusive: + return time_indices >= start + else: + return time_indices > start + + def get_end_cond(): + if end_inclusive: + return time_indices <= end + else: + return time_indices < end + if start and end: - slice_select = (time_indices >= start) & (time_indices < end) + slice_select = get_start_cond() & get_end_cond() elif end: - slice_select = time_indices < end + slice_select = get_end_cond() elif start: - slice_select = time_indices >= start + slice_select = get_start_cond() obj_subset = obj.iloc[slice_select] return convert_to(obj_subset, obj_in_mtype)