Skip to content

Commit

Permalink
Fix zero-refcount bug for callbacks with a StatusOr<PyObject *> ret…
Browse files Browse the repository at this point in the history
…urn.

Note:

This change is the result of exploring and discarding multiple approaches to fixing the zero-refcount bug in a more general way. The only approach that worked out is this local fix in the callback code. Fundamentally, code involving `StatusOr<PyObject *>` objects is inherently unsafe and bug prone, because ownership of the Python reference is not managed automatically. Ideally use of `StatusOr<PyObject *>` would generate compilation errors (e.g. via `static_assert`), which would be easy to achieve just in the pybind11_abseil repo, but would require significant sprawling changes around the Google codebase. Unfortunately, currently this is infeasible.

The command used for manual leak checking (see `# Manual verification` comments in status_testing_no_cpp_eh_test_lib.py) was:

```
blaze run //third_party/pybind11_abseil/tests:status_testing_no_cpp_eh_test
```

The `top` command was used to visually monitor `RES` for about 10 seconds, for each test case.

PiperOrigin-RevId: 597061300
  • Loading branch information
rwgk authored and Copybara-Service committed Jan 9, 2024
1 parent 552e9e7 commit ecbbf71
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 5 deletions.
9 changes: 8 additions & 1 deletion pybind11_abseil/statusor_caster.h
Expand Up @@ -146,7 +146,14 @@ struct func_wrapper<absl::StatusOr<PayloadType>, Args...> : func_wrapper_base {
object py_result =
hfunc.f.call_with_policies(rvpp, std::forward<Args>(args)...);
try {
return py_result.template cast<absl::StatusOr<PayloadType>>();
auto cpp_result =
py_result.template cast<absl::StatusOr<PayloadType>>();
// Intentionally not `if constexpr`: runtime overhead is insignificant.
if (is_same_ignoring_cvref<PayloadType, PyObject*>::value) {
// Ownership of the Python reference was transferred to cpp_result.
py_result.release();
}
return cpp_result;
} catch (cast_error& e) {
return absl::Status(absl::StatusCode::kInvalidArgument, e.what());
}
Expand Down
2 changes: 2 additions & 0 deletions pybind11_abseil/tests/BUILD
Expand Up @@ -27,6 +27,7 @@ cc_library(
name = "status_testing_no_cpp_eh_lib",
hdrs = ["status_testing_no_cpp_eh_lib.h"],
deps = [
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_python//:python_headers", # buildcleaner: keep
Expand All @@ -47,6 +48,7 @@ pybind_extension(
py_library(
name = "status_testing_no_cpp_eh_test_lib",
srcs = ["status_testing_no_cpp_eh_test_lib.py"],
data = ["//pybind11_abseil:status.so"],
)

pybind_extension(
Expand Down
32 changes: 32 additions & 0 deletions pybind11_abseil/tests/status_testing_no_cpp_eh_lib.h
Expand Up @@ -10,6 +10,7 @@
#include <functional>
#include <string>

#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"

Expand Down Expand Up @@ -44,6 +45,37 @@ inline absl::Status GenerateErrorStatusNotOk() {
return absl::AlreadyExistsError("Something went wrong, again.");
}

inline absl::StatusOr<PyObject *> ReturnStatusOrPyObjectPtr(bool is_ok) {
if (is_ok) {
return PyLong_FromLongLong(2314L);
}
return absl::InvalidArgumentError("!is_ok");
}

inline std::string PassStatusOrPyObjectPtr(
const absl::StatusOr<PyObject *> &obj) {
if (!obj.ok()) {
return "!obj.ok()@" + std::string(obj.status().message());
}
if (PyTuple_CheckExact(obj.value())) {
return "is_tuple";
}
return "!is_tuple";
}

inline std::string CallCallbackWithStatusOrPyObjectPtrReturn(
const std::function<absl::StatusOr<PyObject *>(std::string)> &cb,
std::string cb_arg) {
// Implicitly take ownership of Python reference:
absl::StatusOr<PyObject *> cb_result = cb(cb_arg);
std::string result = PassStatusOrPyObjectPtr(cb_result);
if (cb_result.ok()) {
ABSL_CHECK_NE(Py_REFCNT(cb_result.value()), 0);
Py_DECREF(cb_result.value()); // Release owned reference.
}
return result;
}

} // namespace status_testing_no_cpp_eh
} // namespace pybind11_abseil_tests

Expand Down
13 changes: 13 additions & 0 deletions pybind11_abseil/tests/status_testing_no_cpp_eh_pybind.cc
Expand Up @@ -23,6 +23,19 @@ PYBIND11_MODULE(status_testing_no_cpp_eh_pybind, m) {
&CallCallbackWithStatusOrObjectReturn,
pybind11::return_value_policy::take_ownership);
m.def("GenerateErrorStatusNotOk", &GenerateErrorStatusNotOk);

m.attr("PYBIND11_HAS_RETURN_VALUE_POLICY_PACK") =
#if defined(PYBIND11_HAS_RETURN_VALUE_POLICY_PACK)
true;
#else
false;
#endif

m.def("ReturnStatusOrPyObjectPtr", &ReturnStatusOrPyObjectPtr,
pybind11::return_value_policy::take_ownership);
m.def("PassStatusOrPyObjectPtr", &PassStatusOrPyObjectPtr);
m.def("CallCallbackWithStatusOrPyObjectPtrReturn",
&CallCallbackWithStatusOrPyObjectPtrReturn);
}

} // namespace status_testing_no_cpp_eh
Expand Down
62 changes: 58 additions & 4 deletions pybind11_abseil/tests/status_testing_no_cpp_eh_test_lib.py
Expand Up @@ -7,6 +7,8 @@
from absl.testing import absltest
from absl.testing import parameterized

from pybind11_abseil import status

# Exercises status_from_core_py_exc.cc:StatusFromFetchedExc()
TAB_StatusFromFetchedExc = (
(MemoryError, 'RESOURCE_EXHAUSTED: MemoryError'),
Expand Down Expand Up @@ -128,9 +130,6 @@ def setUp(self):
self.tm = self.getTestModule() # pytype: disable=attribute-error

def testStatusOrObject(self): # pylint: disable=invalid-name
if getattr(self.tm, '__pyclif_codegen_mode__', None) != 'c_api':
self.skipTest('TODO(cl/578064081)')
# No leak (manually verified under cl/485274434).
while True:
lst = [1, 2, 3, 4]

Expand All @@ -142,4 +141,59 @@ def cb():
res = self.tm.CallCallbackWithStatusOrObjectReturn(cb)
self.assertListEqual(res, lst)
self.assertIs(res, lst)
return # Comment out for manual leak checking (use `top` command).
break # Comment out for manual leak checking (use `top` command).
# Manual verification: cl/485274434, cl/578064081

def testReturnStatusOrPyObjectPtr(self): # pylint: disable=invalid-name
obj = self.tm.ReturnStatusOrPyObjectPtr(True)
self.assertEqual(obj, 2314)
with self.assertRaises(status.StatusNotOk) as ctx:
self.tm.ReturnStatusOrPyObjectPtr(False)
self.assertEqual(str(ctx.exception), '!is_ok [INVALID_ARGUMENT]')
while True:
self.tm.ReturnStatusOrPyObjectPtr(True)
break # Comment out for manual leak checking (use `top` command).
# Manual verification: cl/578064081
while True:
with self.assertRaises(status.StatusNotOk) as ctx:
self.tm.ReturnStatusOrPyObjectPtr(False)
break # Comment out for manual leak checking (use `top` command).
# Manual verification: cl/578064081

def testPassStatusOrPyObjectPtr(self): # pylint: disable=invalid-name
pass_fn = self.tm.PassStatusOrPyObjectPtr
self.assertEqual(pass_fn(()), 'is_tuple')
self.assertEqual(pass_fn([]), '!is_tuple')
while True:
pass_fn([])
break # Comment out for manual leak checking (use `top` command).
# Manual verification: cl/578064081

def testCallCallbackWithStatusOrPyObjectPtrReturn(self): # pylint: disable=invalid-name
def cb(arg):
if arg == 'tup':
return ()
if arg == 'lst':
return []
raise ValueError(f'Unknown arg: {repr(arg)}')

cc_fn = self.tm.CallCallbackWithStatusOrPyObjectPtrReturn
res = cc_fn(cb, 'tup')
self.assertEqual(res, 'is_tuple')
res = cc_fn(cb, 'lst')
self.assertEqual(res, '!is_tuple')

if (
hasattr(self.tm, '__pyclif_codegen_mode__')
or self.tm.PYBIND11_HAS_RETURN_VALUE_POLICY_PACK
):
res = cc_fn(cb, 'exc')
self.assertEqual(res, "!obj.ok()@ValueError: Unknown arg: 'exc'")
while True:
cc_fn(cb, 'exc')
break # Comment out for manual leak checking (use `top` command).
# Manual verification: cl/578064081
else:
with self.assertRaises(ValueError) as ctx:
cc_fn(cb, 'exc')
self.assertEqual(str(ctx.exception), "Unknown arg: 'exc'")

0 comments on commit ecbbf71

Please sign in to comment.