From 41b7da33f1bc5c583bb98df66bdac2a058ec5c15 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Wed, 12 Oct 2022 20:54:58 +0200 Subject: [PATCH] custom exception translators --- include/nanobind/nb_error.h | 23 +++++++++++++++++++++++ include/nanobind/nb_lib.h | 10 ++++++++++ src/error.cpp | 35 +++++++++++++++++++++++++++++++++++ src/nb_func.cpp | 6 ++++-- src/nb_internals.cpp | 5 +++-- src/nb_internals.h | 2 +- tests/test_exception.cpp | 30 ++++++++++++++++++++++++++++++ tests/test_exception.py | 15 +++++++++++++++ 8 files changed, 121 insertions(+), 5 deletions(-) diff --git a/include/nanobind/nb_error.h b/include/nanobind/nb_error.h index 3b790f4c..54febc3d 100644 --- a/include/nanobind/nb_error.h +++ b/include/nanobind/nb_error.h @@ -71,4 +71,27 @@ NB_EXCEPTION(attribute_error) #undef NB_EXCEPTION +inline void register_exception_translator(detail::exception_translator t, + void *payload = nullptr) { + detail::register_exception_translator(t, payload); +} + +template +class exception : public object { + NB_OBJECT_DEFAULT(exception, object, "Exception", PyExceptionClass_Check) + + exception(handle mod, const char *name, handle base = PyExc_Exception) + : object(detail::exception_new(mod.ptr(), name, base.ptr()), + detail::steal_t()) { + detail::register_exception_translator( + [](const std::exception_ptr &p, void *payload) { + try { + std::rethrow_exception(p); + } catch (T &e) { + PyErr_SetString((PyObject *) payload, e.what()); + } + }, m_ptr); + } +}; + NAMESPACE_END(NB_NAMESPACE) diff --git a/include/nanobind/nb_lib.h b/include/nanobind/nb_lib.h index 2ad8bfe5..7fb2bffc 100644 --- a/include/nanobind/nb_lib.h +++ b/include/nanobind/nb_lib.h @@ -369,6 +369,16 @@ NB_CORE void print(PyObject *file, PyObject *str, PyObject *end); // ======================================================================== +typedef void (*exception_translator)(const std::exception_ptr &, void *); + +NB_CORE void register_exception_translator(exception_translator translator, + void *payload); + +NB_CORE PyObject *exception_new(PyObject *mod, const char *name, + PyObject *base); + +// ======================================================================== + NB_CORE std::pair load_i8 (PyObject *o, uint8_t flags) noexcept; NB_CORE std::pair load_u8 (PyObject *o, uint8_t flags) noexcept; NB_CORE std::pair load_i16(PyObject *o, uint8_t flags) noexcept; diff --git a/src/error.cpp b/src/error.cpp index 4a5ebd19..4dd05853 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -9,6 +9,7 @@ #include #include "buffer.h" +#include "nb_internals.h" NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) @@ -124,4 +125,38 @@ NB_EXCEPTION(attribute_error, PyExc_AttributeError) #undef NB_EXCEPTION +NAMESPACE_BEGIN(detail) + +void register_exception_translator(exception_translator t, void *payload) { + auto &et = internals_get().exception_translators; + et.insert(et.begin(), { t, payload }); +} + +NB_CORE PyObject *exception_new(PyObject *scope, const char *name, + PyObject *base) { + object modname; + if (PyModule_Check(scope)) + modname = getattr(scope, "__name__", handle()); + else + modname = getattr(scope, "__module__", handle()); + + if (!modname.is_valid()) + raise("nanobind::detail::exception_new(): could not determine module name!"); + + str combined = steal( + PyUnicode_FromFormat("%U.%s", modname.ptr(), name)); + + PyObject *result = PyErr_NewException(combined.c_str(), base, nullptr); + if (!result) + raise("nanobind::detail::exception_new(): creation failed!"); + + if (hasattr(scope, name)) + raise("nb::detail::exception_new(): an object of the same name already " + "exists!"); + + setattr(scope, name, result); + return result; +} + +NAMESPACE_END(detail) NAMESPACE_END(NB_NAMESPACE) diff --git a/src/nb_func.cpp b/src/nb_func.cpp index e9d490c8..146d86ed 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -319,9 +319,11 @@ static NB_NOINLINE PyObject *nb_func_error_noconvert(PyObject *self, /// Used by nb_func_vectorcall: convert a C++ exception into a Python error static NB_NOINLINE void nb_func_convert_cpp_exception() noexcept { std::exception_ptr e = std::current_exception(); - for (auto const &et : internals_get().exception_translators) { + + for (auto pair : internals_get().exception_translators) { try { - et(e); + // Try exception translator & forward payload + pair.first(e, pair.second); return; } catch (...) { e = std::current_exception(); diff --git a/src/nb_internals.cpp b/src/nb_internals.cpp index d78aca8b..0e8023fb 100644 --- a/src/nb_internals.cpp +++ b/src/nb_internals.cpp @@ -254,7 +254,7 @@ NB_THREAD_LOCAL current_method current_method_data = static nb_internals *internals_p = nullptr; -void default_exception_translator(const std::exception_ptr &p) { +void default_exception_translator(const std::exception_ptr &p, void *) { try { std::rethrow_exception(p); } catch (python_error &e) { @@ -325,7 +325,6 @@ static void internals_make() { str nb_name("nanobind"); internals_p = new nb_internals(); - internals_p->exception_translators.push_back(default_exception_translator); PyObject *capsule = PyCapsule_New(internals_p, nullptr, nullptr); PyObject *nb_module = PyModule_NewObject(nb_name.ptr()); @@ -388,6 +387,8 @@ static void internals_make() { internals_p->nb_bound_method->tp_vectorcall_offset = offsetof(nb_bound_method, vectorcall); #endif + register_exception_translator(default_exception_translator, nullptr); + if (Py_AtExit(internals_cleanup)) fprintf(stderr, "Warning: could not install the nanobind cleanup handler! This " diff --git a/src/nb_internals.h b/src/nb_internals.h index 8ca07a31..1fd5f9e5 100644 --- a/src/nb_internals.h +++ b/src/nb_internals.h @@ -200,7 +200,7 @@ struct nb_internals { py_set funcs; /// Registered C++ -> Python exception translators - std::vector exception_translators; + std::vector> exception_translators; }; struct current_method { diff --git a/tests/test_exception.cpp b/tests/test_exception.cpp index a3cecdf2..2755ab66 100644 --- a/tests/test_exception.cpp +++ b/tests/test_exception.cpp @@ -2,6 +2,21 @@ namespace nb = nanobind; +class MyError1 : public std::exception { +public: + virtual const char *what() const noexcept { return "MyError1"; } +}; + +class MyError2 : public std::exception { +public: + virtual const char *what() const noexcept { return "MyError2"; } +}; + +class MyError3 : public std::exception { +public: + virtual const char *what() const noexcept { return "MyError3"; } +}; + NB_MODULE(test_exception_ext, m) { m.def("raise_generic", [] { throw std::exception(); }); m.def("raise_bad_alloc", [] { throw std::bad_alloc(); }); @@ -19,4 +34,19 @@ NB_MODULE(test_exception_ext, m) { m.def("raise_import_error", [] { throw nb::import_error("an import error"); }); m.def("raise_attribute_error", [] { throw nb::attribute_error("an attribute error"); }); m.def("raise_stop_iteration", [] { throw nb::stop_iteration("a stop iteration error"); }); + + m.def("raise_my_error_1", [] { throw MyError1(); }); + + nb::register_exception_translator( + [](const std::exception_ptr &p, void * /* unused */) { + try { + std::rethrow_exception(p); + } catch (const MyError2 &e) { + PyErr_SetString(PyExc_IndexError, e.what()); + } + }); + m.def("raise_my_error_2", [] { throw MyError2(); }); + + nb::exception(m, "MyError3"); + m.def("raise_my_error_3", [] { throw MyError3(); }); } diff --git a/tests/test_exception.py b/tests/test_exception.py index c40d2092..18a6aaf3 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -78,3 +78,18 @@ def test16_stop_iteration(): with pytest.raises(StopIteration) as excinfo: assert t.raise_stop_iteration() assert str(excinfo.value) == 'a stop iteration error' + +def test17_raise_my_error_1(): + with pytest.raises(RuntimeError) as excinfo: + assert t.raise_my_error_1() + assert str(excinfo.value) == 'MyError1' + +def test18_raise_my_error_2(): + with pytest.raises(IndexError) as excinfo: + assert t.raise_my_error_2() + assert str(excinfo.value) == 'MyError2' + +def test19_raise_my_error_3(): + with pytest.raises(t.MyError3) as excinfo: + assert t.raise_my_error_3() + assert str(excinfo.value) == 'MyError3'