Initialize

parent[i] ← i for i=1..n

rank[i] ← 0

Find(x) — path compression

If parent[x] == x: return x

Else: parent[x] ← Find(parent[x]); return parent[x]

Union(u, v) — union by rank

ru ← Find(u); rv ← Find(v)

If ru == rv: return

If rank[ru] < rank[rv]: parent[ru] ← rv

Else if rank[ru] > rank[rv]: parent[rv] ← ru

Else: parent[rv] ← ru; rank[ru]++

Usage pattern

For each connection (u, v): Union(u, v)

To check same set: Find(u) == Find(v)

Complexity

Amortized ~ O(α(n)) per op (near-constant)

Notes

Path compression + rank = fastest practical DSU

Works on 1-indexed or 0-indexed arrays consistently

In [4]:
# disjoint_set.py

class DisjointSet:
    def __init__(self, n):
        self.parent = [i for i in range(n + 1)]  # 1..n
        self.rank = [0] * (n + 1)

    def find(self, x):
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])  # path compression
        return self.parent[x]

    def union(self, u, v):
        pu, pv = self.find(u), self.find(v)
        if pu == pv:
            return
        # union by rank
        if self.rank[pu] < self.rank[pv]:
            self.parent[pu] = pv
        elif self.rank[pu] > self.rank[pv]:
            self.parent[pv] = pu
        else:
            self.parent[pv] = pu
            self.rank[pu] += 1


if __name__ == "__main__":
    ds = DisjointSet(7)  # max element is 7

    ds.union(1, 2)
    ds.union(2, 3)
    ds.union(4, 5)
    ds.union(6, 7)
    ds.union(5, 6)
    ds.union(3, 7)

    print(ds.find(1))  # representative of 1..7 should match
    print(ds.find(7))


1
4


Initialization:
parent[i] = i, size[i] = 1

Find(x):
Return ultimate parent using path compression.

Union(u, v):

Find ultimate parents pu, pv

If different, attach smaller size tree → larger size tree

Update size of new root

Complexity:
Amortized O(α(n)) ≈ constant time

In [None]:
# disjoint_set_size.py

class DisjointSet:
    def __init__(self, n):
        self.parent = [i for i in range(n + 1)]  # parent[i] = i initially
        self.size = [1] * (n + 1)  # each set starts with size 1

    def find(self, x):
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])  # path compression
        return self.parent[x]

    def union(self, u, v):
        pu, pv = self.find(u), self.find(v)
        if pu == pv:
            return

        # Union by size: attach smaller tree to larger tree
        if self.size[pu] < self.size[pv]:
            self.parent[pu] = pv
            self.size[pv] += self.size[pu]
        else:
            self.parent[pv] = pu
            self.size[pu] += self.size[pv]


if __name__ == "__main__":
    ds = DisjointSet(7)

    ds.union(1, 2)
    ds.union(2, 3)
    ds.union(4, 5)
    ds.union(6, 7)
    ds.union(5, 6)
    ds.union(3, 7)

    print("Parent array:", ds.parent)
    print("Size array:", ds.size)
    print("Find(1):", ds.find(1))
    print("Find(7):", ds.find(7))
