Skip to content

Commit

Permalink
[ENH] allow inclusive/exclusive bounds in get_slice (#4483)
Browse files Browse the repository at this point in the history
This PR adds two arguments to `get_slice` to allow exclusive/inclusive
bounding at either and of the slice.

The default of the two new arguments leads to current behaviour, so no
deprecation is necessary.

This will be useful as a utility, and also for fixing
#3667
  • Loading branch information
fkiraly committed Apr 21, 2023
1 parent 750f02b commit 9bae78d
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions sktime/datatypes/_utilities.py
Expand Up @@ -501,7 +501,7 @@ def get_window(obj, window_length=None, lag=None):
)


def get_slice(obj, start=None, end=None):
def get_slice(obj, start=None, end=None, start_inclusive=True, end_inclusive=False):
"""Slice obj with start (inclusive) and end (exclusive) indices.
Returns time series or time series panel with time indices
Expand All @@ -522,9 +522,16 @@ def get_slice(obj, start=None, end=None):
must be int if obj is int indexed, timestamp if datetime indexed
Exclusive end of slice. Default = None
If None, then no slice at the end
start_inclusive : bool, optional, default = True
whether start index is inclusive (True) or not (False)
end_inclusive : bool, optional, default = False
whether end index is inclusive (True) or not (False)
Returns
-------
obj sub-set sliced for `start` (inclusive) and `end` (exclusive) indices
obj sub-set sliced for `start` to `end`, default is start/inclusive, end/exclusive
contains all indices from `start` (in- or exclusive as per `start_inclusive`)
up until `end` (in- or exclusive as per `end_inclusive`)
None if obj was None
"""
from sktime.datatypes import check_is_scitype, convert_to
Expand All @@ -550,6 +557,16 @@ def get_slice(obj, start=None, end=None):
# and always subset on first dimension
if obj.ndim > 1:
obj = obj.swapaxes(1, -1)
# deal with inclusive/exclusive
if not start_inclusive:
start = start + 1
if end_inclusive:
end = end + 1
# deal with out-of-index
if start < 0:
start = 0
if start >= len(obj):
start = len(obj) - 1
# subsetting
if start and end:
obj_subset = obj[start:end]
Expand All @@ -571,12 +588,24 @@ def get_slice(obj, start=None, end=None):
else:
time_indices = obj.index.get_level_values(-1)

def get_start_cond():
if start_inclusive:
return time_indices >= start
else:
return time_indices > start

def get_end_cond():
if end_inclusive:
return time_indices <= end
else:
return time_indices < end

if start and end:
slice_select = (time_indices >= start) & (time_indices < end)
slice_select = get_start_cond() & get_end_cond()
elif end:
slice_select = time_indices < end
slice_select = get_end_cond()
elif start:
slice_select = time_indices >= start
slice_select = get_start_cond()

obj_subset = obj.iloc[slice_select]
return convert_to(obj_subset, obj_in_mtype)
Expand Down

0 comments on commit 9bae78d

Please sign in to comment.