In [7]:
%load_ext autoreload
%autoreload 3

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
from heapq import *

import numpy as np
import numba as nb

from numba import njit
from numba.typed import List
from numba.types import Tuple

from pytil.data_structures.array_heap import get_array_heap_1d_items_jitclass

### My ArrayHeap1 class


In [20]:
@njit
def test_array_heap_interspersed(capacity, heap, size, iterations, item_size):
    for count in range(size):
        heap.heappush(np.random.randint(0, capacity, size=(item_size,)).astype(np.float64))
    for i in range(iterations):
        if np.random.randint(0, 2) == 0 and size > capacity // 2 or size == capacity:
            heap.heappop()
            size -= 1
        else:
            heap.heappush(np.random.randint(0, capacity, size=(item_size,)).astype(np.float64))
            size += 1
    output = []
    for count in range(size, 0, -1):
        output.append(heap.heappop())
    return output

In [34]:
capacity = 2_000
item_size = 3
iterations = capacity * 2_000
heap = get_array_heap_1d_items_jitclass(nb.float64)(capacity, item_size)
size = capacity // 2
output = test_array_heap_interspersed(capacity, heap, size, iterations, item_size)
assert [tuple(x) for x in sorted(output, key=lambda x: tuple(x))] == [tuple(x) for x in output]
print(len(output))

1518


### Using List of ndarrays


In [4]:
@njit(inline='always')
def is_less(a, b):
    """
    Compare two arrays element-by-element.
    Return True if 'a' is lexicographically less than 'b';
    False otherwise.
    """
    m = len(a)
    for i in range(m):
        if a[i] < b[i]:
            return True
        elif a[i] > b[i]:
            return False
    # They are equal
    return False


@njit
def njit_np_heappush(heap, item):
    heap.append(item)
    _siftdown_np(heap, 0, len(heap) - 1)


@njit
def njit_np_heappop(heap):
    lastelt = heap.pop()
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        _siftup_np(heap, 0)
        return returnitem
    return lastelt


@njit
def njit_np_heapify(heap):
    n = len(heap)
    for i in range(n // 2 - 1, -1, -1):
        _siftup_np(heap, i)


@njit
def _siftdown_np(heap, startpos, pos):
    newitem = heap[pos]
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if is_less(newitem, parent):
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem


@njit
def _siftup_np(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    childpos = 2 * pos + 1
    while childpos < endpos:
        rightpos = childpos + 1
        if rightpos < endpos and not is_less(heap[childpos], heap[rightpos]):
            childpos = rightpos
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2 * pos + 1
    heap[pos] = newitem
    _siftdown_np(heap, startpos, pos)

In [30]:
@njit
def test_njit_np_interspersed(capacity, heap, size, iterations, item_size):
    for count in range(size):
        njit_np_heappush(heap, np.random.randint(0, capacity, size=(item_size,)).astype(np.float64))
    for i in range(iterations):
        if np.random.randint(0, 2) == 0 and size > capacity // 2 or size == capacity:
            njit_np_heappop(heap)
            size -= 1
        else:
            njit_np_heappush(heap, np.random.randint(0, capacity, size=(item_size,)).astype(np.float64))
            size += 1
    output = []
    for count in range(size, 0, -1):
        output.append(njit_np_heappop(heap))
    return output

In [35]:
capacity = 2_000
item_size = 3
iterations = capacity * 2_000
size = capacity // 2
heap = List.empty_list(nb.float64[:])
output = test_njit_np_interspersed(capacity, heap, size, iterations, item_size)
assert [tuple(x) for x in sorted(output, key=lambda x: tuple(x))] == [tuple(x) for x in output]
print(len(output))

1060
