Skip to content

Commit

Permalink
[XLA:Python] Fix a memory corruption bug in the tp_name attribute of …
Browse files Browse the repository at this point in the history
…ArrayImpl and PjitFunction for Python 3.10 or earlier.

This works around python/cpython#89478, which was fixed in Python 3.11.

PiperOrigin-RevId: 631984256
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed May 9, 2024
1 parent 4f68e6e commit 71de102
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
8 changes: 8 additions & 0 deletions third_party/xla/xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <exception>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -1061,7 +1062,14 @@ void BuildPjitSubmodule(nb::module_& m) {
std::string name =
absl::StrCat(nb::cast<std::string>(m.attr("__name__")), ".PjitFunction");
PyType_Spec PjitFunction_spec = {
#if PY_VERSION_HEX < 0x030B0000
// Work around for https://github.com/python/cpython/issues/89478
// CPython 3.10 and earlier assume that the .name value remains alive
// forever.
/*.name=*/strdup(name.c_str()),
#else
/*.name=*/name.c_str(),
#endif // PY_VERSION_HEX < 0x030B0000
/*.basicsize=*/static_cast<int>(sizeof(PjitFunctionObject)),
/*.itemsize=*/0,
#if PY_VERSION_HEX < 0x030C0000
Expand Down
7 changes: 7 additions & 0 deletions third_party/xla/xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,14 @@ Status PyArray::RegisterTypes(nb::module_& m) {
absl::StrCat(nb::cast<std::string>(m.attr("__name__")), ".ArrayImpl");

PyType_Spec PyArray_spec = {
#if PY_VERSION_HEX < 0x030B0000
// Work around for https://github.com/python/cpython/issues/89478
// CPython 3.10 and earlier assume that the .name value remains alive
// forever.
/*.name=*/strdup(name.c_str()),
#else
/*.name=*/name.c_str(),
#endif // PY_VERSION_HEX < 0x030B0000
/*.basicsize=*/static_cast<int>(sizeof(PyArrayObject)),
/*.itemsize=*/0,
#if PY_VERSION_HEX < 0x030C0000
Expand Down

0 comments on commit 71de102

Please sign in to comment.