/
slicetools.py
234 lines (189 loc) · 6.54 KB
/
slicetools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
Utility functions for working with Python slice objects
"""
#***************************************************************************************************
# Copyright 2015, 2019 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights
# in this software.
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************
import numpy as _np
def length(s):
"""
Returns the length (the number of indices) contained in a slice.
Parameters
----------
s : slice
The slice to operate upon.
Returns
-------
int
"""
if not isinstance(s, slice): return len(s)
if s.start is None or s.stop is None:
return 0
if s.step is None:
return s.stop - s.start
else:
return len(range(s.start, s.stop, s.step))
def shift(s, offset):
"""
Returns a new slice whose start and stop points are shifted by `offset`.
Parameters
----------
s : slice
The slice to operate upon.
offset : int
The amount to shift the start and stop members of `s`.
Returns
-------
slice
"""
if s == slice(0, 0, None): return s # special "null slice": shifted(null_slice) == null_slice
return slice(s.start + offset, s.stop + offset, s.step)
def intersect(s1, s2):
"""
Returns the intersection of two slices (which must have the same step).
Parameters
----------
s1 : slice
First slice.
s2 : slice
Second slice.
Returns
-------
slice
"""
assert (s1.step is None and s2.step is None) or s1.step == s2.step, \
"Only implemented for same-step slices"
if s1.start is None:
start = s2.start
elif s2.start is None:
start = s1.start
else:
start = max(s1.start, s2.start)
if s1.stop is None:
stop = s2.stop
elif s2.stop is None:
stop = s1.stop
else:
stop = min(s1.stop, s2.stop)
if stop is not None and start is not None and stop < start:
stop = start
return slice(start, stop, s1.step)
def indices(s, n=None):
"""
Returns a list of the indices specified by slice `s`.
Parameters
----------
s : slice
The slice to operate upon.
n : int, optional
The number of elements in the array being indexed,
used for computing *negative* start/stop points.
Returns
-------
list of ints
"""
if s.start is None and s.stop is None:
return []
if s.start is None:
start = 0
elif s.start < 0:
assert(n is not None), "Must supply `n` to obtain indices of a slice with negative start point!"
start = n + s.start
else: start = s.start
if s.stop is None:
assert(n is not None), "Must supply `n` to obtain indices of a slice with unspecified stop point!"
stop = n
elif s.stop < 0:
assert(n is not None), "Must supply `n` to obtain indices of a slice with negative stop point!"
stop = n + s.stop
else: stop = s.stop
if s.step is None:
return list(range(start, stop))
return list(range(start, stop, s.step))
def list_to_slice(lst, array_ok=False, require_contiguous=True):
"""
Returns a slice corresponding to a given list of (integer) indices, if this is possible.
If not, `array_ok` determines the behavior.
Parameters
----------
lst : list
The list of integers to convert to a slice (must be contiguous
if `require_contiguous == True`).
array_ok : bool, optional
If True, an integer array (of type `numpy.ndarray`) is returned
when `lst` does not correspond to a single slice. Otherwise,
an AssertionError is raised.
require_contiguous : bool, optional
If True, then lst will only be converted to a contiguous (step=1)
slice, otherwise either a ValueError is raised (if `array_ok`
is False) or an array is returned.
Returns
-------
numpy.ndarray or slice
"""
if isinstance(lst, slice):
if require_contiguous:
if not(lst.step is None or lst.step == 1):
if array_ok:
return _np.array(range(lst.start, lst.stop, 1 if (lst.step is None) else lst.step))
else:
raise ValueError("Slice must be contiguous!")
return lst
if lst is None or len(lst) == 0: return slice(0, 0)
start = lst[0]
if len(lst) == 1: return slice(start, start + 1)
step = lst[1] - lst[0]; stop = start + step * len(lst)
if list(lst) == list(range(start, stop, step)):
if require_contiguous and step != 1:
if array_ok: return _np.array(lst, _np.int64)
else: raise ValueError("Slice must be contiguous (or array_ok must be True)!")
if step == 1: step = None
return slice(start, stop, step)
elif array_ok:
return _np.array(lst, _np.int64)
else:
raise ValueError("List does not correspond to a slice!")
def to_array(slc_or_list_like):
"""
Returns `slc_or_list_like` as an index array (an integer numpy.ndarray).
Parameters
----------
slc_or_list_like : slice or list
A slice, list, or array.
Returns
-------
numpy.ndarray
"""
if isinstance(slc_or_list_like, slice):
return _np.array(indices(slc_or_list_like), _np.int64)
else:
return _np.array(slc_or_list_like, _np.int64)
def divide(slc, max_len):
"""
Divides a slice into sub-slices based on a maximum length (for each sub-slice).
For example:
`divide(slice(0,10,2), 2) == [slice(0,4,2), slice(4,8,2), slice(8,10,2)]`
Parameters
----------
slc : slice
The slice to divide
max_len : int
The maximum length (i.e. number of indices) allowed in a sub-slice.
Returns
-------
list of slices
"""
sub_slices = []
sub_start = slc.start
step = 1 if (slc.step is None) else slc.step
while sub_start < slc.stop:
# Note: len(range(start,stop,step)) == stop-start+(step-1) // step
sub_slices.append(slice(sub_start, min(sub_start + max_len * step, slc.stop),
slc.step))
sub_start += max_len * step
return sub_slices