-
Notifications
You must be signed in to change notification settings - Fork 4
/
_array_split.py
60 lines (49 loc) · 1.88 KB
/
_array_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import dask.array
def _array_split(ary, indices_or_sections, axis=0):
"""
*** From NumPy, adapted for Dask Array's
Split an array into multiple sub-arrays.
Please refer to the ``split`` documentation. The only difference
between these functions is that ``array_split`` allows
`indices_or_sections` to be an integer that does *not* equally
divide the axis. For an array of length l that should be split
into n sections, it returns l % n sub-arrays of size l//n + 1
and the rest of size l//n.
See Also
--------
split : Split array into multiple sub-arrays of equal size.
Examples
--------
>>> x = np.arange(8.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])]
>>> x = np.arange(9)
>>> np.array_split(x, 4)
[array([0, 1, 2]), array([3, 4]), array([5, 6]), array([7, 8])]
"""
try:
Ntotal = ary.shape[axis]
except AttributeError:
Ntotal = len(ary)
try:
# handle array case.
Nsections = len(indices_or_sections) + 1
div_points = [0] + list(indices_or_sections) + [Ntotal]
except TypeError:
# indices_or_sections is a scalar, not an array.
Nsections = int(indices_or_sections)
if Nsections <= 0:
raise ValueError('number sections must be larger than 0.') from None
Neach_section, extras = divmod(Ntotal, Nsections)
section_sizes = ([0] +
extras * [Neach_section+1] +
(Nsections-extras) * [Neach_section])
div_points = np.array(section_sizes, dtype=np.intp).cumsum()
sub_arys = []
sary = dask.array.swapaxes(ary, axis, 0)
for i in range(Nsections):
st = div_points[i]
end = div_points[i + 1]
sub_arys.append(dask.array.swapaxes(sary[st:end], axis, 0))
return sub_arys