From e033c8fab4a14cbb9c5b0e08b1bdf49af2a9cb22 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Thu, 27 Apr 2023 12:34:54 +0200 Subject: [PATCH] Create ``nb_ndarray`` on demand In combination with the previous build system improvements, this enables more aggressive size reductions when extensions don't use n-d arrays. --- src/nb_internals.cpp | 28 +---------------------- src/nb_internals.h | 8 +++---- src/nb_ndarray.cpp | 53 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/src/nb_internals.cpp b/src/nb_internals.cpp index f76330d5..240ab30d 100644 --- a/src/nb_internals.cpp +++ b/src/nb_internals.cpp @@ -85,10 +85,6 @@ extern int nb_bound_method_clear(PyObject *); extern void nb_bound_method_dealloc(PyObject *); extern PyObject *nb_method_descr_get(PyObject *, PyObject *, PyObject *); extern int nb_type_setattro(PyObject*, PyObject*, PyObject*); -extern PyObject *nb_ndarray_get(PyObject *, PyObject *); -extern int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int); -extern void nb_ndarray_releasebuffer(PyObject *, Py_buffer *); -extern void nb_ndarray_dealloc(PyObject *self); static PyObject *nb_static_property_get(PyObject *, PyObject *, PyObject *); #if PY_VERSION_HEX >= 0x03090000 @@ -217,23 +213,6 @@ static PyType_Spec nb_static_property_spec = { /* .slots = */ nb_static_property_slots }; -static PyType_Slot nb_ndarray_slots[] = { - { Py_tp_dealloc, (void *) nb_ndarray_dealloc }, -#if PY_VERSION_HEX >= 0x03090000 - { Py_bf_getbuffer, (void *) nb_ndarray_getbuffer }, - { Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer }, -#endif - { 0, nullptr } -}; - -static PyType_Spec nb_ndarray_spec = { - /* .name = */ "nanobind.nb_ndarray", - /* .basicsize = */ (int) sizeof(nb_ndarray), - /* .itemsize = */ 0, - /* .flags = */ Py_TPFLAGS_DEFAULT, - /* .slots = */ nb_ndarray_slots -}; - /// `nb_static_property_property.__get__()`: Always pass the class instead of the instance. static PyObject *nb_static_property_get(PyObject *self, PyObject *, PyObject *cls) { if (internals_get().nb_static_property_enabled) { @@ -441,17 +420,12 @@ static NB_NOINLINE nb_internals *internals_make() { (PyTypeObject *) PyType_FromSpec(&nb_static_property_spec); p->nb_static_property_enabled = true; - // Tensor type - p->nb_ndarray = (PyTypeObject *) PyType_FromSpec(&nb_ndarray_spec); - if (!p->nb_func || !p->nb_method || !p->nb_bound_method || !p->nb_type || - !p->nb_static_property || !p->nb_ndarray) + !p->nb_static_property) fail("nanobind::detail::internals_make(): type initialization failed!"); #if PY_VERSION_HEX < 0x03090000 - p->nb_ndarray->tp_as_buffer->bf_getbuffer = nb_ndarray_getbuffer; - p->nb_ndarray->tp_as_buffer->bf_releasebuffer = nb_ndarray_releasebuffer; p->nb_func->tp_flags |= NB_HAVE_VECTORCALL; p->nb_func->tp_vectorcall_offset = offsetof(nb_func, vectorcall); p->nb_method->tp_flags |= NB_HAVE_VECTORCALL; diff --git a/src/nb_internals.h b/src/nb_internals.h index 3ac6f774..bfdbff8e 100644 --- a/src/nb_internals.h +++ b/src/nb_internals.h @@ -186,13 +186,13 @@ struct nb_internals { PyTypeObject *nb_static_property; bool nb_static_property_enabled; - /// N-dimensional array wrapper - PyTypeObject *nb_ndarray; + /// N-dimensional array wrapper (constructed optionally) + PyTypeObject *nb_ndarray = nullptr; - /// Instance pointer -> Python object mapping + /// C++ -> Python instance map nb_inst_map inst_c2p; - /// C++ type -> Python type mapping + /// C++ -> Python type map nb_type_map type_c2p; /// Dictionary of sets storing keep_alive references diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index 4402e5ab..9873d3c5 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -22,7 +22,7 @@ struct ndarray_handle { bool call_deleter; }; -void nb_ndarray_dealloc(PyObject *self) { +static void nb_ndarray_dealloc(PyObject *self) { ndarray_dec_ref(((nb_ndarray *) self)->th); freefunc tp_free; @@ -35,7 +35,7 @@ void nb_ndarray_dealloc(PyObject *self) { tp_free(self); } -int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int) { +static int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int) { nb_ndarray *self = (nb_ndarray *) exporter; dlpack::dltensor &t = self->th->ndarray->dltensor; @@ -114,11 +114,47 @@ int nb_ndarray_getbuffer(PyObject *exporter, Py_buffer *view, int) { return 0; } -void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) { +static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) { PyMem_Free(view->shape); PyMem_Free(view->strides); } +static PyTypeObject *nb_ndarray_get() noexcept { + nb_internals &internals = internals_get(); + PyTypeObject *tp = internals.nb_ndarray; + + if (NB_UNLIKELY(!tp)) { + PyType_Slot slots[] = { + { Py_tp_dealloc, (void *) nb_ndarray_dealloc }, +#if PY_VERSION_HEX >= 0x03090000 + { Py_bf_getbuffer, (void *) nb_ndarray_getbuffer }, + { Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer }, +#endif + { 0, nullptr } + }; + + PyType_Spec spec = { + /* .name = */ "nanobind.nb_ndarray", + /* .basicsize = */ (int) sizeof(nb_ndarray), + /* .itemsize = */ 0, + /* .flags = */ Py_TPFLAGS_DEFAULT, + /* .slots = */ slots + }; + + tp = (PyTypeObject *) PyType_FromSpec(&spec); + if (!tp) + fail("nb_ndarray type creation failed!"); + +#if PY_VERSION_HEX < 0x03090000 + tp->tp_as_buffer->bf_getbuffer = nb_ndarray_getbuffer; + tp->tp_as_buffer->bf_releasebuffer = nb_ndarray_releasebuffer; +#endif + internals.nb_ndarray = tp; + } + + return tp; +} + static PyObject *dlpack_from_buffer_protocol(PyObject *o) { scoped_pymalloc view; scoped_pymalloc mt; @@ -467,9 +503,9 @@ void ndarray_dec_ref(ndarray_handle *th) noexcept { } ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in, - PyObject *owner, const int64_t *strides_in, - dlpack::dtype *dtype, int32_t device_type, - int32_t device_id) { + PyObject *owner, const int64_t *strides_in, + dlpack::dtype *dtype, int32_t device_type, + int32_t device_id) { /* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, but PyTorch unfortunately ignores the 'byte_offset' value.. :-( */ #if 0 @@ -539,7 +575,8 @@ static void ndarray_capsule_destructor(PyObject *o) { PyErr_Clear(); } -PyObject *ndarray_wrap(ndarray_handle *th, int framework, rv_policy policy) noexcept { +PyObject *ndarray_wrap(ndarray_handle *th, int framework, + rv_policy policy) noexcept { if (!th) return none().release().ptr(); @@ -547,7 +584,7 @@ PyObject *ndarray_wrap(ndarray_handle *th, int framework, rv_policy policy) noex if ((ndarray_framework) framework == ndarray_framework::numpy) { try { - object o = steal(PyType_GenericAlloc(internals_get().nb_ndarray, 0)); + object o = steal(PyType_GenericAlloc(nb_ndarray_get(), 0)); if (!o.is_valid()) return nullptr; ((nb_ndarray *) o.ptr())->th = th;