Skip to content

Commit

Permalink
pickle support
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Jun 6, 2023
1 parent 2bf6d74 commit 59843e0
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 3 deletions.
34 changes: 34 additions & 0 deletions docs/classes.rst
Expand Up @@ -903,3 +903,37 @@ expected:
>>> my_ext.make_pet(my_ext.PetKind.Dog)
<my_ext.Dog object at 0x104da6ef0>
Pickling
--------

To pickle and unpickle objects bound using nanobind, expose the
``__getstate__`` and ``__setstate__`` methods. They should return and retrieve
the internal instance state using representations that themselves support
pickling. The example below, e.g., does this using a tuple.

The ``__setstate__`` method should construct the object in-place analogous to
custom ``__init__``-style constructors.

.. code-block:: cpp
#include <nanobind/stl/tuple.h>
struct Pet {
std::string name;
int age;
Pet(const std::string &name, int age) : name(name), age(age) { }
};
NB_MODULE(my_ext, m) {
nb::class_<Pet>(m, "Pet")
// ...
.def("__getstate__", [](const Pet &pet) { return std::make_tuple(pet.name, pet.age); })
.def("__setstate__", [](Pet &pet, const std::tuple<std::string, int> &state) {
new (&pet) Pet(
std::get<0>(state),
std::get<1>(state)
);
});
}
4 changes: 3 additions & 1 deletion src/nb_func.cpp
Expand Up @@ -222,7 +222,9 @@ PyObject *nb_func_new(const void *in_) noexcept {
}

// Is this method a constructor that takes a class binding as first parameter?
is_constructor = is_method && strcmp(f->name, "__init__") == 0 &&
is_constructor = is_method &&
(strcmp(f->name, "__init__") == 0 ||
strcmp(f->name, "__setstate__") == 0) &&
strncmp(f->descr, "({%}", 4) == 0;

// Don't use implicit conversions in copy constructors (causes infinite recursion)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_classes.cpp
Expand Up @@ -4,6 +4,7 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/tuple.h>
#include <memory>
#include <cstring>
#include <vector>
Expand All @@ -15,7 +16,7 @@ using namespace nb::literals;

static int default_constructed = 0, value_constructed = 0, copy_constructed = 0,
move_constructed = 0, copy_assigned = 0, move_assigned = 0,
destructed = 0;
destructed = 0, pickled = 0, unpickled = 0;

struct Struct;
std::unique_ptr<Struct> struct_tmp;
Expand All @@ -32,7 +33,9 @@ struct Struct {
~Struct() { destructed++; }

int value() const { return i; }
int getstate() const { ++pickled; return i; }
void set_value(int value) { i = value; }
void setstate(int value) { unpickled++; i = value; }

static int static_test(int) { return 1; }
static int static_test(float) { return 2; }
Expand Down Expand Up @@ -120,6 +123,8 @@ NB_MODULE(test_classes_ext, m) {
.def("set_value", &Struct::set_value, "value"_a)
.def("self", &Struct::self, nb::rv_policy::none)
.def("none", [](Struct &) -> const Struct * { return nullptr; })
.def("__getstate__", &Struct::getstate)
.def("__setstate__", &Struct::setstate)
.def_static("static_test", nb::overload_cast<int>(&Struct::static_test))
.def_static("static_test", nb::overload_cast<float>(&Struct::static_test))
.def_static("create_move", &Struct::create_move)
Expand All @@ -129,7 +134,6 @@ NB_MODULE(test_classes_ext, m) {
nb::rv_policy::copy)
.def_static("create_take", &Struct::create_take);


if (!nb::type<Struct>().is(cls))
nb::detail::raise("type lookup failed!");

Expand All @@ -147,6 +151,8 @@ NB_MODULE(test_classes_ext, m) {
d["copy_assigned"] = copy_assigned;
d["move_assigned"] = move_assigned;
d["destructed"] = destructed;
d["pickled"] = pickled;
d["unpickled"] = unpickled;
return d;
});

Expand All @@ -158,6 +164,8 @@ NB_MODULE(test_classes_ext, m) {
copy_assigned = 0;
move_assigned = 0;
destructed = 0;
pickled = 0;
unpickled = 0;
});

// test06_big
Expand Down
17 changes: 17 additions & 0 deletions tests/test_classes.py
Expand Up @@ -649,3 +649,20 @@ def test35_method_introspection():
assert m.__qualname__ == "Struct.value"
assert m.__module__ == t.__name__
assert m.__doc__ == t.Struct.value.__doc__ == "value(self) -> int"


def test38_pickle(clean):
import pickle

s = t.Struct(123)
s2 = pickle.dumps(s, protocol=pickle.HIGHEST_PROTOCOL)
s3 = pickle.loads(s2)
assert s.value() == s3.value()
del s, s3

assert_stats(
value_constructed=1,
pickled=1,
unpickled=1,
destructed=2
)

0 comments on commit 59843e0

Please sign in to comment.