Skip to content

Commit

Permalink
Support hashing of nb::enum_ instances (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
mnijhuis-tos committed Dec 5, 2022
1 parent 633672c commit ee35767
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
28 changes: 28 additions & 0 deletions src/nb_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,33 @@ int nb_enum_traverse(PyObject *o, visitproc visit, void *arg) {
return 0;
}

Py_hash_t nb_enum_hash(PyObject *o) {
Py_hash_t value = 0;
type_data *t = nb_type_data(Py_TYPE(o));
if (t->flags & (uint32_t(type_flags::is_unsigned_enum) |
uint32_t(type_flags::is_signed_enum))) {
const void *p = inst_ptr((nb_inst *) o);
switch (t->size) {
case 1: value = *(const int8_t *) p; break;
case 2: value = *(const int16_t *) p; break;
case 4: value = *(const int32_t *) p; break;
case 8: value = *(const int64_t *) p; break;
default:
PyErr_SetString(PyExc_TypeError, "nb_enum: invalid type size!");
return -1;
}
} else {
PyErr_SetString(PyExc_TypeError, "nb_enum: input is not an enumeration!");
return -1;
}

// Hash functions should return -1 when an error occurred.
// Return -2 that case, since hash(-1) also yields -2.
if (value == -1) value = -2;

return value;
}

void nb_enum_prepare(PyType_Slot **s, bool is_arithmetic) {
PyType_Slot *t = *s;

Expand All @@ -214,6 +241,7 @@ void nb_enum_prepare(PyType_Slot **s, bool is_arithmetic) {
*t++ = { Py_tp_getset, (void *) nb_enum_getset };
*t++ = { Py_tp_traverse, (void *) nb_enum_traverse };
*t++ = { Py_tp_clear, (void *) nb_enum_clear };
*t++ = { Py_tp_hash, (void *) nb_enum_hash };

if (is_arithmetic) {
*t++ = { Py_nb_add, (void *) nb_enum_add };
Expand Down
2 changes: 1 addition & 1 deletion src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ PyObject *nb_type_new(const type_data *t) noexcept {
}
char *name_copy = NB_STRDUP(name.c_str());

constexpr size_t nb_enum_max_slots = 21,
constexpr size_t nb_enum_max_slots = 22,
nb_type_max_slots = 10,
nb_extra_slots = 80,
nb_total_slots = nb_enum_max_slots +
Expand Down
6 changes: 6 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test01_unsigned_enum():
assert t.to_enum(0) == t.Enum.A
assert t.to_enum(1) == t.Enum.B
assert t.to_enum(0xffffffff) == t.Enum.C
assert hash(t.Enum.A) == 0
assert hash(t.Enum.B) == 1
assert hash(t.Enum.C) == -2 # -1 is an invalid hash value.

with pytest.raises(RuntimeError) as excinfo:
t.to_enum(5).__name__
Expand All @@ -51,6 +54,9 @@ def test02_signed_enum():
assert t.from_enum(t.SEnum.A) == 0
assert t.from_enum(t.SEnum.B) == 1
assert t.from_enum(t.SEnum.C) == -1
assert hash(t.SEnum.A) == 0
assert hash(t.SEnum.B) == 1
assert hash(t.SEnum.C) == -2 # -1 is an invalid hash value.


def test03_enum_arithmetic():
Expand Down

0 comments on commit ee35767

Please sign in to comment.