In [None]:
%%HTML
<style>
    body {
        --vscode-font-family: "Noto Serif"
    }
</style>

# Union–Find (Disjoint Set Union, DSU)

Union–Find maintains a partition of elements into disjoint sets and supports two core operations efficiently:

- find(x): return a canonical representative (root) of the set containing x
- union(x, y): merge the sets containing x and y

With union by size/rank and path compression, each operation runs in amortized near-constant time, specifically O(α(n)) where α is the inverse Ackermann function, which grows slower than log.

## Why and when to use
Common use cases:
- Dynamic connectivity: quickly connect items and query if two items are in the same group
- Connected components in an undirected graph (count components, group nodes)
- Kruskal's Minimum Spanning Tree (MST)
- Cycle detection in undirected graphs
- Accounts merge / friend circles (group by equivalence)
- Constraint satisfaction like equality equations (e.g., a==b, a!=c)
- Percolation / image segmentation / unioning grid cells

## API we'll implement
- Add(x): add an element if not present
- Find(x): return the root representative; adds x implicitly if new
- Union(x, y): union-by-rank with path compression; returns the new root
- Connected(x, y): true if x and y are in the same set
- SizeOf(x): size of the set containing x
- Count(): number of disjoint sets currently tracked
- Groups(): map root -> elements in that group

In [None]:
package main

import (
    "fmt"
)

// UnionFind represents a disjoint set union data structure
type UnionFind[T comparable] struct {
    parent map[T]T
    size   map[T]int
    rank   map[T]int
    count  int
}

// NewUnionFind creates a new Union-Find structure
func NewUnionFind[T comparable]() *UnionFind[T] {
    return &UnionFind[T]{
        parent: make(map[T]T),
        size:   make(map[T]int),
        rank:   make(map[T]int),
        count:  0,
    }
}

// NewUnionFindWithElements creates a Union-Find with initial elements
func NewUnionFindWithElements[T comparable](elements []T) *UnionFind[T] {
    uf := NewUnionFind[T]()
    for _, elem := range elements {
        uf.Add(elem)
    }
    return uf
}

// Add adds a new element as a singleton set if unseen
func (uf *UnionFind[T]) Add(x T) {
    if _, exists := uf.parent[x]; !exists {
        uf.parent[x] = x
        uf.size[x] = 1
        uf.rank[x] = 0
        uf.count++
    }
}

// Find returns the representative of the set containing x
// Uses path compression to flatten trees
func (uf *UnionFind[T]) Find(x T) T {
    if _, exists := uf.parent[x]; !exists {
        uf.Add(x)
        return x
    }
    
    // Path compression find
    root := x
    for root != uf.parent[root] {
        root = uf.parent[root]
    }
    
    // Compress the path
    for x != root {
        parent := uf.parent[x]
        uf.parent[x] = root
        x = parent
    }
    
    return root
}

// Union merges the sets containing a and b using union-by-rank
// Returns the new root
func (uf *UnionFind[T]) Union(a, b T) T {
    rootA := uf.Find(a)
    rootB := uf.Find(b)
    
    if rootA == rootB {
        return rootA
    }
    
    // Union by rank: attach lower-rank tree under higher-rank tree
    if uf.rank[rootA] < uf.rank[rootB] {
        rootA, rootB = rootB, rootA
    }
    
    uf.parent[rootB] = rootA
    uf.size[rootA] += uf.size[rootB]
    delete(uf.size, rootB)
    
    if uf.rank[rootA] == uf.rank[rootB] {
        uf.rank[rootA]++
    }
    delete(uf.rank, rootB)
    
    uf.count--
    return rootA
}

// Connected returns true if a and b are in the same set
func (uf *UnionFind[T]) Connected(a, b T) bool {
    return uf.Find(a) == uf.Find(b)
}

// SizeOf returns the size of the set containing x
func (uf *UnionFind[T]) SizeOf(x T) int {
    root := uf.Find(x)
    return uf.size[root]
}

// RankOf returns the rank of the set containing x
func (uf *UnionFind[T]) RankOf(x T) int {
    root := uf.Find(x)
    return uf.rank[root]
}

// Count returns the number of disjoint sets
func (uf *UnionFind[T]) Count() int {
    return uf.count
}

// Elements returns all tracked elements
func (uf *UnionFind[T]) Elements() []T {
    elements := make([]T, 0, len(uf.parent))
    for elem := range uf.parent {
        elements = append(elements, elem)
    }
    return elements
}

// Groups returns a mapping of root -> elements in that group
func (uf *UnionFind[T]) Groups() map[T][]T {
    groups := make(map[T][]T)
    for elem := range uf.parent {
        root := uf.Find(elem)
        groups[root] = append(groups[root], elem)
    }
    return groups
}

// String returns a string representation
func (uf *UnionFind[T]) String() string {
    return fmt.Sprintf("UnionFind(count=%d, groups=%v)", uf.count, uf.Groups())
}

## Examples

Below are examples showing basic usage and connected components in a graph.

In [None]:
%%
// Quick sanity tests
uf := NewUnionFindWithElements([]int{1, 2, 3, 4, 5})
fmt.Println("Initial count:", uf.Count())

uf.Union(1, 2)
uf.Union(3, 4)
fmt.Println("After unions, connected(1,2):", uf.Connected(1, 2))
fmt.Println("Connected(2,3):", uf.Connected(2, 3))
fmt.Println("Size of set containing 1:", uf.SizeOf(1))

root := uf.Union(2, 3)
fmt.Println("Root after union(2,3):", root, "rank:", uf.RankOf(root))
fmt.Println("Connected(1,4):", uf.Connected(1, 4))
fmt.Println("Size of set containing 4:", uf.SizeOf(4))
fmt.Println("Final count:", uf.Count())

// Implicit add
fmt.Println("Connected(99,5):", uf.Connected(99, 5))
uf.Union(99, 5)
fmt.Println("After union(99,5), connected(99,5):", uf.Connected(99, 5))
fmt.Println(uf)

In [None]:
%%
// Example: Counting connected components in an undirected graph
func countComponents(n int, edges [][2]int) (int, map[int][]int) {
    uf := NewUnionFind[int]()
    for i := 0; i < n; i++ {
        uf.Add(i)
    }
    
    for _, edge := range edges {
        uf.Union(edge[0], edge[1])
    }
    
    return uf.Count(), uf.Groups()
}

n := 7
edges := [][2]int{{0, 1}, {1, 2}, {3, 4}, {5, 6}}
count, groups := countComponents(n, edges)
fmt.Println("Components:", count)
fmt.Println("Groups:")
for root, group := range groups {
    fmt.Printf("%d: %v\n", root, group)
}