In [1]:
# Setup for performance metrics - create nested directories
import random
from time import sleep
perf_base = "spytial_perf"
def get_perf_path(structure, size):
    return perf_base + "_" + structure + "_" + f"{size}.json"
PI = 30
SIZES = [5, 10, 25, 50]


In [3]:
import sys
from pathlib import Path


# Add the parent directory to the Python path
sys.path.append(str(Path().resolve().parent))

from spytial import *
from spytial.annotations import *
from spytial.annotations import flag


# B Trees

In [6]:
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Any

## Selector that gives us parent -> idx -> child 
pic = "{p: BTreeNode, i : int, c : BTreeNode| c in (((p.children).idx)[i]) }"

## WE WANT PARENT->CHILD AS WELL.
pc = "{p: BTreeNode, c : BTreeNode| c in (((p.children).idx)[int]) }"

# ------------------------- Node -------------------------
@dataclass
@attribute(field="t")
@attribute(field="leaf")
@hideAtom(selector="list + bool + int")
@group(selector="{p : BTreeNode, k : str | k in (((p.keys).idx)[int]) } + {p, pp : BTreeNode | p = pp}", name="keys")
@align(selector="{p : BTreeNode, k : str | k in (((p.keys).idx)[int]) } + {p, pp : BTreeNode | p = pp}", direction="horizontal")
@inferredEdge(selector=pic, name="child")
@orientation(selector=pic, directions=["below"])
@align(selector=f"{{ c1, c2 :  BTreeNode | (some p : BTreeNode | ((p->c1) + (p->c2))  in {pc}   )}}", direction="horizontal") ## This?
### This really deep nested alignment is the issue in terms of perf
#@align(selector=f"{{ c1, c2 : BTreeNode | (some p : BTreeNode, i : int, j : int | i != j and (p->i->c1 + p->j->c2 in {pic})  )    }}", direction="horizontal")
class BTreeNode:
    t: int
    leaf: bool = True
    keys: List[Any] = field(default_factory=list)           # sorted
    children: List["BTreeNode"] = field(default_factory=list)  # len = len(keys)+1 if !leaf

    # SEARCH (B-TREE-SEARCH)
    def search(self, k: Any) -> Optional[Tuple["BTreeNode", int]]:
        i = 0
        while i < len(self.keys) and k > self.keys[i]:
            i += 1
        if i < len(self.keys) and k == self.keys[i]:
            return (self, i)
        if self.leaf:
            return None
        return self.children[i].search(k)

    # SPLIT-CHILD (B-TREE-SPLIT-CHILD)
    def split_child(self, i: int) -> None:
        t = self.t
        y = self.children[i]
        z = BTreeNode(t=t, leaf=y.leaf)

        mid = y.keys[t-1]
        z.keys = y.keys[t:]            # upper t-1 keys
        y.keys = y.keys[:t-1]          # lower t-1 keys

        if not y.leaf:
            z.children = y.children[t:]
            y.children = y.children[:t]

        self.children.insert(i+1, z)
        self.keys.insert(i, mid)

    # INSERT-NONFULL (B-TREE-INSERT-NONFULL)
    def insert_nonfull(self, k: Any) -> None:
        i = len(self.keys) - 1
        if self.leaf:
            self.keys.append(k)
            j = len(self.keys) - 1
            while j > 0 and self.keys[j-1] > k:
                self.keys[j] = self.keys[j-1]
                j -= 1
            self.keys[j] = k
            return

        while i >= 0 and k < self.keys[i]:
            i -= 1
        i += 1
        if len(self.children[i].keys) == 2*self.t - 1:
            self.split_child(i)
            if k > self.keys[i]:
                i += 1
        self.children[i].insert_nonfull(k)

    # ----- DELETE helpers -----
    def _pred(self, idx: int) -> Any:
        cur = self.children[idx]
        while not cur.leaf:
            cur = cur.children[-1]
        return cur.keys[-1]

    def _succ(self, idx: int) -> Any:
        cur = self.children[idx+1]
        while not cur.leaf:
            cur = cur.children[0]
        return cur.keys[0]

    def _merge(self, idx: int) -> None:
        """Merge children[idx], keys[idx], children[idx+1] into children[idx]."""
        child = self.children[idx]
        sibling = self.children[idx+1]
        child.keys.append(self.keys[idx])
        child.keys.extend(sibling.keys)
        if not child.leaf:
            child.children.extend(sibling.children)
        del self.keys[idx]
        del self.children[idx+1]

    def _fill(self, idx: int) -> None:
        """Ensure children[idx] has at least t keys."""
        t = self.t
        if idx > 0 and len(self.children[idx-1].keys) >= t:
            # Borrow from left
            left = self.children[idx-1]
            child = self.children[idx]
            child.keys.insert(0, self.keys[idx-1])
            if not child.leaf:
                child.children.insert(0, left.children.pop())
            self.keys[idx-1] = left.keys.pop()
        elif idx < len(self.children)-1 and len(self.children[idx+1].keys) >= t:
            # Borrow from right
            right = self.children[idx+1]
            child = self.children[idx]
            child.keys.append(self.keys[idx])
            if not child.leaf:
                child.children.append(right.children.pop(0))
            self.keys[idx] = right.keys.pop(0)
        else:
            # Merge with a sibling
            if idx < len(self.children)-1:
                self._merge(idx)
            else:
                self._merge(idx-1)

    # DELETE from subtree rooted here. Assumes this node is non-empty.
    def delete_nonempty(self, k: Any) -> None:
        t = self.t
        idx = 0
        while idx < len(self.keys) and k > self.keys[idx]:
            idx += 1

        # Case 1: key present in this node
        if idx < len(self.keys) and self.keys[idx] == k:
            if self.leaf:
                # 1a: key in leaf
                del self.keys[idx]
                return
            # 1b: key in internal node
            if len(self.children[idx].keys) >= t:
                pred = self._pred(idx)
                self.keys[idx] = pred
                self.children[idx].delete_nonempty(pred)
            elif len(self.children[idx+1].keys) >= t:
                succ = self._succ(idx)
                self.keys[idx] = succ
                self.children[idx+1].delete_nonempty(succ)
            else:
                self._merge(idx)
                self.children[idx].delete_nonempty(k)
            return

        # Case 2: key not present here
        if self.leaf:
            return  # key not found
        # Ensure the child we descend to has at least t keys
        if len(self.children[idx].keys) == t - 1:
            self._fill(idx)
            # After fill, structure may have changed. Decide next child:
            if idx > len(self.keys):   # merged with left
                idx -= 1
        self.children[idx].delete_nonempty(k)

# ------------------------- Tree -------------------------
class BTree:
    def __init__(self, t: int):
        if t < 2:
            raise ValueError("Minimum degree t must be ≥ 2")
        self.t = t
        self.root = BTreeNode(t=t, leaf=True)

    def search(self, k: Any) -> Optional[Tuple[BTreeNode, int]]:
        return self.root.search(k)

    # INSERT (B-TREE-INSERT)
    def insert(self, k: Any) -> None:
        r = self.root
        if len(r.keys) == 2*self.t - 1:
            s = BTreeNode(t=self.t, leaf=False, keys=[], children=[r])
            s.split_child(0)
            self.root = s
            s.insert_nonfull(k)
        else:
            r.insert_nonfull(k)

    # DELETE (top-level)
    def delete(self, k: Any) -> None:
        if not self.root.keys and self.root.leaf:
            return
        self.root.delete_nonempty(k)
        # Shrink height if root became empty and is internal
        if not self.root.keys and not self.root.leaf:
            self.root = self.root.children[0]

    # Utilities
    def inorder(self) -> List[Any]:
        out: List[Any] = []
        def dfs(x: BTreeNode):
            if x.leaf:
                out.extend(x.keys)
            else:
                for i, key in enumerate(x.keys):
                    dfs(x.children[i])
                    out.append(key)
                dfs(x.children[-1])
        dfs(self.root)
        return out

    def __contains__(self, k: Any) -> bool:
        return self.search(k) is not None


![b-tree](./img/b-trees.png)

In [7]:
## Now build and display a b tree whose keys are the consonants of English (as in CLRS Fig 18.1)

bt = BTree(t=2)
for k in ['F', 'S', 'Q', 'K', 'C']:
    bt.insert(k)
evaluate(bt, method="browser")
diagram(bt)

## Performance

In [9]:
STRUCTURE = "btree"
import string
for size in [30]: #SIZES:
    values = random.sample(string.ascii_uppercase + string.ascii_lowercase, size)

    b = BTree(2)
    for val in values:
        b.insert(val)
    print(b)
    ## Its the generation of layout that causes some issues here.
    diagram(b, method="browser", perf_path=get_perf_path(STRUCTURE, size), perf_iterations=1)


<__main__.BTree object at 0x103e3d400>


# van Emde Boas (vEB) Tree (CLRS Ch. 20)


In [None]:
## VEB Trees Code Here ###
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Dict

def is_power_of_two(n: int) -> bool:
    return n > 0 and (n & (n - 1)) == 0

def upper_sqrt(u: int) -> int:
    # 2^{ceil(lg u / 2)}
    k = u.bit_length() - 1
    return 1 << ((k + 1) // 2)

def lower_sqrt(u: int) -> int:
    # 2^{floor(lg u / 2)}
    k = u.bit_length() - 1
    return 1 << (k // 2)

@dataclass
@hideAtom(selector="NoneType + int + {d : dict | no (d.kv)}")
@attribute(field="min")
@attribute(field="max")
@attribute(field="u")
@orientation(selector="{x,y : VEB | x->y in summary} + {x: VEB, d : dict | (some d.kv) and x->d in cluster} + kv", directions=["below"])
class VEB:
    u: int
    min: Optional[int] = None
    max: Optional[int] = None
    summary: Optional["VEB"] = None
    cluster: Dict[int, "VEB"] = field(default_factory=dict)

    def __post_init__(self):
        if not is_power_of_two(self.u):
            raise ValueError("u must be a power of two")
        if self.u < 2:
            raise ValueError("u must be >= 2")

    # index math
    def _high(self, x: int) -> int:
        return x // lower_sqrt(self.u)

    def _low(self, x: int) -> int:
        return x % lower_sqrt(self.u)

    def _index(self, high: int, low: int) -> int:
        return high * lower_sqrt(self.u) + low

    # base checks
    def _empty(self) -> bool:
        return self.min is None

    # queries
    def member(self, x: int) -> bool:
        if self.min == x or self.max == x:
            return True
        if self.u == 2:
            return False
        c = self._high(x)
        if c in self.cluster:
            return self.cluster[c].member(self._low(x))
        return False

    def minimum(self) -> Optional[int]:
        return self.min

    def maximum(self) -> Optional[int]:
        return self.max

    # insert into empty VEB
    def _empty_insert(self, x: int) -> None:
        self.min = x
        self.max = x

    # insert (CLRS VEB-INSERT)
    def insert(self, x: int) -> None:
        if self._empty():
            self._empty_insert(x)
            return
        if x < self.min:  # swap into min
            x, self.min = self.min, x
        if self.u > 2:
            c = self._high(x)
            i = self._low(x)
            if c not in self.cluster or self.cluster[c]._empty():
                # create cluster if absent
                if c not in self.cluster:
                    self.cluster[c] = VEB(lower_sqrt(self.u))
                if self.summary is None:
                    self.summary = VEB(upper_sqrt(self.u))
                # first insertion into this cluster
                self.summary.insert(c)
                self.cluster[c]._empty_insert(i)
            else:
                self.cluster[c].insert(i)
        if x > self.max:
            self.max = x

    # delete (CLRS VEB-DELETE)
    def delete(self, x: int) -> None:
        if self.min == self.max:
            # single element
            if self.min == x:
                self.min = None
                self.max = None
            return

        if self.u == 2:
            # u == 2 and two distinct values exist
            if x == 0:
                self.min = 1
            else:
                self.min = 0
            self.max = self.min
            return

        if x == self.min:
            # replace min with first element of the first non-empty cluster
            assert self.summary is not None
            first_cluster = self.summary.minimum()
            if first_cluster is None:
                # should not happen because >1 elements exist
                return
            x = self._index(first_cluster, self.cluster[first_cluster].minimum())
            self.min = x

        # delete x from its cluster
        c = self._high(x)
        i = self._low(x)
        if c in self.cluster:
            self.cluster[c].delete(i)
            if self.cluster[c]._empty():
                # remove cluster from summary
                self.summary.delete(c)
                # optional: free empty cluster to save space
                # del self.cluster[c]

                # if all clusters empty, max is min
                if self.summary._empty():
                    self.summary = None
                    self.max = self.min
                else:
                    # recompute max from last non-empty cluster
                    last_cluster = self.summary.maximum()
                    self.max = self._index(last_cluster, self.cluster[last_cluster].maximum())
            else:
                # cluster still has elements; update max if needed
                if x == self.max:
                    last_cluster = self._high(self.max)
                    self.max = self._index(last_cluster, self.cluster[last_cluster].maximum())

    # successor (CLRS VEB-SUCCESSOR)
    def successor(self, x: int) -> Optional[int]:
        if self._empty():
            return None
        if self.u == 2:
            if x == 0 and self.max == 1:
                return 1
            return None

        if x < self.min:
            return self.min

        c = self._high(x)
        i = self._low(x)

        # try in same cluster
        if c in self.cluster and not self.cluster[c]._empty() and i < self.cluster[c].maximum():
            offset = self.cluster[c].successor(i)
            return self._index(c, offset)

        # find next cluster
        if self.summary is None:
            return None
        next_cluster = self.summary.successor(c)
        if next_cluster is None:
            return None
        return self._index(next_cluster, self.cluster[next_cluster].minimum())

    # predecessor (CLRS VEB-PREDECESSOR)
    def predecessor(self, x: int) -> Optional[int]:
        if self._empty():
            return None
        if self.u == 2:
            if x == 1 and self.min == 0:
                return 0
            return None

        if x > self.max:
            return self.max

        c = self._high(x)
        i = self._low(x)

        # try in same cluster
        if c in self.cluster and not self.cluster[c]._empty() and i > self.cluster[c].minimum():
            offset = self.cluster[c].predecessor(i)
            return self._index(c, offset)

        # find previous cluster
        if self.summary is None:
            if x > self.min:
                return self.min
            return None
        prev_cluster = self.summary.predecessor(c)
        if prev_cluster is None:
            if x > self.min:
                return self.min
            return None
        return self._index(prev_cluster, self.cluster[prev_cluster].maximum())


![veb](./img/veb-tree.png)

In [None]:
# Build a VEB tree for the set {2,3,4,5,7,14,15}
S = [2, 3, 4, 5, 7, 14, 15]
u = 16  # universe size must be a power of two >= max(S)+1

veb = VEB(u)
for x in S:
    veb.insert(x)

print("min:", veb.minimum(), "max:", veb.maximum())
print("members:", sorted(x for x in S if veb.member(x)))
diagram(veb)


## Performance

In [None]:
STRUCTURE = "vebtree"
for size in SIZES:
    # vEB requires universe size u as power of two; use 64 for sufficient range
    u = 64
    values = random.sample(range(0, u), size)

    veb = VEB(u=u)
    for val in values:
        veb.insert(val)

    diagram(veb, method="browser", perf_path=get_perf_path(STRUCTURE, size), perf_iterations=PI)
    sleep(2)  # pause to allow browser to open

# Interval Trees


In [None]:

# Interval node augments RBNode. We use key == low for ordering.
@attribute(field="low")
@attribute(field="high")
@attribute(field="max")
class IntervalNode(RBNode):
    def __init__(self, low, high, color=RED, left=None, right=None, parent=None):
        super().__init__(key=low, color=color, left=left if left is not None else NIL,
                         right=right if right is not None else NIL,
                         parent=parent if parent is not None else NIL)
        self.low = low
        self.high = high
        self.max = high

class IntervalTree(RBTree):
    def __init__(self):
        super().__init__()
        # Ensure NIL has a max field used by augmentation
        if not hasattr(NIL, "max"):
            setattr(NIL, "max", 0)

    # ---- augmentation helpers ----
    def _update(self, x: IntervalNode):
        if x is NIL:
            return
        lm = x.left.max if x.left is not NIL else 0
        rm = x.right.max if x.right is not NIL else 0
        x.max = max(x.high, lm, rm)

    def _update_up(self, x: IntervalNode):
        while x is not NIL:
            old = x.max
            self._update(x)
            if x.max == old:
                break
            x = x.parent

    # ---- rotations with max maintenance ----
    def left_rotate(self, x: IntervalNode):
        y = x.right
        assert y is not NIL
        # standard rotate
        x.right = y.left
        if y.left is not NIL:
            y.left.parent = x
        y.parent = x.parent
        if x.parent is NIL:
            self.root = y
        elif x is x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y
        y.left = x
        x.parent = y
        # update augmented fields bottom-up
        self._update(x)
        self._update(y)
        self._update_up(y.parent)

    def right_rotate(self, y: IntervalNode):
        x = y.left
        assert x is not NIL
        # standard rotate
        y.left = x.right
        if x.right is not NIL:
            x.right.parent = y
        x.parent = y.parent
        if y.parent is NIL:
            self.root = x
        elif y is y.parent.left:
            y.parent.left = x
        else:
            y.parent.right = x
        x.right = y
        y.parent = x
        # update augmented fields bottom-up
        self._update(y)
        self._update(x)
        self._update_up(x.parent)

    # ---- interval overlap ----
    @staticmethod
    def _overlap(a_low, a_high, b_low, b_high) -> bool:
        return not (a_high < b_low or b_high < a_low)

    # ---- search one overlapping interval ----
    def interval_search(self, low: int, high: int):
        x = self.root
        while x is not NIL and not self._overlap(low, high, x.low, x.high):
            if x.left is not NIL and x.left.max >= low:
                x = x.left
            else:
                x = x.right
        return None if x is NIL else x

    # ---- insert (creates IntervalNode, maintains max) ----
    def insert(self, low: int, high: int):
        z = IntervalNode(low, high, color=RED)
        y = NIL
        x = self.root
        # BST insert by low key
        while x is not NIL:
            y = x
            x = x.left if z.key < x.key else x.right
        z.parent = y
        if y is NIL:
            self.root = z
        elif z.key < y.key:
            y.left = z
        else:
            y.right = z
        # fix max on path up
        self._update_up(z)
        # RB fix
        self._insert_fixup(z)
        return z

    # ---- delete (CLRS RB-DELETE + max maintenance). Expects a node handle. ----
    def delete(self, z: IntervalNode):
        # local helpers
        def transplant(u, v):
            if u.parent is NIL:
                self.root = v
            elif u is u.parent.left:
                u.parent.left = v
            else:
                u.parent.right = v
            v.parent = u.parent

        def minimum(x):
            while x.left is not NIL:
                x = x.left
            return x

        y = z
        y_orig_color = y.color
        if z.left is NIL:
            x = z.right
            transplant(z, z.right)
            start = x.parent
        elif z.right is NIL:
            x = z.left
            transplant(z, z.left)
            start = x.parent
        else:
            y = minimum(z.right)
            y_orig_color = y.color
            x = y.right
            if y.parent is z:
                x.parent = y
            else:
                transplant(y, y.right)
                y.right = z.right
                y.right.parent = y
            transplant(z, y)
            y.left = z.left
            y.left.parent = y
            y.color = z.color
            # update y first (its children changed)
            self._update(y)
            start = y

        # update augmented info upward from the structural change
        self._update_up(start)

        if y_orig_color == BLACK:
            self._delete_fixup(x)

    def _delete_fixup(self, x: IntervalNode):
        while x is not self.root and x.color is BLACK:
            if x is x.parent.left:
                w = x.parent.right
                if w.color is RED:
                    w.color = BLACK
                    x.parent.color = RED
                    self.left_rotate(x.parent)
                    w = x.parent.right
                if w.left.color is BLACK and w.right.color is BLACK:
                    w.color = RED
                    x = x.parent
                else:
                    if w.right.color is BLACK:
                        w.left.color = BLACK
                        w.color = RED
                        self.right_rotate(w)
                        w = x.parent.right
                    w.color = x.parent.color
                    x.parent.color = BLACK
                    w.right.color = BLACK
                    self.left_rotate(x.parent)
                    x = self.root
            else:
                w = x.parent.left
                if w.color is RED:
                    w.color = BLACK
                    x.parent.color = RED
                    self.right_rotate(x.parent)
                    w = x.parent.left
                if w.right.color is BLACK and w.left.color is BLACK:
                    w.color = RED
                    x = x.parent
                else:
                    if w.left.color is BLACK:
                        w.right.color = BLACK
                        w.color = RED
                        self.left_rotate(w)
                        w = x.parent.left
                    w.color = x.parent.color
                    x.parent.color = BLACK
                    w.left.color = BLACK
                    self.right_rotate(x.parent)
                    x = self.root
        x.color = BLACK
        # restore augmentation on the path up
        self._update_up(x)

    # ---- optional: exact lookup by (low, high) ----
    def find_exact(self, low: int, high: int):
        x = self.root
        while x is not NIL:
            if low == x.low and high == x.high:
                return x
            x = x.left if low < x.low else x.right
        return None


![interval-trees](./img/interval-tree.png)

In [None]:
it = IntervalTree()
for low, high in [(16,21), (8,9), (25,30), (5,8), (15,23), (17,19), (26,26), (0,3), (6,10), (19,20)]:
    it.insert(low, high)
diagram(it)

## Performance Testing

In [None]:
STRUCTURE = "intervaltree"
for size in SIZES:
    # Generate random intervals
    intervals = []
    for _ in range(size):
        low = random.randint(0, 1000)
        high = random.randint(low + 1, 1000 + 10)  # ensure high > low
        intervals.append((low, high))

    it = IntervalTree()
    for low, high in intervals:
        it.insert(low, high)

    diagram(it, method="browser", perf_path=get_perf_path(STRUCTURE, size), perf_iterations=PI)
    sleep(2)  # pause to allow browser to open