Skip to content

Commit

Permalink
Create nb_ndarray on demand
Browse files Browse the repository at this point in the history
In combination with the previous build system improvements, this enables
more aggressive size reductions when extensions don't use n-d arrays.
  • Loading branch information
wjakob committed Apr 30, 2023
1 parent 5ead9ff commit e033c8f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 39 deletions.
28 changes: 1 addition & 27 deletions src/nb_internals.cpp
Expand Up @@ -85,10 +85,6 @@ extern int nb_bound_method_clear(PyObject *);
extern void nb_bound_method_dealloc(PyObject *);
extern PyObject *nb_method_descr_get(PyObject *, PyObject *, PyObject *);
extern int nb_type_setattro(PyObject*, PyObject*, PyObject*);
extern PyObject *nb_ndarray_get(PyObject *, PyObject *);
extern int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int);
extern void nb_ndarray_releasebuffer(PyObject *, Py_buffer *);
extern void nb_ndarray_dealloc(PyObject *self);
static PyObject *nb_static_property_get(PyObject *, PyObject *, PyObject *);

#if PY_VERSION_HEX >= 0x03090000
Expand Down Expand Up @@ -217,23 +213,6 @@ static PyType_Spec nb_static_property_spec = {
/* .slots = */ nb_static_property_slots
};

static PyType_Slot nb_ndarray_slots[] = {
{ Py_tp_dealloc, (void *) nb_ndarray_dealloc },
#if PY_VERSION_HEX >= 0x03090000
{ Py_bf_getbuffer, (void *) nb_ndarray_getbuffer },
{ Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer },
#endif
{ 0, nullptr }
};

static PyType_Spec nb_ndarray_spec = {
/* .name = */ "nanobind.nb_ndarray",
/* .basicsize = */ (int) sizeof(nb_ndarray),
/* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT,
/* .slots = */ nb_ndarray_slots
};

/// `nb_static_property_property.__get__()`: Always pass the class instead of the instance.
static PyObject *nb_static_property_get(PyObject *self, PyObject *, PyObject *cls) {
if (internals_get().nb_static_property_enabled) {
Expand Down Expand Up @@ -441,17 +420,12 @@ static NB_NOINLINE nb_internals *internals_make() {
(PyTypeObject *) PyType_FromSpec(&nb_static_property_spec);
p->nb_static_property_enabled = true;

// Tensor type
p->nb_ndarray = (PyTypeObject *) PyType_FromSpec(&nb_ndarray_spec);

if (!p->nb_func || !p->nb_method ||
!p->nb_bound_method || !p->nb_type ||
!p->nb_static_property || !p->nb_ndarray)
!p->nb_static_property)
fail("nanobind::detail::internals_make(): type initialization failed!");

#if PY_VERSION_HEX < 0x03090000
p->nb_ndarray->tp_as_buffer->bf_getbuffer = nb_ndarray_getbuffer;
p->nb_ndarray->tp_as_buffer->bf_releasebuffer = nb_ndarray_releasebuffer;
p->nb_func->tp_flags |= NB_HAVE_VECTORCALL;
p->nb_func->tp_vectorcall_offset = offsetof(nb_func, vectorcall);
p->nb_method->tp_flags |= NB_HAVE_VECTORCALL;
Expand Down
8 changes: 4 additions & 4 deletions src/nb_internals.h
Expand Up @@ -186,13 +186,13 @@ struct nb_internals {
PyTypeObject *nb_static_property;
bool nb_static_property_enabled;

/// N-dimensional array wrapper
PyTypeObject *nb_ndarray;
/// N-dimensional array wrapper (constructed optionally)
PyTypeObject *nb_ndarray = nullptr;

/// Instance pointer -> Python object mapping
/// C++ -> Python instance map
nb_inst_map inst_c2p;

/// C++ type -> Python type mapping
/// C++ -> Python type map
nb_type_map type_c2p;

/// Dictionary of sets storing keep_alive references
Expand Down
53 changes: 45 additions & 8 deletions src/nb_ndarray.cpp
Expand Up @@ -22,7 +22,7 @@ struct ndarray_handle {
bool call_deleter;
};

void nb_ndarray_dealloc(PyObject *self) {
static void nb_ndarray_dealloc(PyObject *self) {
ndarray_dec_ref(((nb_ndarray *) self)->th);

freefunc tp_free;
Expand All @@ -35,7 +35,7 @@ void nb_ndarray_dealloc(PyObject *self) {
tp_free(self);
}

int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int) {
static int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int) {
nb_ndarray *self = (nb_ndarray *) exporter;

dlpack::dltensor &t = self->th->ndarray->dltensor;
Expand Down Expand Up @@ -114,11 +114,47 @@ int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int) {
return 0;
}

void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) {
static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) {
PyMem_Free(view->shape);
PyMem_Free(view->strides);
}

static PyTypeObject *nb_ndarray_get() noexcept {
nb_internals &internals = internals_get();
PyTypeObject *tp = internals.nb_ndarray;

if (NB_UNLIKELY(!tp)) {
PyType_Slot slots[] = {
{ Py_tp_dealloc, (void *) nb_ndarray_dealloc },
#if PY_VERSION_HEX >= 0x03090000
{ Py_bf_getbuffer, (void *) nb_ndarray_getbuffer },
{ Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer },
#endif
{ 0, nullptr }
};

PyType_Spec spec = {
/* .name = */ "nanobind.nb_ndarray",
/* .basicsize = */ (int) sizeof(nb_ndarray),
/* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT,
/* .slots = */ slots
};

tp = (PyTypeObject *) PyType_FromSpec(&spec);
if (!tp)
fail("nb_ndarray type creation failed!");

#if PY_VERSION_HEX < 0x03090000
tp->tp_as_buffer->bf_getbuffer = nb_ndarray_getbuffer;
tp->tp_as_buffer->bf_releasebuffer = nb_ndarray_releasebuffer;
#endif
internals.nb_ndarray = tp;
}

return tp;
}

static PyObject *dlpack_from_buffer_protocol(PyObject *o) {
scoped_pymalloc<Py_buffer> view;
scoped_pymalloc<managed_dltensor> mt;
Expand Down Expand Up @@ -467,9 +503,9 @@ void ndarray_dec_ref(ndarray_handle *th) noexcept {
}

ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in,
PyObject *owner, const int64_t *strides_in,
dlpack::dtype *dtype, int32_t device_type,
int32_t device_id) {
PyObject *owner, const int64_t *strides_in,
dlpack::dtype *dtype, int32_t device_type,
int32_t device_id) {
/* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, but
PyTorch unfortunately ignores the 'byte_offset' value.. :-( */
#if 0
Expand Down Expand Up @@ -539,15 +575,16 @@ static void ndarray_capsule_destructor(PyObject *o) {
PyErr_Clear();
}

PyObject *ndarray_wrap(ndarray_handle *th, int framework, rv_policy policy) noexcept {
PyObject *ndarray_wrap(ndarray_handle *th, int framework,
rv_policy policy) noexcept {
if (!th)
return none().release().ptr();

bool copy = policy == rv_policy::copy || policy == rv_policy::move;

if ((ndarray_framework) framework == ndarray_framework::numpy) {
try {
object o = steal(PyType_GenericAlloc(internals_get().nb_ndarray, 0));
object o = steal(PyType_GenericAlloc(nb_ndarray_get(), 0));
if (!o.is_valid())
return nullptr;
((nb_ndarray *) o.ptr())->th = th;
Expand Down

0 comments on commit e033c8f

Please sign in to comment.