Skip to content

Commit

Permalink
custom exception translators
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Oct 12, 2022
1 parent 8ca2796 commit 41b7da3
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 5 deletions.
23 changes: 23 additions & 0 deletions include/nanobind/nb_error.h
Expand Up @@ -71,4 +71,27 @@ NB_EXCEPTION(attribute_error)

#undef NB_EXCEPTION

inline void register_exception_translator(detail::exception_translator t,
void *payload = nullptr) {
detail::register_exception_translator(t, payload);
}

template <typename T>
class exception : public object {
NB_OBJECT_DEFAULT(exception, object, "Exception", PyExceptionClass_Check)

exception(handle mod, const char *name, handle base = PyExc_Exception)
: object(detail::exception_new(mod.ptr(), name, base.ptr()),
detail::steal_t()) {
detail::register_exception_translator(
[](const std::exception_ptr &p, void *payload) {
try {
std::rethrow_exception(p);
} catch (T &e) {
PyErr_SetString((PyObject *) payload, e.what());
}
}, m_ptr);
}
};

NAMESPACE_END(NB_NAMESPACE)
10 changes: 10 additions & 0 deletions include/nanobind/nb_lib.h
Expand Up @@ -369,6 +369,16 @@ NB_CORE void print(PyObject *file, PyObject *str, PyObject *end);

// ========================================================================

typedef void (*exception_translator)(const std::exception_ptr &, void *);

NB_CORE void register_exception_translator(exception_translator translator,
void *payload);

NB_CORE PyObject *exception_new(PyObject *mod, const char *name,
PyObject *base);

// ========================================================================

NB_CORE std::pair<int8_t, bool> load_i8 (PyObject *o, uint8_t flags) noexcept;
NB_CORE std::pair<uint8_t, bool> load_u8 (PyObject *o, uint8_t flags) noexcept;
NB_CORE std::pair<int16_t, bool> load_i16(PyObject *o, uint8_t flags) noexcept;
Expand Down
35 changes: 35 additions & 0 deletions src/error.cpp
Expand Up @@ -9,6 +9,7 @@

#include <nanobind/nanobind.h>
#include "buffer.h"
#include "nb_internals.h"

NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)
Expand Down Expand Up @@ -124,4 +125,38 @@ NB_EXCEPTION(attribute_error, PyExc_AttributeError)

#undef NB_EXCEPTION

NAMESPACE_BEGIN(detail)

void register_exception_translator(exception_translator t, void *payload) {
auto &et = internals_get().exception_translators;
et.insert(et.begin(), { t, payload });
}

NB_CORE PyObject *exception_new(PyObject *scope, const char *name,
PyObject *base) {
object modname;
if (PyModule_Check(scope))
modname = getattr(scope, "__name__", handle());
else
modname = getattr(scope, "__module__", handle());

if (!modname.is_valid())
raise("nanobind::detail::exception_new(): could not determine module name!");

str combined = steal<str>(
PyUnicode_FromFormat("%U.%s", modname.ptr(), name));

PyObject *result = PyErr_NewException(combined.c_str(), base, nullptr);
if (!result)
raise("nanobind::detail::exception_new(): creation failed!");

if (hasattr(scope, name))
raise("nb::detail::exception_new(): an object of the same name already "
"exists!");

setattr(scope, name, result);
return result;
}

NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)
6 changes: 4 additions & 2 deletions src/nb_func.cpp
Expand Up @@ -319,9 +319,11 @@ static NB_NOINLINE PyObject *nb_func_error_noconvert(PyObject *self,
/// Used by nb_func_vectorcall: convert a C++ exception into a Python error
static NB_NOINLINE void nb_func_convert_cpp_exception() noexcept {
std::exception_ptr e = std::current_exception();
for (auto const &et : internals_get().exception_translators) {

for (auto pair : internals_get().exception_translators) {
try {
et(e);
// Try exception translator & forward payload
pair.first(e, pair.second);
return;
} catch (...) {
e = std::current_exception();
Expand Down
5 changes: 3 additions & 2 deletions src/nb_internals.cpp
Expand Up @@ -254,7 +254,7 @@ NB_THREAD_LOCAL current_method current_method_data =

static nb_internals *internals_p = nullptr;

void default_exception_translator(const std::exception_ptr &p) {
void default_exception_translator(const std::exception_ptr &p, void *) {
try {
std::rethrow_exception(p);
} catch (python_error &e) {
Expand Down Expand Up @@ -325,7 +325,6 @@ static void internals_make() {
str nb_name("nanobind");

internals_p = new nb_internals();
internals_p->exception_translators.push_back(default_exception_translator);

PyObject *capsule = PyCapsule_New(internals_p, nullptr, nullptr);
PyObject *nb_module = PyModule_NewObject(nb_name.ptr());
Expand Down Expand Up @@ -388,6 +387,8 @@ static void internals_make() {
internals_p->nb_bound_method->tp_vectorcall_offset = offsetof(nb_bound_method, vectorcall);
#endif

register_exception_translator(default_exception_translator, nullptr);

if (Py_AtExit(internals_cleanup))
fprintf(stderr,
"Warning: could not install the nanobind cleanup handler! This "
Expand Down
2 changes: 1 addition & 1 deletion src/nb_internals.h
Expand Up @@ -200,7 +200,7 @@ struct nb_internals {
py_set<void *, ptr_hash> funcs;

/// Registered C++ -> Python exception translators
std::vector<void (*)(const std::exception_ptr &)> exception_translators;
std::vector<std::pair<exception_translator, void *>> exception_translators;
};

struct current_method {
Expand Down
30 changes: 30 additions & 0 deletions tests/test_exception.cpp
Expand Up @@ -2,6 +2,21 @@

namespace nb = nanobind;

class MyError1 : public std::exception {
public:
virtual const char *what() const noexcept { return "MyError1"; }
};

class MyError2 : public std::exception {
public:
virtual const char *what() const noexcept { return "MyError2"; }
};

class MyError3 : public std::exception {
public:
virtual const char *what() const noexcept { return "MyError3"; }
};

NB_MODULE(test_exception_ext, m) {
m.def("raise_generic", [] { throw std::exception(); });
m.def("raise_bad_alloc", [] { throw std::bad_alloc(); });
Expand All @@ -19,4 +34,19 @@ NB_MODULE(test_exception_ext, m) {
m.def("raise_import_error", [] { throw nb::import_error("an import error"); });
m.def("raise_attribute_error", [] { throw nb::attribute_error("an attribute error"); });
m.def("raise_stop_iteration", [] { throw nb::stop_iteration("a stop iteration error"); });

m.def("raise_my_error_1", [] { throw MyError1(); });

nb::register_exception_translator(
[](const std::exception_ptr &p, void * /* unused */) {
try {
std::rethrow_exception(p);
} catch (const MyError2 &e) {
PyErr_SetString(PyExc_IndexError, e.what());
}
});
m.def("raise_my_error_2", [] { throw MyError2(); });

nb::exception<MyError3>(m, "MyError3");
m.def("raise_my_error_3", [] { throw MyError3(); });
}
15 changes: 15 additions & 0 deletions tests/test_exception.py
Expand Up @@ -78,3 +78,18 @@ def test16_stop_iteration():
with pytest.raises(StopIteration) as excinfo:
assert t.raise_stop_iteration()
assert str(excinfo.value) == 'a stop iteration error'

def test17_raise_my_error_1():
with pytest.raises(RuntimeError) as excinfo:
assert t.raise_my_error_1()
assert str(excinfo.value) == 'MyError1'

def test18_raise_my_error_2():
with pytest.raises(IndexError) as excinfo:
assert t.raise_my_error_2()
assert str(excinfo.value) == 'MyError2'

def test19_raise_my_error_3():
with pytest.raises(t.MyError3) as excinfo:
assert t.raise_my_error_3()
assert str(excinfo.value) == 'MyError3'

0 comments on commit 41b7da3

Please sign in to comment.