Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] allow inclusive/exclusive bounds in get_slice #4483

Merged
merged 2 commits into from Apr 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 34 additions & 5 deletions sktime/datatypes/_utilities.py
Expand Up @@ -501,7 +501,7 @@ def get_window(obj, window_length=None, lag=None):
)


def get_slice(obj, start=None, end=None):
def get_slice(obj, start=None, end=None, start_inclusive=True, end_inclusive=False):
"""Slice obj with start (inclusive) and end (exclusive) indices.

Returns time series or time series panel with time indices
Expand All @@ -522,9 +522,16 @@ def get_slice(obj, start=None, end=None):
must be int if obj is int indexed, timestamp if datetime indexed
Exclusive end of slice. Default = None
If None, then no slice at the end
start_inclusive : bool, optional, default = True
whether start index is inclusive (True) or not (False)
end_inclusive : bool, optional, default = False
whether end index is inclusive (True) or not (False)

Returns
-------
obj sub-set sliced for `start` (inclusive) and `end` (exclusive) indices
obj sub-set sliced for `start` to `end`, default is start/inclusive, end/exclusive
contains all indices from `start` (in- or exclusive as per `start_inclusive`)
up until `end` (in- or exclusive as per `end_inclusive`)
None if obj was None
"""
from sktime.datatypes import check_is_scitype, convert_to
Expand All @@ -550,6 +557,16 @@ def get_slice(obj, start=None, end=None):
# and always subset on first dimension
if obj.ndim > 1:
obj = obj.swapaxes(1, -1)
# deal with inclusive/exclusive
if not start_inclusive:
start = start + 1
if end_inclusive:
end = end + 1
# deal with out-of-index
if start < 0:
start = 0
if start >= len(obj):
start = len(obj) - 1
# subsetting
if start and end:
obj_subset = obj[start:end]
Expand All @@ -571,12 +588,24 @@ def get_slice(obj, start=None, end=None):
else:
time_indices = obj.index.get_level_values(-1)

def get_start_cond():
if start_inclusive:
return time_indices >= start
else:
return time_indices > start

def get_end_cond():
if end_inclusive:
return time_indices <= end
else:
return time_indices < end

if start and end:
slice_select = (time_indices >= start) & (time_indices < end)
slice_select = get_start_cond() & get_end_cond()
elif end:
slice_select = time_indices < end
slice_select = get_end_cond()
elif start:
slice_select = time_indices >= start
slice_select = get_start_cond()

obj_subset = obj.iloc[slice_select]
return convert_to(obj_subset, obj_in_mtype)
Expand Down