Skip to content

Commit

Permalink
Properly handle non-interned keyword argument names (#469)
Browse files Browse the repository at this point in the history
  • Loading branch information
oremanj committed Mar 12, 2024
1 parent df8996a commit fa873d1
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 14 deletions.
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ noteworthy:
if nanobind did not know about that base.
(PR `#471 <https://github.com/wjakob/nanobind/pull/471>`__).

* nanobind can now handle keyword arguments that are not interned, which
avoids spurious TypeErrors in constructs like ``fn(**pickle.loads(...))``.
The speed of normal function calls (which generally do have interned
keyword arguments) should be unaffected.
(PR `#469 <https://github.com/wjakob/nanobind/pull/469>`__).

* ABI version 14.

.. rubric:: Footnote
Expand Down
47 changes: 40 additions & 7 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,43 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
uint8_t *args_flags = (uint8_t *) alloca(max_nargs * sizeof(uint8_t));
bool *kwarg_used = (bool *) alloca(nkwargs_in * sizeof(bool));

// Ensure that keyword argument names are interned. That makes it faster
// to compare them against pre-interned argument names in the overload chain.
// Normal function calls will have their keyword arguments already interned,
// but we can't rely on that; it fails for things like fn(**json.loads(...)).
PyObject **kwnames = nullptr;

#if !defined(PYPY_VERSION) && !defined(Py_LIMITED_API)
bool kwnames_interned = true;
for (size_t i = 0; i < nkwargs_in; ++i) {
PyObject *key = NB_TUPLE_GET_ITEM(kwargs_in, i);
kwnames_interned &= ((PyASCIIObject *) key)->state.interned != 0;
}
if (NB_LIKELY(kwnames_interned)) {
kwnames = ((PyTupleObject *) kwargs_in)->ob_item;
goto traverse_overloads;
}
#endif

kwnames = (PyObject **) alloca(nkwargs_in * sizeof(PyObject *));
for (size_t i = 0; i < nkwargs_in; ++i) {
PyObject *key = NB_TUPLE_GET_ITEM(kwargs_in, i),
*key_interned = key;
Py_INCREF(key_interned);

PyUnicode_InternInPlace(&key_interned);

if (NB_LIKELY(key == key_interned)) // string was already interned
Py_DECREF(key_interned);
else
cleanup.append(key_interned);
kwnames[i] = key_interned;
}

#if !defined(PYPY_VERSION) && !defined(Py_LIMITED_API)
traverse_overloads:
#endif

/* The logic below tries to find a suitable overload using two passes
of the overload chain (or 1, if there are no overloads). The first pass
is strict and permits no implicit conversions, while the second pass
Expand Down Expand Up @@ -610,12 +647,8 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
if (kwargs_in && ad.name_py) {
PyObject *hit = nullptr;
for (size_t j = 0; j < nkwargs_in; ++j) {
PyObject *key = NB_TUPLE_GET_ITEM(kwargs_in, j);
#if defined(PYPY_VERSION)
bool match = PyUnicode_Compare(key, ad.name_py) == 0;
#else
bool match = (key == ad.name_py);
#endif
PyObject *key = kwnames[j];
bool match = (key == ad.name_py);
if (match) {
hit = args_in[nargs_in + j];
kwarg_used[j] = true;
Expand Down Expand Up @@ -668,7 +701,7 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
if (has_var_kwargs) {
PyObject *dict = PyDict_New();
for (size_t j = 0; j < nkwargs_in; ++j) {
PyObject *key = NB_TUPLE_GET_ITEM(kwargs_in, j);
PyObject *key = kwnames[j];
if (!kwarg_used[j])
PyDict_SetItem(dict, key, args_in[nargs_in + j]);
}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ NB_MODULE(test_functions_ext, m) {
m.def("test_01", []() { });

// Simple binary function (via function pointer)
auto test_02 = [](int j, int k) -> int { return j - k; };
m.def("test_02", (int (*)(int, int)) test_02, "j"_a = 8, "k"_a = 1);
auto test_02 = [](int up, int down) -> int { return up - down; };
m.def("test_02", (int (*)(int, int)) test_02, "up"_a = 8, "down"_a = 1);

// Simple binary function with capture object
int i = 42;
Expand Down
18 changes: 14 additions & 4 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,18 @@ def test02_default_args():
def test03_kwargs():
# Basic use of keyword arguments
assert t.test_02(3, 5) == -2
assert t.test_02(3, k=5) == -2
assert t.test_02(k=5, j=3) == -2
assert t.test_02(3, down=5) == -2
assert t.test_02(down=5, up=3) == -2

# Make sure non-interned keyword names work also
i_cant_believe_its_not_down = "".join("down")
assert i_cant_believe_its_not_down is not "down"
assert t.test_02(**{i_cant_believe_its_not_down: 5, "up": 3}) == -2
assert t.test_02(**{i_cant_believe_its_not_down: 5}) == 3
with pytest.raises(TypeError):
t.test_02(unexpected=27)
with pytest.raises(TypeError):
t.test_02(**{i_cant_believe_its_not_down: None})


def test04_overloads():
Expand All @@ -42,7 +52,7 @@ def test04_overloads():

def test05_signature():
assert t.test_01.__doc__ == "test_01() -> None"
assert t.test_02.__doc__ == "test_02(j: int = 8, k: int = 1) -> int"
assert t.test_02.__doc__ == "test_02(up: int = 8, down: int = 1) -> int"
assert t.test_05.__doc__ == (
"test_05(arg: int, /) -> int\n"
"test_05(arg: float, /) -> int\n"
Expand Down Expand Up @@ -420,7 +430,7 @@ def test39_del():
def test40_nb_signature():
assert t.test_01.__nb_signature__ == ((r"def test_01() -> None", None, None),)
assert t.test_02.__nb_signature__ == (
(r"def test_02(j: int = \0, k: int = \1) -> int", None, (8, 1)),
(r"def test_02(up: int = \0, down: int = \1) -> int", None, (8, 1)),
)
assert t.test_05.__nb_signature__ == (
(r"def test_05(arg: int, /) -> int", "doc_1", None),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_functions_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class kw_only_methods:

def test_01() -> None: ...

def test_02(j: int = 8, k: int = 1) -> int: ...
def test_02(up: int = 8, down: int = 1) -> int: ...

def test_03(arg0: int, arg1: int, /) -> int: ...

Expand Down

0 comments on commit fa873d1

Please sign in to comment.