In [14]:
class UnionFind:
    """
    A lightweight implementation of the Union-Find (disjoint set) data structure.
    
    This class efficiently manages which elements belong to the same connected "group."
    When two elements are found to overlap or otherwise be related, calling `union(x, y)`
    merges them into one set. All elements in the same set share the same "root," which
    can be retrieved by `find(x)`.

    Typical usage in the context of layout analysis:
    
    1. Use a spatial index (e.g., R-tree) to find overlapping clusters. For each pair of
       clusters that significantly overlap, union their IDs.
    2. After processing all overlaps, call `get_groups()` to retrieve the sets of mutually
       connected elements. For example:
       - If cluster 1 overlaps cluster 2 and cluster 2 overlaps cluster 3, all three
         (1, 2, 3) end up in the same group, even if 1 and 3 do not directly overlap.
       - Clusters with no overlaps remain alone in their own sets.
    3. In each group, you can select a "best" representative cluster (e.g., highest
       confidence or largest bounding box) and merge other clusters' cells into it.
       This helps eliminate duplicates while preserving the data from the "loser" clusters.
    
    The internal mechanism uses:
    - **Path Compression** in `find(x)` to speed up future lookups.
    - **Union by Rank** in `union(x, y)` to keep trees shallow, further improving efficiency.

    Attributes:
        parent (Dict[int, int]): Maps each element to its parent. Initially points to itself.
        rank (Dict[int, int]): Stores the tree depth for each root to prioritize smaller subtrees.
    """

    def __init__(self, elements):
        self.parent = {elem: elem for elem in elements}
        self.rank = {elem: 0 for elem in elements}

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # Path compression
        return self.parent[x]

    def union(self, x, y):
        root_x, root_y = self.find(x), self.find(y)
        if root_x == root_y:
            return

        if self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        elif self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1

    def get_groups(self) -> Dict[int, List[int]]:
        """
        Returns the final partitioning of elements after all unions. 
        
        The returned dictionary maps the "root" representative of each group to a list
        of all elements in that group. For example:
        
            {
                root1: [elementA, elementB, elementC],
                root2: [elementD],
                ...
            }
        
        where root1 and root2 are the 'find(...)' results for each group's representative.
        """
        from collections import defaultdict
        groups = defaultdict(list)
        for elem in self.parent:
            groups[self.find(elem)].append(elem)
        return groups

In [15]:
valid_clusters ={
    1: Cluster(id=1, bbox=BoundingBox(l=0, t=0, r=10, b=10), cells=[], label='text'),
    2: Cluster(id=2, bbox=BoundingBox(l=8, t=8, r=18, b=18), cells=[], label='text'),
    3: Cluster(id=3, bbox=BoundingBox(l=20, t=20, r=30, b=30), cells=[], label='text'),
    4: Cluster(id=4, bbox=BoundingBox(l=9, t=9, r=12, b=12), cells=[], label='text'),
}


In [16]:
uf = UnionFind(valid_clusters.keys())

In [17]:
uf.union(4,1)

In [18]:
uf.union(1,2)

In [19]:
uf.get_groups()

defaultdict(list, {4: [1, 2, 4], 3: [3]})

In [20]:
# so basically 1,2,3 have union and they are connected together