Skip to content

Commit 6539d87

Browse files
committed
pycontainer inherits from pybind11 object
1 parent 6c5f366 commit 6539d87

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

benchmark/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "pybind11/pybind11.h"
22
#include "pybind11/numpy.h"
3+
//#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
34
#include "numpy/arrayobject.h"
45
#include "xtensor/xtensor.hpp"
56
#include "xtensor/xarray.hpp"

include/xtensor-python/pycontainer.hpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace xt
2121
{
2222

2323
template <class D>
24-
class pycontainer
24+
class pycontainer : public pybind11::object
2525
{
2626

2727
public:
@@ -48,8 +48,8 @@ namespace xt
4848
using stepper = xstepper<D>;
4949
using const_stepper = xstepper<const D>;
5050

51-
using broadcast_iterator = xiterator<stepper, shape_type>;
52-
using const_broadcast_iterator = xiterator<const_stepper, shape_type>;
51+
using broadcast_iterator = xiterator<stepper, shape_type*>;
52+
using const_broadcast_iterator = xiterator<const_stepper, shape_type*>;
5353

5454
size_type size() const;
5555
size_type dimension() const;
@@ -127,6 +127,9 @@ namespace xt
127127
pycontainer() = default;
128128
~pycontainer() = default;
129129

130+
pycontainer(pybind11::handle h, borrowed_t);
131+
pycontainer(pybind11::handle h, stolen_t);
132+
130133
pycontainer(const pycontainer&) = default;
131134
pycontainer& operator=(const pycontainer&) = default;
132135

@@ -159,12 +162,17 @@ namespace xt
159162
static constexpr bool value = false;
160163
};
161164

165+
constexpr int log2(size_t n, int k = 0)
166+
{
167+
return (n <= 1) ? k : log2(n >> 1, k + 1);
168+
}
169+
162170
template <typename T>
163171
struct is_fmt_numeric<T, std::enable_if_t<std::is_arithmetic<T>::value>>
164172
{
165173
static constexpr bool value = true;
166174
static constexpr int index = std::is_same<T, bool>::value ? 0 : 1 + (
167-
std::is_integral<T>::value ? std::log2(sizeof(T)) * 2 + std::is_unsigned<T>::value : 8 + (
175+
std::is_integral<T>::value ? log2(sizeof(T)) * 2 + std::is_unsigned<T>::value : 8 + (
168176
std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
169177
};
170178

@@ -193,6 +201,18 @@ namespace xt
193201
* pycontainer implementation *
194202
******************************/
195203

204+
template <class D>
205+
inline pycontainer<D>::pycontainer(pybind11::handle h, borrowed_t)
206+
: pybind11::object(h, borrowed)
207+
{
208+
}
209+
210+
template <class D>
211+
inline pycontainer<D>::pycontainer(pybind11::handle h, stolen_t)
212+
: pybind11::object(h, stolen)
213+
{
214+
}
215+
196216
template <class D>
197217
inline void pycontainer<D>::fill_default_strides(const shape_type& shape, strides_type& strides)
198218
{

include/xtensor-python/pytensor.hpp

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ namespace xt
7171
};
7272

7373
template <class T, std::size_t N>
74-
class pytensor : public pybind11::object,
75-
public pycontainer<pytensor<T, N>>,
74+
class pytensor : public pycontainer<pytensor<T, N>>,
7675
public xcontainer_semantic<pytensor<T, N>>
7776
{
7877

@@ -88,7 +87,7 @@ namespace xt
8887
using strides_type = typename base_type::strides_type;
8988
using backstrides_type = typename base_type::backstrides_type;
9089

91-
pytensor();
90+
pytensor() = default;
9291

9392
pytensor(pybind11::handle h, borrowed_t);
9493
pytensor(pybind11::handle h, stolen_t);
@@ -137,30 +136,24 @@ namespace xt
137136
* pytensor implementation *
138137
***************************/
139138

140-
template <class T, std::size_t N>
141-
inline pytensor<T, N>::pytensor()
142-
{
143-
}
144-
145139
template <class T, std::size_t N>
146140
inline pytensor<T, N>::pytensor(pybind11::handle h, borrowed_t)
147-
: pybind11::object(h, borrowed)
141+
: base_type(h, borrowed)
148142
{
149143
init_from_python();
150144
}
151145

152146
template <class T, std::size_t N>
153147
inline pytensor<T, N>::pytensor(pybind11::handle h, stolen_t)
154-
: pybind11::object(h, stolen)
148+
: base_type(h, stolen)
155149
{
156150
init_from_python();
157151
}
158152

159153
template <class T, std::size_t N>
160154
inline pytensor<T, N>::pytensor(const pybind11::object& o)
161-
: pybind11::object(base_type::raw_array_t(o.ptr()), stolen)
155+
: base_type(base_type::raw_array_t(o.ptr()), stolen)
162156
{
163-
//std::cout << "Object constructor" << std::endl;
164157
if(!this->m_ptr)
165158
throw pybind11::error_already_set();
166159
init_from_python();
@@ -170,14 +163,12 @@ namespace xt
170163
inline pytensor<T, N>::pytensor(const shape_type& shape,
171164
const strides_type& strides)
172165
{
173-
//std::cout << "Shape + strides constructor" << std::endl;
174166
init_tensor(shape, strides);
175167
}
176168

177169
template <class T, std::size_t N>
178170
inline pytensor<T, N>::pytensor(const shape_type& shape)
179171
{
180-
//std::cout << "Shape constructor" << std::endl;
181172
base_type::fill_default_strides(shape, m_strides);
182173
init_tensor(shape, m_strides);
183174
}
@@ -186,7 +177,6 @@ namespace xt
186177
template <class E>
187178
inline pytensor<T, N>::pytensor(const xexpression<E>& e)
188179
{
189-
//std::cout << "Extended constructor" << std::endl;
190180
semantic_base::assign(e);
191181
}
192182

@@ -200,7 +190,6 @@ namespace xt
200190
template <class T, std::size_t N>
201191
inline void pytensor<T, N>::reshape(const shape_type& shape)
202192
{
203-
//std::cout << "Reshape(shape)" << std::endl;
204193
if(shape != m_shape)
205194
{
206195
strides_type strides;
@@ -212,17 +201,14 @@ namespace xt
212201
template <class T, std::size_t N>
213202
inline void pytensor<T, N>::reshape(const shape_type& shape, const strides_type& strides)
214203
{
215-
//std::cout << "Reshape(shape, strides)" << std::endl;
216204
self_type tmp(shape, strides);
217205
*this = std::move(tmp);
218206
}
219207

220208
template <class T, std::size_t N>
221209
inline auto pytensor<T, N>::ensure(pybind11::handle h) -> self_type
222210
{
223-
//std::cout << "Ensure" << std::endl;
224211
auto result = pybind11::reinterpret_steal<self_type>(base_type::raw_array_t(h.ptr()));
225-
//auto result = pybind11::reinterpret_steal<self_type>(h.ptr());
226212
if(result.ptr() == nullptr)
227213
PyErr_Clear();
228214
return result;
@@ -238,7 +224,6 @@ namespace xt
238224
template <class T, std::size_t N>
239225
inline void pytensor<T, N>::init_tensor(const shape_type& shape, const strides_type& strides)
240226
{
241-
//std::cout << "init tensor" << std::endl;
242227
npy_intp python_strides[N];
243228
std::transform(strides.beign(), strides.end(), python_strides,
244229
[](auto v) { return sizeof(T) * v; });
@@ -266,7 +251,6 @@ namespace xt
266251
template <class T, std::size_t N>
267252
inline void pytensor<T, N>::init_from_python()
268253
{
269-
//std::cout << "init from python" << std::endl;
270254
if(PyArray_NDIM(this->m_ptr) != N)
271255
throw std::runtime_error("NumPy: ndarray has incorrect number of dimensions");
272256

0 commit comments

Comments
 (0)