In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
from pytil.utility import closure

### Data structures definitions


I am looking to create Numba jitclasses for some data structures. I want to use numpy arrays to contain the bulk of the data in the data structures. The numpy arrays inside these data structures should all be preallocated, so all the **init** methods should have a capacity parameter as the first parameter.


In [None]:
from functools import cache
import numpy as np
import numba as nb
from numba import njit
from numba.experimental import jitclass


@njit(inline='always')
def stringify_1d_array(array):
    return '[' + ', '.join([str(entry) for entry in array]) + ']'


@cache
def get_array_stack_1d_items_jitclass(data_type):
    spec = (
        ('capacity', nb.int64),  # Maximum capacity of the stack
        ('item_size', nb.int64),  # Size of the each item
        ('size', nb.int64),  # Current size of the stack
        ('array', data_type[:, :]),  # Underlying NumPy array to store elements
    )

    @jitclass(spec)
    class ArrayStack1:
        def __init__(self, capacity, item_size):
            self.capacity = capacity
            self.item_size = item_size
            self.size = 0
            self.array = np.empty((capacity, item_size), dtype=data_type)

        def append(self, item):
            """
            Push an item onto the top of the stack.
            """
            if self.size >= self.capacity:
                raise IndexError("Stack is full")
            self.array[self.size] = item
            self.size += 1

        def pop(self):
            """
            Pop and return a copy of the top item from the stack.
            """
            if self.size == 0:
                raise IndexError("Pop from empty stack")
            value = self.array[self.size - 1].copy()
            self.size -= 1
            return value

        def pop_no_copy(self):
            """
            Pop and return a reference to the top item without copying.
            """
            if self.size == 0:
                raise IndexError("Pop from empty stack")
            value_ref = self.array[self.size - 1]
            self.size -= 1
            return value_ref

        def peek(self):
            """
            Return a reference to the top item without removing it.
            """
            if self.size == 0:
                raise IndexError("Peek from empty stack")
            return self.array[self.size - 1]

        def clear(self):
            """
            Remove all items from the stack.
            """
            self.size = 0

        def __getitem__(self, index):
            if index < -self.size or index >= self.size:
                raise IndexError("Index out of bounds")
            if index < 0:
                index += self.size
            return self.array[index]

        def __setitem__(self, index, value):
            if index < -self.size or index >= self.size:
                raise IndexError("Index out of bounds")
            if index < 0:
                index += self.size
            self.array[index] = value

        def __len__(self):
            return self.size

        def __str__(self):
            return '[' + ', '.join([stringify_1d_array(item) for item in self.array[: self.size]]) + ']'

    return ArrayStack1

In [None]:
ArrayStack1 = get_array_stack_1d_items_jitclass(nb.int64)


@closure
@njit
def hi():
    # Example usage
    a = ArrayStack1(5, 3)
    a.append((1, 3, 2))
    a.append((8, 3, 1))
    print(a[1])  # Output: 2.2
    print(len(a))  # Output: 2
    print(a)

[8 3 1]
2
[[1, 3, 2], [8, 3, 1]]


In [None]:
from functools import cache
import numpy as np
import numba as nb
from numba import njit
from numba.experimental import jitclass


@njit(inline='always')
def stringify_1d_array(array):
    return '[' + ', '.join([str(entry) for entry in array]) + ']'


@cache
def get_array_deque_1d_items_jitclass(data_type):
    spec = (
        ('capacity', nb.int64),  # total capacity
        ('item_size', nb.int64),  # size of each item
        ('size', nb.int64),  # current number of elements
        ('start', nb.int64),  # index of the left (front) element
        ('array', data_type[:, :]),
    )

    @jitclass(spec)
    class ArrayDeque1:
        def __init__(self, capacity, item_size):
            self.capacity = capacity
            self.item_size = item_size
            self.size = 0
            self.start = 0
            self.array = np.empty((capacity, item_size), dtype=data_type)

        def append(self, item):
            if self.size >= self.capacity:
                raise IndexError("Deque is full")
            end_index = (self.start + self.size) % self.capacity
            self.array[end_index] = item
            self.size += 1

        def pop(self):
            if self.size == 0:
                raise IndexError("Pop from empty deque")
            end_index = (self.start + self.size - 1) % self.capacity
            value = self.array[end_index].copy()
            self.size -= 1
            return value

        def pop_no_copy(self):
            if self.size == 0:
                raise IndexError("Pop from empty deque")
            end_index = (self.start + self.size - 1) % self.capacity
            value_ref = self.array[end_index]
            self.size -= 1
            return value_ref

        def appendleft(self, item):
            if self.size >= self.capacity:
                raise IndexError("Deque is full")
            left_index = (self.start - 1) % self.capacity
            self.array[left_index] = item
            self.start = left_index
            self.size += 1

        def popleft(self):
            if self.size == 0:
                raise IndexError("Popleft from empty deque")
            value = self.array[self.start].copy()
            self.start = (self.start + 1) % self.capacity
            self.size -= 1
            return value

        def popleft_no_copy(self):
            if self.size == 0:
                raise IndexError("Popleft from empty deque")
            value_ref = self.array[self.start]
            self.start = (self.start + 1) % self.capacity
            self.size -= 1
            return value_ref

        def peek(self):
            if self.size == 0:
                raise IndexError("Peek from empty deque")
            return self.array[(self.start + self.size - 1) % self.capacity]

        def peekleft(self):
            if self.size == 0:
                raise IndexError("Peek from empty deque")
            return self.array[self.start]

        def __getitem__(self, index):
            if index < 0:
                index += self.size
            if index < 0 or index >= self.size:
                raise IndexError("Index out of range")
            actual_index = (self.start + index) % self.capacity
            return self.array[actual_index]

        def __setitem__(self, index, value):
            if index < 0:
                index += self.size
            if index < 0 or index >= self.size:
                raise IndexError("Index out of range")
            actual_index = (self.start + index) % self.capacity
            self.array[actual_index] = value

        def clear(self):
            self.size = 0
            self.start = 0

        def __len__(self):
            return self.size

        def __str__(self):
            if self.size == 0:
                return "[]"
            elems = []
            for i in range(self.size):
                idx = (self.start + i) % self.capacity
                elems.append(stringify_1d_array(self.array[idx]))
            return '[' + ', '.join(elems) + ']'

    return ArrayDeque1

In [None]:
ArrayDeque1 = get_array_deque_1d_items_jitclass(nb.int64)


@closure
@njit
def hi():
    # Example usage
    a = ArrayDeque1(5, 3)
    a.append((1, 3, 2))
    a.append((8, 3, 1))
    print(a[1])  # Output: 2.2
    print(len(a))  # Output: 2
    print(a)

[8 3 1]
2
[[1, 3, 2], [8, 3, 1]]


In [None]:
from functools import cache
import numpy as np
import numba as nb
from numba import njit
from numba.experimental import jitclass


@njit(inline='always')
def array_is_less(a, b):
    """
    Compare two arrays element-by-element.
    Return True if 'a' is lexicographically less than 'b';
    False otherwise.
    """
    m = len(a)
    for i in range(m):
        if a[i] < b[i]:
            return True
        elif a[i] > b[i]:
            return False
    # They are equal
    return False


@njit(inline='always')
def stringify_1d_array(array):
    return '[' + ', '.join([str(entry) for entry in array]) + ']'


@cache
def get_array_heap_1d_items_jitclass(data_type):
    spec = (
        ('capacity', nb.int64),
        ('item_size', nb.int64),
        ('size', nb.int64),
        ('heap', data_type[:, :]),
    )

    @jitclass(spec)
    class ArrayHeap1:
        def __init__(self, capacity, item_size):
            self.capacity = capacity
            self.item_size = item_size
            self.size = 0
            self.heap = np.empty((capacity, item_size), dtype=data_type)

        def heappush(self, item):
            if self.size >= self.capacity:
                raise IndexError("Heap is full")
            self.heap[self.size] = item
            self.size += 1
            self._siftdown(0, self.size - 1)

        def heappop(self):
            if self.size == 0:
                raise IndexError("Pop from empty heap")
            lastelt = self.heap[self.size - 1].copy()
            self.size -= 1
            if self.size > 0:
                returnitem = self.heap[0].copy()
                self.heap[0] = lastelt
                self._siftup(0, self.size)
                return returnitem
            return lastelt

        def heappeek(self):
            if self.size == 0:
                raise IndexError("Peek from empty heap")
            return self.heap[0]  # Return reference

        def heapreplace(self, item):
            if self.size == 0:
                raise IndexError("Replace from empty heap")
            returnitem = self.heap[0].copy()
            self.heap[0] = item
            self._siftup(0, self.size)
            return returnitem

        def heappushpop(self, item):
            if self.size > 0 and array_is_less(self.heap[0], item):
                item, self.heap[0] = self.heap[0], item
                self._siftup(0, self.size)
            return item

        def heapify(self):
            for i in range(self.size // 2 - 1, -1, -1):
                self._siftup(i, self.size)

        def _siftdown(self, startpos, pos):
            newitem = self.heap[pos].copy()
            while pos > startpos:
                parentpos = (pos - 1) >> 1
                parent = self.heap[parentpos]
                if array_is_less(newitem, parent):
                    self.heap[pos] = parent
                    pos = parentpos
                    continue
                break
            self.heap[pos] = newitem

        def _siftup(self, pos, endpos):
            startpos = pos
            newitem = self.heap[pos].copy()
            childpos = 2 * pos + 1
            while childpos < endpos:
                rightpos = childpos + 1
                if rightpos < endpos and not array_is_less(self.heap[childpos], self.heap[rightpos]):
                    childpos = rightpos
                self.heap[pos] = self.heap[childpos]
                pos = childpos
                childpos = 2 * pos + 1
            self.heap[pos] = newitem
            self._siftdown(startpos, pos)

        def __getitem__(self, index):
            if index == 0:
                if self.size == 0:
                    raise IndexError("Index out of range for empty heap")
                # Return reference, not a copy
                return self.heap[0]
            raise IndexError("Only index 0 is supported")

        def __len__(self):
            return self.size

        def __str__(self):
            return '[' + ', '.join([stringify_1d_array(item) for item in self.heap[: self.size]]) + ']'

    return ArrayHeap1

In [None]:
ArrayHeap1 = get_array_heap_1d_items_jitclass(nb.int64)


@closure
@njit
def hi():
    # Example usage
    a = ArrayHeap1(5, 1)
    a.heappush([1])
    a.heappush([3])
    print(a.heappeek())  # Output: 2.2
    print(len(a))  # Output: 2
    print(a)

[1]
2
[[1], [3]]


In [None]:
from functools import cache

import numba as nb
import numpy as np
from numba import njit
from numba.experimental import jitclass


@njit(inline='always')
def array_is_less(a, b):
    """
    Compare two arrays element-by-element for a strict lexicographical order.
    Return True if 'a' is lexicographically less than 'b', False otherwise.
    """
    for i in range(len(a)):
        if a[i] < b[i]:
            return True
        elif a[i] > b[i]:
            return False
    return False  # They are equal in all positions.


@njit(inline='always')
def stringify_1d_array(array):
    return '[' + ', '.join([str(entry) for entry in array]) + ']'


@cache
def get_array_treap_1d_items_jitclass(data_type):
    """
    Factory function to create a treap class specialized for 1D items
    using only 'array_is_less' to define BST ordering (and equality detection).
    """
    assert (
        data_type is nb.int64 or data_type is nb.float64
    ), 'Only 64-bit data types supported as the random key generation is defined up to 10**9'

    spec = [
        ('capacity', nb.int64),
        ('item_size', nb.int64),
        ('size', nb.int64),
        ('root', nb.int64),
        ('data', data_type[:, :]),
        ('priorities', nb.int64[:]),
        ('left', nb.int64[:]),
        ('right', nb.int64[:]),
        ('parent', nb.int64[:]),
        ('free_list', nb.int64[:]),
        ('free_list_top', nb.int64),
        ('subtree_size', nb.int64[:]),
    ]

    @jitclass(spec)
    class ArrayTreap1:
        """
        A treap (tree-based heap) implemented with implicit indexing,
        using 'array_is_less' to compare multi-dimensional items.
        """

        def __init__(self, capacity, item_size):
            """
            Initialize the treap with a maximum capacity and item size.
            Each row of 'data' is considered one item.
            """
            self.capacity = capacity
            self.item_size = item_size
            self.size = 0
            self.root = -1

            # Preallocate arrays
            self.data = np.empty((capacity, item_size), dtype=data_type)
            self.priorities = np.empty(capacity, dtype=np.int64)
            self.left = -np.ones(capacity, dtype=np.int64)
            self.right = -np.ones(capacity, dtype=np.int64)
            self.parent = -np.ones(capacity, dtype=np.int64)

            # Free list: stack of available node indices
            self.free_list = np.empty(capacity, dtype=np.int64)
            for i in range(capacity):
                self.free_list[i] = capacity - 1 - i
            self.free_list_top = capacity

            # For each node, store subtree size
            self.subtree_size = np.zeros(capacity, dtype=np.int64)

        def _recalc(self, x):
            """
            Recalculate the subtree size for node x.
            """
            s = 1
            if self.left[x] != -1:
                s += self.subtree_size[self.left[x]]
            if self.right[x] != -1:
                s += self.subtree_size[self.right[x]]
            self.subtree_size[x] = s

        def _new_node(self, item):
            """
            Allocate a new node from the free list and initialize it.
            """
            if self.free_list_top == 0:
                raise Exception("Treap capacity exceeded")

            self.free_list_top -= 1
            idx = self.free_list[self.free_list_top]
            self.data[idx] = item
            self.priorities[idx] = np.random.randint(1, 10**9)  # random priority
            self.left[idx] = -1
            self.right[idx] = -1
            self.parent[idx] = -1
            self.subtree_size[idx] = 1
            self.size += 1
            return idx

        def rotate_right(self, x):
            y = self.left[x]
            self.left[x] = self.right[y]
            if self.right[y] != -1:
                self.parent[self.right[y]] = x
            self.right[y] = x
            parent_x = self.parent[x]
            self.parent[y] = parent_x
            self.parent[x] = y

            if parent_x == -1:
                self.root = y
            else:
                if self.left[parent_x] == x:
                    self.left[parent_x] = y
                else:
                    self.right[parent_x] = y

            # Recalculate subtree sizes for x and y first.
            self._recalc(x)
            self._recalc(y)

            # Then update the rest of the chain.
            temp = self.parent[y]
            while temp != -1:
                self._recalc(temp)
                temp = self.parent[temp]

        def rotate_left(self, x):
            y = self.right[x]
            self.right[x] = self.left[y]
            if self.left[y] != -1:
                self.parent[self.left[y]] = x
            self.left[y] = x
            parent_x = self.parent[x]
            self.parent[y] = parent_x
            self.parent[x] = y

            if parent_x == -1:
                self.root = y
            else:
                if self.left[parent_x] == x:
                    self.left[parent_x] = y
                else:
                    self.right[parent_x] = y

            # Recalculate subtree sizes for x and y first.
            self._recalc(x)
            self._recalc(y)

            # Then update the rest of the chain.
            temp = self.parent[y]
            while temp != -1:
                self._recalc(temp)
                temp = self.parent[temp]

        def insert(self, item):
            """
            Insert a new item into the treap, maintaining BST order via array_is_less.
            """
            new_idx = self._new_node(item)

            # If tree is empty, new node is the root
            if self.root == -1:
                self.root = new_idx
                return

            # Search for insertion point
            cur = self.root
            parent = -1
            path = []
            while cur != -1:
                path.append(cur)
                parent = cur
                if array_is_less(item, self.data[cur]):
                    cur = self.left[cur]
                else:
                    # If not less, go right
                    cur = self.right[cur]

            # Attach new_idx to parent
            self.parent[new_idx] = parent
            if array_is_less(item, self.data[parent]):
                self.left[parent] = new_idx
            else:
                self.right[parent] = new_idx

            # Update subtree sizes along the path
            for node in path:
                self.subtree_size[node] += 1

            # Bubble up based on priority
            cur = new_idx
            while self.parent[cur] != -1 and self.priorities[self.parent[cur]] < self.priorities[cur]:
                p = self.parent[cur]
                if self.left[p] == cur:
                    self.rotate_right(p)
                else:
                    self.rotate_left(p)

        def __getitem__(self, index):
            """
            Retrieve the in-order item at position 'index', integer-only.
            Negative indices behave similarly to Python's sequence rules.
            """
            if index < 0:
                index += self.size
            if index < 0 or index >= self.size:
                raise IndexError("Index out of range")
            # We do a standard subtree-based 'select'
            node_idx = self.select(index)
            return self.data[node_idx]

        # # Old slice method that wastefully traverses all nodes to the left of the range
        # def slice(self, start, stop):
        #     """
        #     Return a slice of the treap's in-order traversal from 'start' (inclusive)
        #     to 'stop' (exclusive), as a 2D NumPy array (a copy).
        #     Negative indices behave like Python slices. Indices beyond the
        #     valid range are clamped. If stop < start after adjustments, the slice is empty.
        #     """
        #     # Handle negative indices for start
        #     if start < 0:
        #         start += self.size
        #     if start < 0:
        #         start = 0
        #     if start > self.size:
        #         start = self.size

        #     # Handle negative indices for stop
        #     if stop < 0:
        #         stop += self.size
        #     if stop < 0:
        #         stop = 0
        #     if stop > self.size:
        #         stop = self.size

        #     if stop < start:
        #         # Empty slice
        #         return np.empty((0, self.item_size), dtype=self.data.dtype)

        #     length = stop - start
        #     out = np.empty((length, self.item_size), dtype=self.data.dtype)

        #     # Partial in-order traversal
        #     stack = np.empty(self.capacity, dtype=np.int64)
        #     stack_top = 0
        #     cur = self.root
        #     count = 0  # Current in-order index
        #     out_pos = 0  # Position in 'out'

        #     while stack_top > 0 or cur != -1:
        #         if cur != -1:
        #             stack[stack_top] = cur
        #             stack_top += 1
        #             cur = self.left[cur]
        #         else:
        #             stack_top -= 1
        #             cur = stack[stack_top]
        #             # We are visiting 'cur' in in-order
        #             if count >= start and count < stop:
        #                 out[out_pos] = self.data[cur]
        #                 out_pos += 1
        #                 if count + 1 == stop:
        #                     # Done collecting
        #                     break
        #             count += 1
        #             cur = self.right[cur]

        #     return out

        def slice(self, start, stop):
            """
            Return a slice of the treap's in-order traversal from 'start' (inclusive)
            to 'stop' (exclusive), as a 2D NumPy array (a copy). Negative indices behave
            like Python slices and are clamped to the valid range. This version uses the
            subtree_size field to jump directly to the start-th element and then uses a stack
            for in-order traversal over only the nodes in the slice.
            """
            # Normalize indices
            if start < 0:
                start += self.size
            if start < 0:
                start = 0
            if start > self.size:
                start = self.size

            if stop < 0:
                stop += self.size
            if stop < 0:
                stop = 0
            if stop > self.size:
                stop = self.size

            if stop < start:
                return np.empty((0, self.item_size), dtype=self.data.dtype)

            length = stop - start
            out = np.empty((length, self.item_size), dtype=self.data.dtype)

            # Build the initial stack representing the path to the start-th node.
            # 'order' tracks the in-order index relative to the subtree counts.
            stack = np.empty(self.capacity, dtype=np.int64)
            stack_top = 0
            order = start
            cur = self.root
            while cur != -1:
                # Compute the size of the left subtree, if any.
                left = self.left[cur]
                left_count = self.subtree_size[left] if left != -1 else 0
                if order < left_count:
                    # The desired node is in the left subtree.
                    stack[stack_top] = cur
                    stack_top += 1
                    cur = self.left[cur]
                elif order == left_count:
                    # The current node is exactly the start-th element.
                    stack[stack_top] = cur
                    stack_top += 1
                    break
                else:
                    # Skip the entire left subtree and current node.
                    order -= left_count + 1
                    cur = self.right[cur]

            # Now perform an in-order traversal using the stack,
            # collecting exactly 'length' nodes.
            out_pos = 0
            while stack_top > 0 and out_pos < length:
                # Pop the top of the stack; this is the next in-order node.
                stack_top -= 1
                cur = stack[stack_top]
                out[out_pos] = self.data[cur]
                out_pos += 1

                # Process the right subtree of the current node:
                cur = self.right[cur]
                while cur != -1:
                    stack[stack_top] = cur
                    stack_top += 1
                    cur = self.left[cur]
                # Continue until we've collected all nodes in [start, stop).
            return out

        def search(self, item):
            """
            Search for a node whose stored item is 'item', returning the node index if found, else -1.
            We assume 'item' <-> 'self.data[cur]' if array_is_less says so,
            and equality is the case when neither is less than the other.
            """
            cur = self.root
            while cur != -1:
                if array_is_less(item, self.data[cur]):
                    cur = self.left[cur]
                elif array_is_less(self.data[cur], item):
                    cur = self.right[cur]
                else:
                    # Not less in either direction => they are equal
                    return cur
            return -1

        def remove(self, item):
            """
            Remove the node containing 'item' (found via BST search).
            Equivalent to the old 'delete(item)'.
            If 'item' is not found, no change is made.
            """
            # Find node
            cur = self.root
            while cur != -1:
                if array_is_less(item, self.data[cur]):
                    cur = self.left[cur]
                elif array_is_less(self.data[cur], item):
                    cur = self.right[cur]
                else:
                    # Found
                    break

            if cur == -1:
                return  # Not found

            self.delete(cur)

        def delete(self, node_idx):
            """
            Delete the node at index 'node_idx' directly, without searching by item.
            Bubbles the node down until it becomes a leaf, then removes it.
            """
            # Bubble down until it becomes a leaf.
            while self.left[node_idx] != -1 or self.right[node_idx] != -1:
                if self.right[node_idx] == -1 or (
                    self.left[node_idx] != -1
                    and self.priorities[self.left[node_idx]] > self.priorities[self.right[node_idx]]
                ):
                    self.rotate_right(node_idx)
                else:
                    self.rotate_left(node_idx)

                # If node_idx's parent is -1, then node_idx is now the root.
                # (This branch is for when rotations change the root pointer.)
                if self.parent[node_idx] == -1:
                    self.root = node_idx

            # At this point, node_idx should be a leaf.
            # Remove the leaf node from its parent.
            parent = self.parent[node_idx]
            if parent != -1:
                if self.left[parent] == node_idx:
                    self.left[parent] = -1
                else:
                    self.right[parent] = -1

                # Update subtree sizes for all ancestors.
                while parent != -1:
                    self._recalc(parent)
                    parent = self.parent[parent]
            else:
                # If the node is the root, then after deletion the tree is empty.
                self.root = -1

            # Clean up this node's pointers so it doesn't create a cycle if re-used.
            self.left[node_idx] = -1
            self.right[node_idx] = -1
            self.parent[node_idx] = -1

            # Return node_idx to the free list.
            self.free_list[self.free_list_top] = node_idx
            self.free_list_top += 1
            self.size -= 1

        def delitem(self, index):
            """
            Remove the item at in-order position 'index'. Negative indexing is supported.
            Uses select(index) to locate the node, then calls delete(node_idx).
            """
            if index < 0:
                index += self.size
            if index < 0 or index >= self.size:
                raise IndexError("Index out of range")

            node_idx = self.select(index)
            if node_idx == -1:
                # Shouldn't happen if index is valid, but just in case
                raise IndexError("Index out of range")

            self.delete(node_idx)

        def select(self, order):
            """
            Return the index of the node that is the `order`-th element
            in the in-order (sorted) traversal, 0-indexed.
            """
            if order < 0 or order >= self.size:
                return -1
            cur = self.root
            while cur != -1:
                left_count = self.subtree_size[self.left[cur]] if self.left[cur] != -1 else 0
                if order < left_count:
                    cur = self.left[cur]
                elif order == left_count:
                    return cur
                else:
                    order -= left_count + 1
                    cur = self.right[cur]
            return -1

        def inorder(self):
            """
            Perform an in-order traversal and return a 1D NumPy array
            of node indices in 'sorted' order (as defined by array_is_less).
            """
            result = np.empty(self.size, dtype=np.int64)
            stack = np.empty(self.capacity, dtype=np.int64)
            stack_top = 0
            cur = self.root
            count = 0

            while stack_top > 0 or cur != -1:
                if cur != -1:
                    stack[stack_top] = cur
                    stack_top += 1
                    cur = self.left[cur]
                else:
                    stack_top -= 1
                    cur = stack[stack_top]
                    result[count] = cur
                    count += 1
                    cur = self.right[cur]
            return result

        def inorder_stack(self):
            """
            Perform an in-order traversal and return a 1D NumPy array
            of node indices in 'sorted' order (as defined by array_is_less).
            """
            result = np.empty(self.size, dtype=np.int64)
            stack = np.empty(self.capacity, dtype=np.int64)
            stack_top = 0
            cur = self.root
            count = 0

            while stack_top > 0 or cur != -1:
                if cur != -1:
                    stack[stack_top] = cur
                    stack_top += 1
                    cur = self.left[cur]
                else:
                    stack_top -= 1
                    cur = stack[stack_top]
                    result[count] = cur
                    count += 1
                    cur = self.right[cur]
            return result

        def inorder_parent(self):
            """
            Perform an in-order traversal using parent pointers and return a 1D NumPy array
            of node indices in sorted order.
            """
            result = np.empty(self.size, dtype=np.int64)
            count = 0

            # Start at the leftmost node
            cur = self.root
            if cur == -1:
                return result
            while self.left[cur] != -1:
                cur = self.left[cur]

            # Traverse the tree using parent pointers
            while cur != -1:
                result[count] = cur
                count += 1

                # If there is a right subtree, go to its leftmost node.
                if self.right[cur] != -1:
                    cur = self.right[cur]
                    while self.left[cur] != -1:
                        cur = self.left[cur]
                else:
                    # Climb up the tree until we find a node that is a left child of its parent.
                    while self.parent[cur] != -1 and self.right[self.parent[cur]] == cur:
                        cur = self.parent[cur]
                    cur = self.parent[cur]
            return result

        def get(self, idx):
            """
            Retrieve a direct reference to the item stored at node index 'idx'.
            """
            return self.data[idx]

        def __str__(self):
            """
            Return a string representation of the treap's items in in-order order.
            Uses `stringify_1d_array` for formatting.
            """
            if self.size == 0:
                return "[]"
            inord = self.inorder()
            elems = [stringify_1d_array(self.data[node]) for node in inord]
            return "[" + ", ".join(elems) + "]"

    return ArrayTreap1

In [None]:
ArrayTreap1 = get_array_treap_1d_items_jitclass(nb.int64)

capacity = 100  # Maximum number of nodes in the treap
item_size = 3  # Each item is a 3-dimensional vector
treap = ArrayTreap1(capacity, item_size)

# Create several example items (each is a NumPy array).
# The first element is used as the key for BST ordering.
items = [
    np.array([5.0, 0.1, 0.2]),
    np.array([2.0, 1.1, 1.2]),
    np.array([8.0, 2.1, 2.2]),
    np.array([1.0, 3.1, 3.2]),
    np.array([3.0, 4.1, 4.2]),
    np.array([7.0, 5.1, 5.2]),
    np.array([9.0, 6.1, 6.2]),
]

# Insert items into the treap.
for item in items:
    treap.insert(item)

# Print in-order traversal (node indices) and then each item (sorted by key).
inorder_indices = treap.inorder()
print("In-order traversal (node indices):", inorder_indices)
for idx in inorder_indices:
    print("Item at node", idx, ":", treap.get(idx))

# Demonstrate the select() function:
# Get the 0th, 3rd, and last elements in sorted order.
idx0 = treap.select(0)
idx3 = treap.select(3)
idx_last = treap.select(treap.size - 1)
print("Select order 0:", treap.get(idx0))
print("Select order 3:", treap.get(idx3))
print("Select last order:", treap.get(idx_last))

print(treap[0])

In-order traversal (node indices): [3 1 4 0 5 2 6]
Item at node 3 : [1 3 3]
Item at node 1 : [2 1 1]
Item at node 4 : [3 4 4]
Item at node 0 : [5 0 0]
Item at node 5 : [7 5 5]
Item at node 2 : [8 2 2]
Item at node 6 : [9 6 6]
Select order 0: [1 3 3]
Select order 3: [5 0 0]
Select last order: [9 6 6]
[1 3 3]


In [112]:
capacity = 1_000_000
treap = ArrayTreap1(capacity, 3)
for i in range(capacity):
    treap.insert(np.random.randint(1_000_000_000, size=3))

In [123]:
for i in range(100):
    treap.inorder_stack()

In [152]:
capacity = 500
treap = ArrayTreap1(capacity, 1)
for x in np.random.permutation(capacity):
    treap.insert(np.asarray((x,)))
for i in range(capacity):
    for j in range(i, capacity + 1):
        assert (treap.slice(i, j).reshape(j - i) == np.arange(i, j)).all()

### Benchmarking


In [None]:
import numpy as np
import numba as nb
from functools import cache
from numba.typed import List

# Suppose your jitclass factory is in array_heap.py
from data_structures.array_heap import get_array_heap_1d_items_jitclass


@nb.njit
def make_item_int64():
    """Return a 1D np.array[int64] with a random int in [-1000,1000)."""
    return np.array([np.random.randint(-1000, 1000)], dtype=np.int64)


@nb.njit
def make_item_float64():
    """Return a 1D np.array[float64] with a random float in [-1000,1000)."""
    return np.array([2000.0 * np.random.random() - 1000.0], dtype=np.float64)


@cache
def create_heap_tests(data_type):
    """
    A factory function (decorated with @cache) that:
      1) Creates a specialized heap jitclass based on `data_type`.
      2) Chooses the right random item generator and typed-list type.
      3) Defines two nopython-mode functions:
         - correctness_test(n_ops: int) -> bool
         - benchmark_test(n_ops: int) -> (int or float)
      4) Returns them as a tuple (correctness_test, benchmark_test).

    Usage:
        correctness_fn, benchmark_fn = create_heap_tests(nb.int64)
        ok = correctness_fn(2000)
        total = benchmark_fn(100_000)
        ...
    """

    # 1) Create the specialized jitclass
    HeapClass = get_array_heap_1d_items_jitclass(data_type)

    # 2) Decide which item-maker to use and what typed-list element type to use for the oracle
    if data_type == nb.int64:
        make_item = make_item_int64
        list_dtype = nb.int64
    elif data_type == nb.float64:
        make_item = make_item_float64
        list_dtype = nb.float64
    else:
        raise TypeError("Unsupported data_type; use nb.int64 or nb.float64")

    @nb.njit
    def correctness_test(n_ops: int) -> bool:
        """
        Performs random correctness testing by comparing each heap operation
        to a naive 'oracle' that stores items in ascending order (typed list).
        Returns True if all checks pass, otherwise False.
        """
        ds = HeapClass(n_ops + 10, 1)  # positional args
        oracle = List.empty_list(list_dtype)

        for _ in range(n_ops):
            op = np.random.randint(0, 3)
            if op == 0:
                # heappush
                if ds.size < ds.capacity:
                    val = make_item()
                    ds.heappush(val)
                    # Insert val[0] in ascending order in oracle
                    inserted = False
                    for i in range(len(oracle)):
                        if oracle[i] > val[0]:
                            oracle.insert(i, val[0])
                            inserted = True
                            break
                    if not inserted:
                        oracle.append(val[0])

            elif op == 1:
                # heappop
                if ds.size > 0:
                    got = ds.heappop()
                    if len(oracle) == 0:
                        return False
                    ref = oracle[0]
                    del oracle[0]
                    if got[0] != ref:
                        return False

            else:
                # heappeek
                if ds.size > 0:
                    got = ds.heappeek()
                    if len(oracle) == 0:
                        return False
                    if got[0] != oracle[0]:
                        return False

        # Final length check
        if ds.size != len(oracle):
            return False
        return True

    @nb.njit
    def benchmark_test(n_ops: int):
        """
        Performs random heappush, heappop, heappeek operations in nopython mode,
        measuring performance. Returns a sum (int or float) so the compiler
        cannot optimize the loop away.
        """
        ds = HeapClass(n_ops + 5, 1)
        total = 0

        for _ in range(n_ops):
            op = np.random.randint(0, 3)
            if op == 0:
                # heappush
                if ds.size < ds.capacity:
                    val = make_item()
                    ds.heappush(val)
            elif op == 1:
                # heappop
                if ds.size > 0:
                    got = ds.heappop()
                    total += got[0]
            else:
                # heappeek
                if ds.size > 0:
                    got = ds.heappeek()
                    total += got[0]

        return total

    # Return both jitted functions from the factory
    return correctness_test, benchmark_test

In [27]:
# ---------------------------------------------------------------------
# USAGE EXAMPLE
# ---------------------------------------------------------------------
if __name__ == "__main__":
    import time

    correctness_n_ops = 1000
    benchmark_n_ops = 1_000_000

    # Set seed for reproducibility
    np.random.seed(42)

    # 1) Get int64 test functions
    correctness_int64, benchmark_int64 = create_heap_tests(nb.int64)

    # 2) Run correctness for int64
    ok_int = correctness_int64(correctness_n_ops)
    print("Heap correctness int64:", ok_int)

    # 3) Run benchmark for int64
    t0 = time.time()
    sum_int = benchmark_int64(benchmark_n_ops)
    t1 = time.time()
    print(f"Heap benchmark int64 sum={sum_int}, time={t1 - t0:.4f}s")

    # 4) Similarly for float64
    correctness_float64, benchmark_float64 = create_heap_tests(nb.float64)

    ok_float = correctness_float64(correctness_n_ops)
    print("Heap correctness float64:", ok_float)

    t0 = time.time()
    sum_float = benchmark_float64(benchmark_n_ops)
    t1 = time.time()
    print(f"Heap benchmark float64 sum={sum_float}, time={t1 - t0:.4f}s")

Heap correctness int64: True
Heap benchmark int64 sum=-1783088, time=0.5463s
Heap correctness float64: True
Heap benchmark float64 sum=-324813.8758473227, time=0.5098s


In [None]:
import numpy as np
import numba as nb
from functools import cache
from numba.typed import List

# Import your treap jitclass factory.
from data_structures.array_treap import get_array_treap_1d_items_jitclass


# ---------------------------------------------------------------------
# Helpers: Create a single-element random item for each data type.
# ---------------------------------------------------------------------
@nb.njit
def make_item_int64():
    return np.array([np.random.randint(-1000, 1000)], dtype=np.int64)


@nb.njit
def make_item_float64():
    return np.array([2000.0 * np.random.random() - 1000.0], dtype=np.float64)


# ---------------------------------------------------------------------
# Helpers: Random value generators (compiled so they can be inlined).
# ---------------------------------------------------------------------
@nb.njit
def random_val_int():
    return np.random.randint(-1000, 1000)


@nb.njit
def random_val_float():
    return 2000.0 * np.random.random() - 1000.0


# ---------------------------------------------------------------------
# Factory Function for Treap Tests (only public methods)
# ---------------------------------------------------------------------
@cache
def create_treap_tests(data_type):
    """
    Factory that creates two specialized Numba-jitted functions for testing
    your treap implementation (from array_treap.py). Only public methods
    are tested (insert, __getitem__, slice, remove, and delitem). The inner
    functions do not compare data_type values.

    Returns a tuple:
       (correctness_test, benchmark_test)
    """
    # Create the specialized jitclass for the treap.
    TreapClass = get_array_treap_1d_items_jitclass(data_type)

    # Choose the proper item generator, oracle element type, random value
    # generator, and initial total value—all done outside the njit functions.
    if data_type == nb.int64:
        item_generator = make_item_int64
        list_dtype = nb.int64
        rand_val = random_val_int
        initial_total = 0
    elif data_type == nb.float64:
        item_generator = make_item_float64
        list_dtype = nb.float64
        rand_val = random_val_float
        initial_total = 0.0
    else:
        raise TypeError("Unsupported data_type; use nb.int64 or nb.float64")

    @nb.njit
    def correctness_test(n_ops: int) -> bool:
        """
        Executes a randomized sequence of operations on the treap,
        comparing its behavior against an oracle (a sorted typed list).
        Operations include:
          0: insert(item)
          1: remove(item)
          2: __getitem__(index)
          3: slice(start, stop)
          4: delitem(index)
        Returns True if all checks pass.
        """
        treap = TreapClass(n_ops + 10, 1)  # Use positional arguments.
        oracle = List.empty_list(list_dtype)

        for _ in range(n_ops):
            op = np.random.randint(0, 5)
            if op == 0:
                # INSERT operation.
                item = item_generator()
                treap.insert(item)
                # Insert item[0] into oracle in ascending order.
                inserted = False
                for i in range(len(oracle)):
                    if oracle[i] > item[0]:
                        oracle.insert(i, item[0])
                        inserted = True
                        break
                if not inserted:
                    oracle.append(item[0])

            elif op == 1:
                # REMOVE operation: generate a candidate value.
                val = rand_val()
                arr = np.array([val], dtype=data_type)
                treap.remove(arr)
                # Remove first occurrence from oracle if present.
                for i in range(len(oracle)):
                    if oracle[i] == val:
                        del oracle[i]
                        break

            elif op == 2:
                # __getitem__ test.
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    item_from_treap = treap[idx]
                    if item_from_treap[0] != oracle[idx]:
                        return False

            elif op == 3:
                # slice test.
                if treap.size > 0:
                    start = np.random.randint(-treap.size, treap.size)
                    stop = np.random.randint(-treap.size, treap.size + 1)
                    sliced = treap.slice(start, stop)
                    # Emulate Python slicing on oracle.
                    n = len(oracle)
                    s = start if start >= 0 else start + n
                    e = stop if stop >= 0 else stop + n
                    if s < 0:
                        s = 0
                    if s > n:
                        s = n
                    if e < s:
                        e = s
                    if e > n:
                        e = n
                    expected_len = e - s
                    if sliced.shape[0] != expected_len:
                        return False
                    for j in range(sliced.shape[0]):
                        if sliced[j, 0] != oracle[s + j]:
                            return False

            elif op == 4:
                # delitem test.
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    treap.delitem(idx)
                    del oracle[idx]

            if treap.size != len(oracle):
                return False
        return True

    @nb.njit
    def benchmark_test(n_ops: int):
        """
        Executes a randomized sequence of operations on the treap for benchmarking.
        The operations are the same as in correctness_test, but without an oracle.
        A running total is accumulated (from __getitem__ and slice operations) so that
        the work cannot be optimized away.
        Returns the accumulated total.
        """
        treap = TreapClass(n_ops + 10, 1)
        total = initial_total
        for _ in range(n_ops):
            op = np.random.randint(0, 5)
            if op == 0:
                if treap.size < (n_ops + 10):
                    treap.insert(item_generator())
            elif op == 1:
                # Remove: generate a random value.
                val = rand_val()
                arr = np.array([val], dtype=data_type)
                treap.remove(arr)
            elif op == 2:
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    total += treap[idx][0]
            elif op == 3:
                if treap.size > 0:
                    start = np.random.randint(-treap.size, treap.size)
                    stop = np.random.randint(-treap.size, treap.size + 1)
                    sliced = treap.slice(start, stop)
                    if sliced.shape[0] > 0:
                        total += sliced[0, 0]
            elif op == 4:
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    treap.delitem(idx)
            total += treap.size  # Prevent dead-code elimination.
        return total

    return correctness_test, benchmark_test

In [127]:
# ---------------------------------------------------------------------
# USAGE EXAMPLE
# ---------------------------------------------------------------------
if __name__ == "__main__":
    import time

    correctness_n_ops = 1000
    benchmark_n_ops = 200_000

    np.random.seed(42)

    # Generate tests for int64.
    treap_correctness_int64, treap_benchmark_int64 = create_treap_tests(nb.int64)
    ok_int = treap_correctness_int64(correctness_n_ops)
    print("Treap correctness (int64):", ok_int)

    t0 = time.time()
    total_int = treap_benchmark_int64(benchmark_n_ops)
    t1 = time.time()
    print(f"Treap benchmark (int64): total={total_int}, time={t1 - t0:.4f}s")

    # Generate tests for float64.
    treap_correctness_float64, treap_benchmark_float64 = create_treap_tests(nb.float64)
    ok_float = treap_correctness_float64(correctness_n_ops)
    print("Treap correctness (float64):", ok_float)

    t0 = time.time()
    total_float = treap_benchmark_float64(benchmark_n_ops)
    t1 = time.time()
    print(f"Treap benchmark (float64): total={total_float}, time={t1 - t0:.4f}s")

Treap correctness (int64): True
Treap benchmark (int64): total=-741624, time=0.8888s
Treap correctness (float64): True
Treap benchmark (float64): total=33574833.90024427, time=1.0152s


In [None]:
import numba as nb

data_type = nb.int64
ArrayTreap1 = get_array_treap_1d_items_jitclass(data_type)
a = ArrayTreap1(100, 1)
a.insert([5])
a.insert([100])
a.insert([9])
a.insert([3])
a.insert([0])
print(a)
print(a.slice(0, -1))
a.delitem(0)
print(a)
a.delitem(-2)
print(a)
a.insert([1000])
a.insert([99])
a.insert([999])
a.delitem(-2)
a.delitem(1)
print(a)

[[0], [3], [5], [9], [100]]
[[0]
 [3]
 [5]
 [9]]
[[3], [5], [9], [100]]
[[3], [5], [100]]
[[3], [99], [100], [1000]]


##### Debugging


In [None]:
from functools import cache

# import numba as nb
import numpy as np

# from numba import njit
# from numba.experimental import jitclass


# @njit(inline='always')
def array_is_less(a, b):
    """
    Compare two arrays element-by-element for a strict lexicographical order.
    Return True if 'a' is lexicographically less than 'b', False otherwise.
    """
    for i in range(len(a)):
        if a[i] < b[i]:
            return True
        elif a[i] > b[i]:
            return False
    return False  # They are equal in all positions.


# @njit(inline='always')
def stringify_1d_array(array):
    return '[' + ', '.join([str(entry) for entry in array]) + ']'


@cache
def get_array_treap_1d_items_jitclass(data_type):
    """
    Factory function to create a treap class specialized for 1D items
    using only 'array_is_less' to define BST ordering (and equality detection).
    """
    # assert (
    #     data_type is nb.int64 or data_type is nb.float64
    # ), 'Only 64-bit data types supported as the random key generation is defined up to 10**9'

    # spec = [
    #     ('capacity', nb.int64),
    #     ('item_size', nb.int64),
    #     ('size', nb.int64),
    #     ('root', nb.int64),
    #     ('data', data_type[:, :]),
    #     ('priorities', nb.int64[:]),
    #     ('left', nb.int64[:]),
    #     ('right', nb.int64[:]),
    #     ('parent', nb.int64[:]),
    #     ('free_list', nb.int64[:]),
    #     ('free_list_top', nb.int64),
    #     ('subtree_size', nb.int64[:]),
    # ]

    # @jitclass(spec)
    class ArrayTreap1:
        """
        A treap (tree-based heap) implemented with implicit indexing,
        using 'array_is_less' to compare multi-dimensional items.
        """

        def __init__(self, capacity, item_size):
            """
            Initialize the treap with a maximum capacity and item size.
            Each row of 'data' is considered one item.
            """
            self.capacity = capacity
            self.item_size = item_size
            self.size = 0
            self.root = -1

            # Preallocate arrays
            self.data = np.empty((capacity, item_size), dtype=data_type)
            self.priorities = np.empty(capacity, dtype=np.int64)
            self.left = -np.ones(capacity, dtype=np.int64)
            self.right = -np.ones(capacity, dtype=np.int64)
            self.parent = -np.ones(capacity, dtype=np.int64)

            # Free list: stack of available node indices
            self.free_list = np.empty(capacity, dtype=np.int64)
            for i in range(capacity):
                self.free_list[i] = capacity - 1 - i
            self.free_list_top = capacity

            # For each node, store subtree size
            self.subtree_size = np.zeros(capacity, dtype=np.int64)

        def _recalc(self, x):
            """
            Recalculate the subtree size for node x.
            """
            s = 1
            if self.left[x] != -1:
                s += self.subtree_size[self.left[x]]
            if self.right[x] != -1:
                s += self.subtree_size[self.right[x]]
            self.subtree_size[x] = s

        def _new_node(self, item):
            """
            Allocate a new node from the free list and initialize it.
            """
            if self.free_list_top == 0:
                raise Exception("Treap capacity exceeded")

            self.free_list_top -= 1
            idx = self.free_list[self.free_list_top]
            self.data[idx] = item
            self.priorities[idx] = np.random.randint(1, 10**9)  # random priority
            self.left[idx] = -1
            self.right[idx] = -1
            self.parent[idx] = -1
            self.subtree_size[idx] = 1
            self.size += 1
            return idx

        def rotate_right(self, x):
            print('rotate_right', x)
            y = self.left[x]
            self.left[x] = self.right[y]
            if self.right[y] != -1:
                self.parent[self.right[y]] = x
            self.right[y] = x
            parent_x = self.parent[x]
            self.parent[y] = parent_x
            self.parent[x] = y

            if parent_x == -1:
                self.root = y
            else:
                if self.left[parent_x] == x:
                    self.left[parent_x] = y
                else:
                    self.right[parent_x] = y

            # Recalculate subtree sizes for x and y first.
            self._recalc(x)
            self._recalc(y)

            # Then update the rest of the chain.
            temp = self.parent[y]
            while temp != -1:
                self._recalc(temp)
                temp = self.parent[temp]

        def rotate_left(self, x):
            debug = True  # Set to False once you've debugged the loop

            if debug:
                print(f"\n--- rotate_left called on x={x} ---")
                print(f"Initial state:")
                print(f"  x={x}")
                print(f"  right[x]={self.right[x]}")
                print(f"  left[x]={self.left[x]}")
                print(f"  parent[x]={self.parent[x]}")

            y = self.right[x]

            # Sanity check: if y == -1, rotating left makes no sense
            if y == -1:
                raise RuntimeError(f"rotate_left error: right[{x}] is -1, cannot rotate left.")

            if debug:
                print(f"  y={y}")
                print(f"  left[y]={self.left[y]}")
                print(f"  right[y]={self.right[y]}")
                print(f"  parent[y]={self.parent[y]}")

            # If y is the same as x, we have a pointer cycle already
            if y == x:
                raise RuntimeError(f"rotate_left error: right[x] == x ({x}), cycle detected!")

            # Perform the standard rotate-left transformation
            # Step 1: move y’s left subtree to x’s right subtree
            self.right[x] = self.left[y]
            if self.left[y] != -1:
                self.parent[self.left[y]] = x

            # Step 2: set y’s left pointer to x
            self.left[y] = x

            # Step 3: fix y’s parent pointer
            parent_x = self.parent[x]
            self.parent[y] = parent_x
            self.parent[x] = y

            # Step 4: update the original parent of x to point to y
            if parent_x == -1:
                self.root = y
            else:
                if self.left[parent_x] == x:
                    self.left[parent_x] = y
                elif self.right[parent_x] == x:
                    self.right[parent_x] = y
                else:
                    raise RuntimeError(
                        f"rotate_left error: x={x} is not a child of its parent={parent_x}. "
                        "Tree structure is corrupted."
                    )

            # Debug prints after pointer adjustments
            if debug:
                print("\nAfter pointer reassignments:")
                print(f"  x={x}, parent[x]={self.parent[x]}, left[x]={self.left[x]}, right[x]={self.right[x]}")
                print(f"  y={y}, parent[y]={self.parent[y]}, left[y]={self.left[y]}, right[y]={self.right[y]}")
                print(f"  parent_x={parent_x}, root={self.root}")
                if parent_x != -1:
                    print(f"  parent's left={self.left[parent_x]}, parent's right={self.right[parent_x]}")

            # Recalculate subtree sizes for x and y first
            self._recalc(x)
            self._recalc(y)

            # Then propagate recalculations up the tree
            temp = self.parent[y]
            while temp != -1:
                self._recalc(temp)
                temp = self.parent[temp]

            if debug:
                print("\nAfter subtree size updates:")
                print(f"  subtree_size[x]={self.subtree_size[x]}")
                print(f"  subtree_size[y]={self.subtree_size[y]}")
                print(f"  final parent[y]={self.parent[y]} (should match old parent_x)")
                print("--- rotate_left complete ---\n")

        def insert(self, item):
            """
            Insert a new item into the treap, maintaining BST order via array_is_less.
            """
            new_idx = self._new_node(item)

            # If tree is empty, new node is the root
            if self.root == -1:
                self.root = new_idx
                return

            # Search for insertion point
            cur = self.root
            parent = -1
            path = []
            while cur != -1:
                path.append(cur)
                parent = cur
                if array_is_less(item, self.data[cur]):
                    cur = self.left[cur]
                else:
                    # If not less, go right
                    cur = self.right[cur]

            # Attach new_idx to parent
            self.parent[new_idx] = parent
            if array_is_less(item, self.data[parent]):
                self.left[parent] = new_idx
            else:
                self.right[parent] = new_idx

            # Update subtree sizes along the path
            for node in path:
                self.subtree_size[node] += 1

            # Bubble up based on priority
            cur = new_idx
            while self.parent[cur] != -1 and self.priorities[self.parent[cur]] < self.priorities[cur]:
                p = self.parent[cur]
                if self.left[p] == cur:
                    self.rotate_right(p)
                else:
                    self.rotate_left(p)

        def __getitem__(self, index):
            """
            Retrieve the in-order item at position 'index', integer-only.
            Negative indices behave similarly to Python's sequence rules.
            """
            if index < 0:
                index += self.size
            if index < 0 or index >= self.size:
                raise IndexError("Index out of range")
            # We do a standard subtree-based 'select'
            node_idx = self.select(index)
            return self.data[node_idx]

        def slice(self, start, stop):
            debug = True  # Turn this off to disable debug prints once you fix the loop.
            max_iter = self.capacity * 2  # Arbitrary limit to catch infinite loops
            visited = set()  # Detect repeated visits to the same node in a single slice call.

            # Handle negative indices for start
            if start < 0:
                start += self.size
            if start < 0:
                start = 0
            if start > self.size:
                start = self.size

            # Handle negative indices for stop
            if stop < 0:
                stop += self.size
            if stop < 0:
                stop = 0
            if stop > self.size:
                stop = self.size

            if stop < start:
                # Empty slice
                return np.empty((0, self.item_size), dtype=self.data.dtype)

            length = stop - start
            out = np.empty((length, self.item_size), dtype=self.data.dtype)

            stack = np.empty(self.capacity, dtype=np.int64)
            stack_top = 0
            cur = self.root
            count = 0
            out_pos = 0

            iteration_count = 0
            while stack_top > 0 or cur != -1:
                iteration_count += 1
                if iteration_count > max_iter:
                    raise RuntimeError(
                        "Infinite loop detected in slice (max_iter exceeded). "
                        "Possible pointer cycle or traversal logic error."
                    )
                if debug:
                    print(
                        f"[slice debug] iteration={iteration_count}, stack_top={stack_top}, "
                        f"cur={cur}, count={count}, out_pos={out_pos}, start={start}, stop={stop}"
                    )

                if cur != -1:
                    # Check if we've visited this node before in the same traversal.
                    if cur in visited:
                        raise RuntimeError(
                            f"Infinite loop detected in slice: node {cur} revisited. "
                            f"This usually means there is a cycle in the tree."
                        )
                    visited.add(cur)

                    stack[stack_top] = cur
                    stack_top += 1
                    cur = self.left[cur]
                else:
                    stack_top -= 1
                    cur = stack[stack_top]
                    # We are visiting 'cur' in in-order
                    if count >= start and count < stop:
                        out[out_pos] = self.data[cur]
                        out_pos += 1
                        if count + 1 == stop:
                            # Done collecting
                            break
                    count += 1
                    cur = self.right[cur]

            return out

        def search(self, item):
            """
            Search for a node whose stored item is 'item', returning the node index if found, else -1.
            We assume 'item' <-> 'self.data[cur]' if array_is_less says so,
            and equality is the case when neither is less than the other.
            """
            cur = self.root
            while cur != -1:
                if array_is_less(item, self.data[cur]):
                    cur = self.left[cur]
                elif array_is_less(self.data[cur], item):
                    cur = self.right[cur]
                else:
                    # Not less in either direction => they are equal
                    return cur
            return -1

        def remove(self, item):
            """
            Remove the node containing 'item' (found via BST search).
            Equivalent to the old 'delete(item)'.
            If 'item' is not found, no change is made.
            """
            # Find node
            cur = self.root
            while cur != -1:
                if array_is_less(item, self.data[cur]):
                    cur = self.left[cur]
                elif array_is_less(self.data[cur], item):
                    cur = self.right[cur]
                else:
                    # Found
                    break

            if cur == -1:
                return  # Not found

            self.delete(cur)

        def delete(self, node_idx):
            """
            Delete the node at index 'node_idx' directly, without searching by item.
            Bubbles the node down until it becomes a leaf, then removes it.
            """
            # Bubble down until it becomes a leaf.
            while self.left[node_idx] != -1 or self.right[node_idx] != -1:
                if self.right[node_idx] == -1 or (
                    self.left[node_idx] != -1
                    and self.priorities[self.left[node_idx]] > self.priorities[self.right[node_idx]]
                ):
                    self.rotate_right(node_idx)
                else:
                    self.rotate_left(node_idx)

                # If node_idx's parent is -1, then node_idx is now the root.
                # (This branch is for when rotations change the root pointer.)
                if self.parent[node_idx] == -1:
                    self.root = node_idx

            # At this point, node_idx should be a leaf.
            # Remove the leaf node from its parent.
            parent = self.parent[node_idx]
            if parent != -1:
                if self.left[parent] == node_idx:
                    self.left[parent] = -1
                else:
                    self.right[parent] = -1

                # Update subtree sizes for all ancestors.
                while parent != -1:
                    self._recalc(parent)
                    parent = self.parent[parent]
            else:
                # If the node is the root, then after deletion the tree is empty.
                self.root = -1

            # Clean up this node's pointers so it doesn't create a cycle if re-used.
            self.left[node_idx] = -1
            self.right[node_idx] = -1
            self.parent[node_idx] = -1

            # Return node_idx to the free list.
            self.free_list[self.free_list_top] = node_idx
            self.free_list_top += 1
            self.size -= 1

        def delitem(self, index):
            """
            Remove the item at in-order position 'index'. Negative indexing is supported.
            Uses select(index) to locate the node, then calls delete(node_idx).
            """
            if index < 0:
                index += self.size
            if index < 0 or index >= self.size:
                raise IndexError("Index out of range")

            node_idx = self.select(index)
            if node_idx == -1:
                # Shouldn't happen if index is valid, but just in case
                raise IndexError("Index out of range")

            self.delete(node_idx)

        def select(self, order):
            """
            Return the index of the node that is the `order`-th element
            in the in-order (sorted) traversal, 0-indexed.
            """
            if order < 0 or order >= self.size:
                return -1
            cur = self.root
            while cur != -1:
                left_count = self.subtree_size[self.left[cur]] if self.left[cur] != -1 else 0
                if order < left_count:
                    cur = self.left[cur]
                elif order == left_count:
                    return cur
                else:
                    order -= left_count + 1
                    cur = self.right[cur]
            return -1

        def inorder(self):
            """
            Perform an in-order traversal and return a 1D NumPy array
            of node indices in 'sorted' order (as defined by array_is_less).
            """
            result = np.empty(self.size, dtype=np.int64)
            stack = np.empty(self.capacity, dtype=np.int64)
            stack_top = 0
            cur = self.root
            count = 0

            while stack_top > 0 or cur != -1:
                if cur != -1:
                    stack[stack_top] = cur
                    stack_top += 1
                    cur = self.left[cur]
                else:
                    stack_top -= 1
                    cur = stack[stack_top]
                    result[count] = cur
                    count += 1
                    cur = self.right[cur]
            return result

        def get(self, idx):
            """
            Retrieve a direct reference to the item stored at node index 'idx'.
            """
            return self.data[idx]

        def __str__(self):
            """
            Return a string representation of the treap's items in in-order order.
            Uses `stringify_1d_array` for formatting.
            """
            if self.size == 0:
                return "[]"
            inord = self.inorder()
            elems = [stringify_1d_array(self.data[node]) for node in inord]
            return "[" + ", ".join(elems) + "]"

    return ArrayTreap1

In [None]:
import numpy as np
import numba as nb
from functools import cache
from numba.typed import List

# Import your treap jitclass factory.
# from data_structures.array_treap import get_array_treap_1d_items_jitclass


# ---------------------------------------------------------------------
# Helpers: Create a single-element random item for each data type.
# ---------------------------------------------------------------------
# @nb.njit
def make_item_int64():
    return np.array([np.random.randint(-1000, 1000)], dtype=np.int64)


# @nb.njit
def make_item_float64():
    return np.array([2000.0 * np.random.random() - 1000.0], dtype=np.float64)


# ---------------------------------------------------------------------
# Helpers: Random value generators (compiled so they can be inlined).
# ---------------------------------------------------------------------
# @nb.njit
def random_val_int():
    return np.random.randint(-1000, 1000)


# @nb.njit
def random_val_float():
    return 2000.0 * np.random.random() - 1000.0


# ---------------------------------------------------------------------
# Factory Function for Treap Tests (only public methods)
# ---------------------------------------------------------------------
@cache
def create_treap_tests(data_type):
    """
    Factory that creates two specialized Numba-jitted functions for testing
    your treap implementation (from array_treap.py). Only public methods
    are tested (insert, __getitem__, slice, remove, and delitem). The inner
    functions do not compare data_type values.

    Returns a tuple:
       (correctness_test, benchmark_test)
    """
    # Create the specialized jitclass for the treap.
    TreapClass = get_array_treap_1d_items_jitclass(data_type)

    # Choose the proper item generator, oracle element type, random value
    # generator, and initial total value—all done outside the njit functions.
    if data_type == np.int64:
        item_generator = make_item_int64
        list_dtype = nb.int64
        rand_val = random_val_int
        initial_total = 0
    elif data_type == np.float64:
        item_generator = make_item_float64
        list_dtype = nb.float64
        rand_val = random_val_float
        initial_total = 0.0
    else:
        raise TypeError("Unsupported data_type; use nb.int64 or nb.float64")

    # @nb.njit
    def correctness_test(n_ops: int) -> bool:
        """
        Executes a randomized sequence of operations on the treap,
        comparing its behavior against an oracle (a sorted typed list).
        Operations include:
          0: insert(item)
          1: remove(item)
          2: __getitem__(index)
          3: slice(start, stop)
          4: delitem(index)
        Returns True if all checks pass.
        """
        treap = TreapClass(n_ops + 10, 1)  # Use positional arguments.
        oracle = List.empty_list(list_dtype)

        for _ in range(n_ops):
            op = np.random.randint(0, 5)
            if op == 0:
                # INSERT operation.
                item = item_generator()
                treap.insert(item)
                # Insert item[0] into oracle in ascending order.
                inserted = False
                for i in range(len(oracle)):
                    if oracle[i] > item[0]:
                        oracle.insert(i, item[0])
                        inserted = True
                        break
                if not inserted:
                    oracle.append(item[0])

            elif op == 1:
                # REMOVE operation: generate a candidate value.
                val = rand_val()
                arr = np.array([val], dtype=data_type)
                treap.remove(arr)
                # Remove first occurrence from oracle if present.
                for i in range(len(oracle)):
                    if oracle[i] == val:
                        del oracle[i]
                        break

            elif op == 2:
                # __getitem__ test.
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    item_from_treap = treap[idx]
                    if item_from_treap[0] != oracle[idx]:
                        return False

            elif op == 3:
                continue
                # slice test.
                if treap.size > 0:
                    start = np.random.randint(-treap.size, treap.size)
                    stop = np.random.randint(-treap.size, treap.size + 1)
                    sliced = treap.slice(start, stop)
                    # Emulate Python slicing on oracle.
                    n = len(oracle)
                    s = start if start >= 0 else start + n
                    e = stop if stop >= 0 else stop + n
                    if s < 0:
                        s = 0
                    if s > n:
                        s = n
                    if e < s:
                        e = s
                    if e > n:
                        e = n
                    expected_len = e - s
                    if sliced.shape[0] != expected_len:
                        return False
                    for j in range(sliced.shape[0]):
                        if sliced[j, 0] != oracle[s + j]:
                            return False

            elif op == 4:
                # delitem test.
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    treap.delitem(idx)
                    del oracle[idx]

            if treap.size != len(oracle):
                return False
        return True

    # @nb.njit
    def benchmark_test(n_ops: int):
        """
        Executes a randomized sequence of operations on the treap for benchmarking.
        The operations are the same as in correctness_test, but without an oracle.
        A running total is accumulated (from __getitem__ and slice operations) so that
        the work cannot be optimized away.
        Returns the accumulated total.
        """
        treap = TreapClass(n_ops + 10, 1)
        total = initial_total
        for _ in range(n_ops):
            op = np.random.randint(0, 5)
            if op == 0:
                if treap.size < (n_ops + 10):
                    treap.insert(item_generator())
            elif op == 1:
                # Remove: generate a random value.
                val = rand_val()
                arr = np.array([val], dtype=data_type)
                treap.remove(arr)
            elif op == 2:
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    total += treap[idx][0]
            elif op == 3:
                continue
                if treap.size > 0:
                    start = np.random.randint(-treap.size, treap.size)
                    stop = np.random.randint(-treap.size, treap.size + 1)
                    sliced = treap.slice(start, stop)
                    if sliced.shape[0] > 0:
                        total += sliced[0, 0]
            elif op == 4:
                if treap.size > 0:
                    idx = np.random.randint(0, treap.size)
                    treap.delitem(idx)
            total += treap.size  # Prevent dead-code elimination.
        return total

    return correctness_test, benchmark_test

In [58]:
# ---------------------------------------------------------------------
# USAGE EXAMPLE
# ---------------------------------------------------------------------
if __name__ == "__main__":
    import time

    # np.random.seed(42)
    n_ops = 1000

    # Generate tests for int64.
    treap_correctness_int64, treap_benchmark_int64 = create_treap_tests(np.int64)
    ok_int = treap_correctness_int64(n_ops)
    print("Treap correctness (int64):", ok_int)

    t0 = time.time()
    total_int = treap_benchmark_int64(n_ops)
    t1 = time.time()
    print(f"Treap benchmark (int64): total={total_int}, time={t1 - t0:.4f}s")

    # Generate tests for float64.
    treap_correctness_float64, treap_benchmark_float64 = create_treap_tests(np.float64)
    ok_float = treap_correctness_float64(n_ops)
    print("Treap correctness (float64):", ok_float)

    t0 = time.time()
    total_float = treap_benchmark_float64(n_ops)
    t1 = time.time()
    print(f"Treap benchmark (float64): total={total_float}, time={t1 - t0:.4f}s")

AssertionError: Only 64-bit data types supported as the random key generation is defined up to 10**9