Skip to content

Commit

Permalink
bpo-39421: Fix posible crash in heapq with custom comparison operators (
Browse files Browse the repository at this point in the history
GH-18118) (GH-18146)

(cherry picked from commit 79f89e6)

Co-authored-by: Pablo Galindo <Pablogsal@gmail.com>
  • Loading branch information
2 people authored and ned-deily committed Jan 23, 2020
1 parent fe24458 commit c563f40
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
31 changes: 31 additions & 0 deletions Lib/test/test_heapq.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,37 @@ def test_heappop_mutating_heap(self):
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappop(heap)

def test_comparison_operator_modifiying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
class EvilClass(int):
def __lt__(self, o):
heap.clear()
return NotImplemented

heap = []
self.module.heappush(heap, EvilClass(0))
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)

def test_comparison_operator_modifiying_heap_two_heaps(self):

class h(int):
def __lt__(self, o):
list2.clear()
return NotImplemented

class g(int):
def __lt__(self, o):
list1.clear()
return NotImplemented

list1, list2 = [], []

self.module.heappush(list1, h(0))
self.module.heappush(list2, g(0))

self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))

class TestErrorHandlingPython(TestErrorHandling, TestCase):
module = py_heapq
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix possible crashes when operating with the functions in the :mod:`heapq`
module and custom comparison operators.
35 changes: 26 additions & 9 deletions Modules/_heapqmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
while (pos > startpos) {
parentpos = (pos - 1) >> 1;
parent = arr[parentpos];
Py_INCREF(newitem);
Py_INCREF(parent);
cmp = PyObject_RichCompareBool(newitem, parent, Py_LT);
Py_DECREF(parent);
Py_DECREF(newitem);
if (cmp < 0)
return -1;
if (size != PyList_GET_SIZE(heap)) {
Expand Down Expand Up @@ -71,10 +75,13 @@ siftup(PyListObject *heap, Py_ssize_t pos)
/* Set childpos to index of smaller child. */
childpos = 2*pos + 1; /* leftmost child position */
if (childpos + 1 < endpos) {
cmp = PyObject_RichCompareBool(
arr[childpos],
arr[childpos + 1],
Py_LT);
PyObject* a = arr[childpos];
PyObject* b = arr[childpos + 1];
Py_INCREF(a);
Py_INCREF(b);
cmp = PyObject_RichCompareBool(a, b, Py_LT);
Py_DECREF(a);
Py_DECREF(b);
if (cmp < 0)
return -1;
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */
Expand Down Expand Up @@ -229,7 +236,10 @@ heappushpop(PyObject *self, PyObject *args)
return item;
}

cmp = PyObject_RichCompareBool(PyList_GET_ITEM(heap, 0), item, Py_LT);
PyObject* top = PyList_GET_ITEM(heap, 0);
Py_INCREF(top);
cmp = PyObject_RichCompareBool(top, item, Py_LT);
Py_DECREF(top);
if (cmp < 0)
return NULL;
if (cmp == 0) {
Expand Down Expand Up @@ -383,7 +393,11 @@ siftdown_max(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
while (pos > startpos) {
parentpos = (pos - 1) >> 1;
parent = arr[parentpos];
Py_INCREF(parent);
Py_INCREF(newitem);
cmp = PyObject_RichCompareBool(parent, newitem, Py_LT);
Py_DECREF(parent);
Py_DECREF(newitem);
if (cmp < 0)
return -1;
if (size != PyList_GET_SIZE(heap)) {
Expand Down Expand Up @@ -425,10 +439,13 @@ siftup_max(PyListObject *heap, Py_ssize_t pos)
/* Set childpos to index of smaller child. */
childpos = 2*pos + 1; /* leftmost child position */
if (childpos + 1 < endpos) {
cmp = PyObject_RichCompareBool(
arr[childpos + 1],
arr[childpos],
Py_LT);
PyObject* a = arr[childpos + 1];
PyObject* b = arr[childpos];
Py_INCREF(a);
Py_INCREF(b);
cmp = PyObject_RichCompareBool(a, b, Py_LT);
Py_DECREF(a);
Py_DECREF(b);
if (cmp < 0)
return -1;
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */
Expand Down

0 comments on commit c563f40

Please sign in to comment.