In [1]:
import bisect
from typing import Dict, List, Set, Tuple


class Interval:
    """Helper class for sortable intervals."""

    def __init__(self, min_val: float, max_val: float, id: int):
        self.min_val = min_val
        self.max_val = max_val
        self.id = id

    def __lt__(self, other):
        if isinstance(other, Interval):
            return self.min_val < other.min_val
        return self.min_val < other

class IntervalTree:
    """
    A memory-efficient interval tree that supports 1D range insertion and
    point lookup queries.

    This data structure maintains a sorted list of intervals (by their
    minimum boundary). It enables efficient searches to find intervals
    that contain a specific point. Inserting intervals is done via the
    built-in `bisect` module, ensuring the intervals remain sorted in
    ascending order of their minimum value.

    Attributes:
        intervals (List[Interval]): 
            A list of Interval objects, sorted by each Interval's `min_val`.

    Methods:
        insert(min_val: float, max_val: float, id: int):
            Inserts a new interval into the structure while keeping it sorted.
        
        find_containing(point: float) -> Set[int]:
            Finds all interval IDs whose intervals contain the specified point.
    """

    def __init__(self):
        self.intervals: List[Interval] = []  # Sorted by min_val

    def insert(self, min_val: float, max_val: float, id: int):
        interval = Interval(min_val, max_val, id)
        bisect.insort(self.intervals, interval)

    def find_containing(self, point: float) -> Set[int]:
        """Find all intervals containing the point."""
        pos = bisect.bisect_left(self.intervals, point)
        result = set()

        # Check intervals starting before point
        for interval in reversed(self.intervals[:pos]):
            if interval.min_val <= point <= interval.max_val:
                result.add(interval.id)
            else:
                break
                
        # Check intervals starting at/after point
        for interval in self.intervals[pos:]:
            if point <= interval.max_val:
                if interval.min_val <= point:
                    result.add(interval.id)
            else:
                break
                
        return result

In [2]:
tree = IntervalTree()
tree.insert(1, 5, 101)
tree.insert(3, 7, 102)
tree.insert(6, 10, 103)
tree.insert(4, 15, 104)

In [3]:
point_to_check = 1
# returns the id of the 1d cluster or multiple clutsters that has the value point to chekc
tree.find_containing(point_to_check) 

{101}

In [4]:
tree.find_containing(9)

{103, 104}

In [6]:
from rtree import index

from docling.datamodel.base_models import BoundingBox, Cell, Cluster

In [7]:
class SpatialClusterIndex:
    """Efficient spatial indexing for clusters using R-tree and interval trees."""

    def __init__(self, clusters: List[Cluster]):
        p = index.Property()
        p.dimension = 2
        self.spatial_index = index.Index(properties=p)
        self.x_intervals = IntervalTree()
        self.y_intervals = IntervalTree()
        self.clusters_by_id: Dict[int, Cluster] = {}

        for cluster in clusters:
            self.add_cluster(cluster)

    def add_cluster(self, cluster: Cluster):
        bbox = cluster.bbox
        self.spatial_index.insert(cluster.id, bbox.as_tuple())
        self.x_intervals.insert(bbox.l, bbox.r, cluster.id)
        self.y_intervals.insert(bbox.t, bbox.b, cluster.id)
        self.clusters_by_id[cluster.id] = cluster

    def remove_cluster(self, cluster: Cluster):
        self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple())
        del self.clusters_by_id[cluster.id]

    def find_candidates(self, bbox: BoundingBox) -> Set[int]:
        """Find potential overlapping cluster IDs using all indexes."""
        spatial = set(self.spatial_index.intersection(bbox.as_tuple()))
        x_candidates = self.x_intervals.find_containing(
            bbox.l
        ) | self.x_intervals.find_containing(bbox.r)
        y_candidates = self.y_intervals.find_containing(
            bbox.t
        ) | self.y_intervals.find_containing(bbox.b)
        return spatial.union(x_candidates).union(y_candidates)

In [14]:
# Mock data: Define clusters and cells
clusters = [
    Cluster(id=1, bbox=BoundingBox(l=0, t=0, r=5, b=5), confidence=0.9, cells=[],label='text'),
    Cluster(id=2, bbox=BoundingBox(l=4, t=4, r=10,b= 10), confidence=0.8, cells=[],label='text'),
    Cluster(id=3, bbox=BoundingBox(l=8, t=8, r=12, b=12), confidence=0.7, cells=[],label='text'),
    Cluster(id=4, bbox=BoundingBox(l=10, t=10, r=15, b=15), confidence=0.95, cells=[],label='text'),
]
# Initialize SpatialClusterIndex with the clusters
spatial_index = SpatialClusterIndex(clusters)

In [15]:
# Find overlapping candidates using the spatial index
spatial_index.find_candidates(clusters[0].bbox)

{1, 2}