Skip to content

Commit fa96054

Browse files
committed
Add tests for array data access /index methods
1 parent 0625f7b commit fa96054

File tree

2 files changed

+249
-58
lines changed

2 files changed

+249
-58
lines changed

tests/test_numpy_array.cpp

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,97 @@
88
*/
99

1010
#include "pybind11_tests.h"
11+
1112
#include <pybind11/numpy.h>
1213
#include <pybind11/stl.h>
1314

15+
#include <cstdint>
16+
#include <vector>
17+
18+
using arr = py::array;
19+
using arr_t = py::array_t<uint16_t, 0>;
20+
21+
template<typename... Ix> arr data(const arr& a, Ix&&... index) {
22+
return arr(a.nbytes() - a.index_at(index...), a.data(index...));
23+
}
24+
25+
template<typename... Ix> arr data_t(const arr_t& a, Ix&&... index) {
26+
return arr(a.size() - a.index_at(index...), a.data(index...));
27+
}
28+
29+
arr& mutate_data(arr& a) {
30+
auto ptr = a.mutable_data();
31+
for (size_t i = 0; i < a.nbytes(); i++)
32+
ptr[i] = (uint8_t) (ptr[i] * 2);
33+
return a;
34+
}
35+
36+
arr_t& mutate_data_t(arr_t& a) {
37+
auto ptr = a.mutable_data();
38+
for (size_t i = 0; i < a.size(); i++)
39+
ptr[i]++;
40+
return a;
41+
}
42+
43+
template<typename... Ix> arr& mutate_data(arr& a, Ix&&... index) {
44+
auto ptr = a.mutable_data(index...);
45+
for (size_t i = 0; i < a.nbytes() - a.index_at(index...); i++)
46+
ptr[i] = (uint8_t) (ptr[i] * 2);
47+
return a;
48+
}
49+
50+
template<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix&&... index) {
51+
auto ptr = a.mutable_data(index...);
52+
for (size_t i = 0; i < a.size() - a.index_at(index...); i++)
53+
ptr[i]++;
54+
return a;
55+
}
56+
1457
test_initializer numpy_array([](py::module &m) {
15-
m.def("get_arr_ndim", [](const py::array& arr) {
16-
return arr.ndim();
17-
});
18-
m.def("get_arr_shape", [](const py::array& arr) {
19-
return std::vector<size_t>(arr.shape(), arr.shape() + arr.ndim());
20-
});
21-
m.def("get_arr_shape", [](const py::array& arr, size_t dim) {
22-
return arr.shape(dim);
23-
});
24-
m.def("get_arr_strides", [](const py::array& arr) {
25-
return std::vector<size_t>(arr.strides(), arr.strides() + arr.ndim());
26-
});
27-
m.def("get_arr_strides", [](const py::array& arr, size_t dim) {
28-
return arr.strides(dim);
29-
});
30-
m.def("get_arr_writeable", [](const py::array& arr) {
31-
return arr.writeable();
32-
});
33-
m.def("get_arr_size", [](const py::array& arr) {
34-
return arr.size();
35-
});
36-
m.def("get_arr_itemsize", [](const py::array& arr) {
37-
return arr.itemsize();
38-
});
39-
m.def("get_arr_nbytes", [](const py::array& arr) {
40-
return arr.nbytes();
41-
});
42-
m.def("get_arr_owndata", [](const py::array& arr) {
43-
return arr.owndata();
44-
});
58+
auto sm = m.def_submodule("array");
59+
60+
sm.def("ndim", [](const arr& a) { return a.ndim(); });
61+
sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
62+
sm.def("shape", [](const arr& a, size_t dim) { return a.shape(dim); });
63+
sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); });
64+
sm.def("strides", [](const arr& a, size_t dim) { return a.strides(dim); });
65+
sm.def("writeable", [](const arr& a) { return a.writeable(); });
66+
sm.def("size", [](const arr& a) { return a.size(); });
67+
sm.def("itemsize", [](const arr& a) { return a.itemsize(); });
68+
sm.def("nbytes", [](const arr& a) { return a.nbytes(); });
69+
sm.def("owndata", [](const arr& a) { return a.owndata(); });
70+
sm.def("index_at", [](const arr& a) { return a.index_at(); });
71+
sm.def("index_at", [](const arr& a, int i) { return a.index_at(i); });
72+
sm.def("index_at", [](const arr& a, int i, int j) { return a.index_at(i, j); });
73+
sm.def("index_at", [](const arr& a, int i, int j, int k) { return a.index_at(i, j, k); });
74+
sm.def("data", [](const arr& a) { return data(a); });
75+
sm.def("data", [](const arr& a, int i) { return data(a, i); });
76+
sm.def("data", [](const arr& a, int i, int j) { return data(a, i, j); });
77+
sm.def("data", [](const arr& a, int i, int j, int k) { return data(a, i, j, k); });
78+
sm.def("mutate_data", [](arr& a) { return mutate_data(a); });
79+
sm.def("mutate_data", [](arr& a) { return mutate_data(a); });
80+
sm.def("mutate_data", [](arr& a, int i) { return mutate_data(a, i); });
81+
sm.def("mutate_data", [](arr& a, int i, int j) { return mutate_data(a, i, j); });
82+
sm.def("mutate_data", [](arr& a, int i, int j, int k) { return mutate_data(a, i, j, k); });
83+
sm.def("index_at_t", [](const arr_t& a) { return a.index_at(); });
84+
sm.def("index_at_t", [](const arr_t& a, int i) { return a.index_at(i); });
85+
sm.def("index_at_t", [](const arr_t& a, int i, int j) { return a.index_at(i, j); });
86+
sm.def("index_at_t", [](const arr_t& a, int i, int j, int k) { return a.index_at(i, j, k); });
87+
sm.def("data_t", [](const arr_t& a) { return data_t(a); });
88+
sm.def("data_t", [](const arr_t& a, int i) { return data_t(a, i); });
89+
sm.def("data_t", [](const arr_t& a, int i, int j) { return data_t(a, i, j); });
90+
sm.def("data_t", [](const arr_t& a, int i, int j, int k) { return data_t(a, i, j, k); });
91+
sm.def("mutate_data_t", [](arr_t& a) { return mutate_data_t(a); });
92+
sm.def("mutate_data_t", [](arr_t& a) { return mutate_data_t(a); });
93+
sm.def("mutate_data_t", [](arr_t& a, int i) { return mutate_data_t(a, i); });
94+
sm.def("mutate_data_t", [](arr_t& a, int i, int j) { return mutate_data_t(a, i, j); });
95+
sm.def("mutate_data_t", [](arr_t& a, int i, int j, int k) { return mutate_data_t(a, i, j, k); });
96+
sm.def("at_t", [](const arr_t& a) { return a.at(); });
97+
sm.def("at_t", [](const arr_t& a, int i) { return a.at(i); });
98+
sm.def("at_t", [](const arr_t& a, int i, int j) { return a.at(i, j); });
99+
sm.def("at_t", [](const arr_t& a, int i, int j, int k) { return a.at(i, j, k); });
100+
sm.def("mutate_at_t", [](arr_t& a) { a.mutable_at()++; return a; });
101+
sm.def("mutate_at_t", [](arr_t& a, int i) { a.mutable_at(i)++; return a; });
102+
sm.def("mutate_at_t", [](arr_t& a, int i, int j) { a.mutable_at(i, j)++; return a; });
103+
sm.def("mutate_at_t", [](arr_t& a, int i, int j, int k) { a.mutable_at(i, j, k)++; return a; });
45104
});

tests/test_numpy_array.py

Lines changed: 160 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,172 @@
44
import numpy as np
55

66

7+
@pytest.fixture(scope='function')
8+
def arr():
9+
return np.array([[1, 2, 3], [4, 5, 6]], '<u2')
10+
11+
712
@pytest.requires_numpy
813
def test_array_attributes():
9-
from pybind11_tests import (get_arr_ndim, get_arr_shape, get_arr_strides, get_arr_writeable,
10-
get_arr_size, get_arr_itemsize, get_arr_nbytes, get_arr_owndata)
14+
from pybind11_tests.array import (
15+
ndim, shape, strides, writeable, size, itemsize, nbytes, owndata
16+
)
1117

1218
a = np.array(0, 'f8')
13-
assert get_arr_ndim(a) == 0
14-
assert get_arr_shape(a) == []
15-
assert get_arr_strides(a) == []
16-
with pytest.raises(RuntimeError):
17-
get_arr_shape(a, 1)
18-
with pytest.raises(RuntimeError):
19-
get_arr_strides(a, 0)
20-
assert get_arr_writeable(a)
21-
assert get_arr_size(a) == 1
22-
assert get_arr_itemsize(a) == 8
23-
assert get_arr_nbytes(a) == 8
24-
assert get_arr_owndata(a)
19+
assert ndim(a) == 0
20+
assert all(shape(a) == [])
21+
assert all(strides(a) == [])
22+
with pytest.raises(IndexError) as excinfo:
23+
shape(a, 0)
24+
assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
25+
with pytest.raises(IndexError) as excinfo:
26+
strides(a, 0)
27+
assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
28+
assert writeable(a)
29+
assert size(a) == 1
30+
assert itemsize(a) == 8
31+
assert nbytes(a) == 8
32+
assert owndata(a)
2533

2634
a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
2735
a.flags.writeable = False
28-
assert get_arr_ndim(a) == 2
29-
assert get_arr_shape(a) == [2, 3]
30-
assert get_arr_shape(a, 0) == 2
31-
assert get_arr_shape(a, 1) == 3
32-
assert get_arr_strides(a) == [6, 2]
33-
assert get_arr_strides(a, 0) == 6
34-
assert get_arr_strides(a, 1) == 2
36+
assert ndim(a) == 2
37+
assert all(shape(a) == [2, 3])
38+
assert shape(a, 0) == 2
39+
assert shape(a, 1) == 3
40+
assert all(strides(a) == [6, 2])
41+
assert strides(a, 0) == 6
42+
assert strides(a, 1) == 2
43+
with pytest.raises(IndexError) as excinfo:
44+
shape(a, 2)
45+
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
46+
with pytest.raises(IndexError) as excinfo:
47+
strides(a, 2)
48+
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
49+
assert not writeable(a)
50+
assert size(a) == 6
51+
assert itemsize(a) == 2
52+
assert nbytes(a) == 12
53+
assert not owndata(a)
54+
55+
56+
@pytest.requires_numpy
57+
def test_index_at(arr):
58+
from pybind11_tests.array import index_at, index_at_t
59+
60+
assert index_at(arr) == 0
61+
assert index_at(arr, 0) == 0
62+
assert index_at(arr, 1) == 6
63+
assert index_at(arr, 0, 1) == 2
64+
assert index_at(arr, 1, 2) == 10
65+
with pytest.raises(IndexError) as excinfo:
66+
index_at(arr, 1, 2, 3)
67+
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
68+
69+
assert index_at_t(arr) == 0
70+
assert index_at_t(arr, 0) == 0
71+
assert index_at_t(arr, 1) == 3
72+
assert index_at_t(arr, 0, 1) == 1
73+
assert index_at_t(arr, 1, 2) == 5
74+
with pytest.raises(IndexError) as excinfo:
75+
index_at_t(arr, 1, 2, 3)
76+
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
77+
78+
79+
@pytest.requires_numpy
80+
def test_data(arr):
81+
from pybind11_tests.array import data, data_t
82+
83+
assert all(data(arr) == [1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0])
84+
assert all(data(arr) == [1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0])
85+
assert all(data(arr, 1) == [4, 0, 5, 0, 6, 0])
86+
assert all(data(arr, 0, 1) == [2, 0, 3, 0, 4, 0, 5, 0, 6, 0])
87+
assert all(data(arr, 1, 2) == [6, 0])
88+
with pytest.raises(IndexError):
89+
data(arr, 1, 2, 3)
90+
91+
assert all(data_t(arr) == [1, 2, 3, 4, 5, 6])
92+
assert all(data_t(arr) == [1, 2, 3, 4, 5, 6])
93+
assert all(data_t(arr, 1) == [4, 5, 6])
94+
assert all(data_t(arr, 0, 1) == [2, 3, 4, 5, 6])
95+
assert all(data_t(arr, 1, 2) == [6])
96+
with pytest.raises(IndexError):
97+
data_t(arr, 1, 2, 3)
98+
99+
100+
@pytest.requires_numpy
101+
def test_mutate_reaadonly(arr):
102+
from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t
103+
arr.flags.writeable = False
104+
35105
with pytest.raises(RuntimeError):
36-
get_arr_shape(a, 2)
106+
mutate_data(arr)
37107
with pytest.raises(RuntimeError):
38-
get_arr_strides(a, 2)
39-
assert not get_arr_writeable(a)
40-
assert get_arr_size(a) == 6
41-
assert get_arr_itemsize(a) == 2
42-
assert get_arr_nbytes(a) == 12
43-
assert not get_arr_owndata(a)
108+
mutate_data_t(arr)
109+
with pytest.raises(RuntimeError):
110+
mutate_at_t(arr, 0, 0)
111+
112+
113+
@pytest.requires_numpy
114+
def test_at(arr):
115+
from pybind11_tests.array import at_t, mutate_at_t
116+
117+
with pytest.raises(IndexError) as excinfo:
118+
at_t(arr)
119+
assert str(excinfo.value) == 'invalid index shape: 0 (ndim = 2)'
120+
with pytest.raises(IndexError) as excinfo:
121+
at_t(arr, 1)
122+
assert str(excinfo.value) == 'invalid index shape: 1 (ndim = 2)'
123+
with pytest.raises(IndexError) as excinfo:
124+
at_t(arr, 1, 2, 3)
125+
assert str(excinfo.value) == 'invalid index shape: 3 (ndim = 2)'
126+
assert at_t(arr, 0, 2) == 3
127+
assert at_t(arr, 1, 0) == 4
128+
129+
with pytest.raises(IndexError) as excinfo:
130+
mutate_at_t(arr)
131+
assert str(excinfo.value) == 'invalid index shape: 0 (ndim = 2)'
132+
with pytest.raises(IndexError) as excinfo:
133+
mutate_at_t(arr, 1)
134+
assert str(excinfo.value) == 'invalid index shape: 1 (ndim = 2)'
135+
with pytest.raises(IndexError) as excinfo:
136+
mutate_at_t(arr, 1, 2, 3)
137+
assert str(excinfo.value) == 'invalid index shape: 3 (ndim = 2)'
138+
assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
139+
assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
140+
141+
142+
@pytest.requires_numpy
143+
def test_mutate_data(arr):
144+
from pybind11_tests.array import mutate_data, mutate_data_t
145+
146+
assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
147+
assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
148+
assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
149+
assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
150+
assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])
151+
with pytest.raises(IndexError):
152+
mutate_data(arr, 1, 2, 3)
153+
154+
assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
155+
assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
156+
assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
157+
assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
158+
assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
159+
with pytest.raises(IndexError):
160+
mutate_data_t(arr, 1, 2, 3)
161+
162+
163+
@pytest.requires_numpy
164+
def test_bounds_check(arr):
165+
from pybind11_tests.array import (index_at, index_at_t, data, data_t,
166+
mutate_data, mutate_data_t, at_t, mutate_at_t)
167+
funcs = (index_at, index_at_t, data, data_t,
168+
mutate_data, mutate_data_t, at_t, mutate_at_t)
169+
for func in funcs:
170+
with pytest.raises(IndexError) as excinfo:
171+
index_at(arr, 2, 0)
172+
assert str(excinfo.value) == 'out of bounds access'
173+
with pytest.raises(IndexError) as excinfo:
174+
index_at(arr, 0, 3)
175+
assert str(excinfo.value) == 'out of bounds access'

0 commit comments

Comments
 (0)