Skip to content

Commit

Permalink
Change nb::keep_alive in the presence of implicit conversions
Browse files Browse the repository at this point in the history
This commit changes the behavior of the ``nb::keep_alive<Nurse,
Patient>`` function binding annotation as follows: when function call
requires the implicit conversion of an argument, the lifetime constraint
now applies to the newly produced argument instead of the original
object. The former behavior seems highly undesirable and bug-prone.

Capture of such implicitly converted arguments is supported for the base
type caster (anything bound via ``nb::class_<>``) and ``nb::ndarray``.
  • Loading branch information
wjakob committed Oct 13, 2023
1 parent 0b7f3b1 commit 9d4b2e3
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 56 deletions.
50 changes: 36 additions & 14 deletions docs/api_core.rst
Expand Up @@ -1533,20 +1533,42 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`,

.. cpp:struct:: template <size_t Nurse, size_t Patient> keep_alive

Following the function evaluation, keep object ``Patient`` alive while the
object ``Nurse`` exists. The function uses the following indexing
convention:

- Index ``0`` refers to the return value of methods. (It should not be used in
constructors)

- Index ``1`` refers to the first argument. In methods and constructors, index ``1``
refers to the implicit ``this`` pointer, while regular arguments begin at index ``2``.

When the nurse or patient equal ``None``, the annotation does nothing.

nanobind will raise an exception when the nurse object is neither a
nanobind-registered type nor weak-referenceable.
Following evaluation of the bound function, keep the object referenced by
index ``Patient`` alive *as long as* the object with index ``Nurse`` exists.
This uses the following indexing convention:

- Index ``0`` refers to the return value of methods. It should not be used
in constructors or functions that do not return a result.

- Index ``1`` refers to the first argument. In methods and constructors,
index ``1`` refers to the implicit ``this`` pointer, while regular
arguments begin at index ``2``.

The annotation has the following runtime characteristics:

- It does nothing when the nurse or patient object are ``None``.

- It raises an exception when the nurse object is neither
weak-referenceable nor an instance of a binding created via
:cpp:class:`nb::class_\<..\> <class_>`.

Two additional caveats regarding :cpp:class:`keep_alive <keep_alive>` are
noteworthy:

- It *usually* doesn't make sense to specify a ``Nurse`` or ``Patient`` for an
argument or return value handled by a :ref:`type caster <type_caster>` (e.g.,
a STL vector handled via the include directive ``#include
<nanobind/stl/vector.h>``). That's because type casters copy-convert the
Python object into an equivalent C++ object, whose lifetime is decoupled
from the original Python object. However, the :cpp:class:`keep_alive
<keep_alive>` annotation *only* affects the lifetime of Python objects
*and not their C++ copy*.

- Dispatching a Python → C++ function call may require the :ref:`implicit
conversion <noconvert>` of function arguments. In this case, the objects
passed to the C++ function differ from the originally specified arguments.
The ``Nurse`` and ``Patient`` annotation always refer to the *final* object
following implicit conversion.

.. cpp:struct:: raw_doc

Expand Down
19 changes: 13 additions & 6 deletions docs/functions.rst
Expand Up @@ -319,17 +319,24 @@ appends an entry to a log data structure.
nb::class_<Log>(m, "Log")
.def("append",
[](Log &log, Entry *entry) { ... },
[](Log &log, Entry *entry) -> void { ... },
nb::keep_alive<1, 2>());
Here, ``Nurse = 1`` refers to the ``log`` argument, while ``Patient = 2``
refers to ``entry``. See the definition of :cpp:class:`nb::keep_alive
<keep_alive>` for details on the numbering convention.
refers to ``entry``. Setting ``Nurse/Patient = 0`` would select the function
return value (here, the function doesn't return anything, so ``0`` is not a
valid choice).

The example uses the annotation to tie the lifetime of the ``entry`` to that of
the ``log``. Without it, Python may delete the ``Entry`` instance at a later
point, which would be problematic if ``Log`` did not make a copy but references
the instance through its pointer address.
``log``. Without it, Python could potentially delete ``entry`` *before*
``log``, which would be problematic if the ``log.append()`` operation causes
``log`` to reference ``entry`` through a pointer address instead of making a
copy. Whether or not this is a good design is another question (for example,
shared ownership via ``std::shared_ptr<T>`` or intrusive reference counting
would avoid the problem altogether).

See the definition of :cpp:class:`nb::keep_alive <keep_alive>` for further
discussion and limitations of this method.

.. _call_guards:

Expand Down
38 changes: 24 additions & 14 deletions include/nanobind/nb_attr.h
Expand Up @@ -90,7 +90,7 @@ enum class func_flags : uint32_t {

/// Did the user specify a name for this function, or is it anonymous?
has_name = (1 << 4),
/// Did the user specify a scope where this function should be installed?
/// Did the user specify a scope in which this function should be installed?
has_scope = (1 << 5),
/// Did the user specify a docstring?
has_doc = (1 << 6),
Expand All @@ -100,7 +100,7 @@ enum class func_flags : uint32_t {
has_var_args = (1 << 8),
/// Does the function signature contain an *kwargs-style argument?
has_var_kwargs = (1 << 9),
/// Is this function a class method?
/// Is this function a method of a class?
is_method = (1 << 10),
/// Is this function a method called __init__? (automatically generated)
is_constructor = (1 << 11),
Expand All @@ -113,7 +113,9 @@ enum class func_flags : uint32_t {
/// Should the func_new() call return a new reference?
return_ref = (1 << 15),
/// Does this overload specify a raw docstring that should take precedence?
raw_doc = (1 << 16)
raw_doc = (1 << 16),
/// Does this function have one or more nb::keep_alive() annotations?
has_keep_alive = (1 << 17)
};

struct arg_data {
Expand Down Expand Up @@ -226,30 +228,38 @@ template <typename F, typename... Ts>
NB_INLINE void func_extra_apply(F &, call_guard<Ts...>, size_t &) {}

template <typename F, size_t Nurse, size_t Patient>
NB_INLINE void func_extra_apply(F &, nanobind::keep_alive<Nurse, Patient>,
size_t &) {}

template <typename... Ts> struct extract_guard { using type = void; };
NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive<Nurse, Patient>, size_t &) {
f.flags |= (uint32_t) func_flags::has_keep_alive;
}

template <typename T, typename... Ts> struct extract_guard<T, Ts...> {
using type = typename extract_guard<Ts...>::type;
template <typename... Ts> struct func_extra_info {
using call_guard = void;
static constexpr bool keep_alive = false;
};

template <typename T, typename... Ts> struct func_extra_info<T, Ts...>
: func_extra_info<Ts...> { };

template <typename... Cs, typename... Ts>
struct extract_guard<call_guard<Cs...>, Ts...> {
static_assert(std::is_same_v<typename extract_guard<Ts...>::type, void>,
struct func_extra_info<nanobind::call_guard<Cs...>, Ts...> : func_extra_info<Ts...> {
static_assert(std::is_same_v<typename func_extra_info<Ts...>::call_guard, void>,
"call_guard<> can only be specified once!");
using type = call_guard<Cs...>;
using call_guard = nanobind::call_guard<Cs...>;
};

template <size_t Nurse, size_t Patient, typename... Ts>
struct func_extra_info<nanobind::keep_alive<Nurse, Patient>, Ts...> : func_extra_info<Ts...> {
static constexpr bool keep_alive = true;
};

template <typename T>
NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) {}
NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { }

template <size_t Nurse, size_t Patient>
NB_INLINE void
process_keep_alive(PyObject **args, PyObject *result,
nanobind::keep_alive<Nurse, Patient> *) {
keep_alive(Nurse == 0 ? result : args[Nurse - 1],
keep_alive(Nurse == 0 ? result : args[Nurse - 1],
Patient == 0 ? result : args[Patient - 1]);
}

Expand Down
41 changes: 32 additions & 9 deletions include/nanobind/nb_func.h
Expand Up @@ -10,17 +10,33 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Caster>
bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags,
cleanup_list *cleanup, size_t index) {
size_t size_before = cleanup->size();
if (!c.from_python(args[index], args_flags[index], cleanup))
return false;

// If an implicit conversion took place, update the 'args' array so that
// the keep_alive annotation can later process this change
size_t size_after = cleanup->size();
if (size_after != size_before)
args[index] = (*cleanup)[size_after - 1];

return true;
}

template <bool ReturnRef, bool CheckGuard, typename Func, typename Return,
typename... Args, size_t... Is, typename... Extra>
NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
std::index_sequence<Is...> is,
const Extra &...extra) {
using Guard = typename extract_guard<Extra...>::type;
using Info = func_extra_info<Extra...>;

if constexpr (CheckGuard && !std::is_same_v<Guard, void>) {
if constexpr (CheckGuard && !std::is_same_v<typename Info::call_guard, void>) {
return func_create<ReturnRef, false>(
[func = (forward_t<Func>) func](Args... args) NB_INLINE_LAMBDA {
typename Guard::type g;
typename Info::call_guard::type g;
(void) g;
return func((forward_t<Args>) args...);
},
Expand All @@ -43,7 +59,7 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
constexpr bool is_method_det =
(std::is_same_v<is_method, Extra> + ... + 0) != 0;

/// A few compile-time consistency checks
// A few compile-time consistency checks
static_assert(args_pos_1 == args_pos_n && kwargs_pos_1 == kwargs_pos_n,
"Repeated use of nb::kwargs or nb::args in the function signature!");
static_assert(nargs_provided == 0 || nargs_provided + is_method_det == nargs,
Expand Down Expand Up @@ -104,7 +120,7 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),

f.impl = [](void *p, PyObject **args, uint8_t *args_flags, rv_policy policy,
cleanup_list *cleanup) NB_INLINE_LAMBDA -> PyObject * {
(void)p; (void)args; (void)args_flags; (void)policy; (void)cleanup;
(void) p; (void) args; (void) args_flags; (void) policy; (void) cleanup;

const capture *cap;
if constexpr (sizeof(capture) <= sizeof(f.capture))
Expand All @@ -115,9 +131,15 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
tuple<make_caster<Args>...> in;
(void) in;

if ((!in.template get<Is>().from_python(args[Is], args_flags[Is],
cleanup) || ...))
return NB_NEXT_OVERLOAD;
if constexpr (Info::keep_alive) {
if ((!from_python_keep_alive(in.template get<Is>(), args,
args_flags, cleanup, Is) || ...))
return NB_NEXT_OVERLOAD;
} else {
if ((!in.template get<Is>().from_python(args[Is], args_flags[Is],
cleanup) || ...))
return NB_NEXT_OVERLOAD;
}

PyObject *result;
if constexpr (std::is_void_v<Return>) {
Expand All @@ -131,7 +153,8 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
policy, cleanup).ptr();
}

(process_keep_alive(args, result, (Extra *) nullptr), ...);
if constexpr (Info::keep_alive)
(process_keep_alive(args, result, (Extra *) nullptr), ...);

return result;
};
Expand Down
14 changes: 11 additions & 3 deletions include/nanobind/nb_lib.h
Expand Up @@ -50,8 +50,14 @@ struct NB_CORE cleanup_list {
/// Decrease the reference count of all appended objects
void release() noexcept;

/// Does the list contain any entries?
inline bool used() { return m_size != 1; }
/// Does the list contain any entries? (besides the 'self' argument)
bool used() { return m_size != 1; }

/// Return the size of the cleanup stack
size_t size() const { return m_size; }

/// Subscript operator
PyObject *operator[](size_t index) const { return m_data[index]; }

protected:
/// Out of memory, expand..
Expand Down Expand Up @@ -382,6 +388,7 @@ NB_CORE void keep_alive(PyObject *nurse, PyObject *patient);
NB_CORE void keep_alive(PyObject *nurse, void *payload,
void (*deleter)(void *) noexcept) noexcept;


// ========================================================================

/// Indicate to nanobind that an implicit constructor can convert 'src' -> 'dst'
Expand Down Expand Up @@ -424,7 +431,8 @@ NB_CORE PyObject *module_new_submodule(PyObject *base, const char *name,

// Try to import a reference-counted ndarray object via DLPack
NB_CORE ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
bool convert) noexcept;
bool convert,
cleanup_list *cleanup) noexcept;

// Describe a local ndarray object using a DLPack capsule
NB_CORE ndarray_handle *ndarray_create(void *value, size_t ndim,
Expand Down
4 changes: 2 additions & 2 deletions include/nanobind/nb_types.h
Expand Up @@ -14,9 +14,9 @@ NAMESPACE_BEGIN(NB_NAMESPACE)
public: \
static constexpr auto Name = ::nanobind::detail::const_name(Str); \
NB_INLINE Type(handle h, ::nanobind::detail::borrow_t) \
: Parent(h, ::nanobind::detail::borrow_t{}) {} \
: Parent(h, ::nanobind::detail::borrow_t{}) { } \
NB_INLINE Type(handle h, ::nanobind::detail::steal_t) \
: Parent(h, ::nanobind::detail::steal_t{}) {} \
: Parent(h, ::nanobind::detail::steal_t{}) { } \
NB_INLINE static bool check_(handle h) { \
return Check(h.ptr()); \
}
Expand Down
4 changes: 2 additions & 2 deletions include/nanobind/ndarray.h
Expand Up @@ -540,14 +540,14 @@ template <typename... Args> struct type_caster<ndarray<Args...>> {
concat_maybe(detail::ndarray_arg<Args>::name...) +
const_name("]"));

bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept {
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
constexpr size_t size = (0 + ... + detail::ndarray_arg<Args>::size);
size_t shape[size + 1];
detail::ndarray_req req;
req.shape = shape;
(detail::ndarray_arg<Args>::apply(req), ...);
value = ndarray<Args...>(ndarray_import(
src.ptr(), &req, flags & (uint8_t) cast_flags::convert));
src.ptr(), &req, flags & (uint8_t) cast_flags::convert, cleanup));
return value.is_valid();
}

Expand Down
3 changes: 2 additions & 1 deletion src/nb_func.cpp
Expand Up @@ -181,6 +181,7 @@ PyObject *nb_func_new(const void *in_) noexcept {
has_args = f->flags & (uint32_t) func_flags::has_args,
has_var_args = f->flags & (uint32_t) func_flags::has_var_args,
has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_kwargs,
has_keep_alive = f->flags & (uint32_t) func_flags::has_keep_alive,
has_doc = f->flags & (uint32_t) func_flags::has_doc,
is_implicit = f->flags & (uint32_t) func_flags::is_implicit,
is_method = f->flags & (uint32_t) func_flags::is_method,
Expand Down Expand Up @@ -248,7 +249,7 @@ PyObject *nb_func_new(const void *in_) noexcept {
has_name ? f->name : "<anonymous>");

func->max_nargs_pos = f->nargs;
func->complex_call = has_args || has_var_args || has_var_kwargs;
func->complex_call = has_args || has_var_args || has_var_kwargs || has_keep_alive;

if (func_prev) {
func->complex_call |= ((nb_func *) func_prev)->complex_call;
Expand Down
13 changes: 9 additions & 4 deletions src/nb_ndarray.cpp
Expand Up @@ -293,7 +293,7 @@ bool ndarray_check(PyObject *o) noexcept {


ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
bool convert) noexcept {
bool convert, cleanup_list *cleanup) noexcept {
object capsule;
bool is_pycapsule = PyCapsule_CheckExact(o);

Expand Down Expand Up @@ -456,10 +456,15 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
} catch (...) { converted.reset(); }

// Potentially try again recursively
if (!converted.is_valid())
if (!converted.is_valid()) {
return nullptr;
else
return ndarray_import(converted.ptr(), req, false);
} else {
ndarray_handle *h =
ndarray_import(converted.ptr(), req, false, nullptr);
if (h && cleanup)
cleanup->append(converted.release().ptr());
return h;
}
}

if (!pass_dtype || !pass_device || !pass_shape || !pass_order)
Expand Down

0 comments on commit 9d4b2e3

Please sign in to comment.