# Union–Find (Disjoint Set Union, DSU)

Union–Find maintains a partition of elements into disjoint sets and supports two core operations efficiently:

- find(x): return a canonical representative (root) of the set containing x
- union(x, y): merge the sets containing x and y

With union by size/rank and path compression, each operation runs in amortized near-constant time, specifically O(α(n)) where α is the inverse Ackermann function, which grows slower than log.

## Why and when to use
Common use cases:
- Dynamic connectivity: quickly connect items and query if two items are in the same group
- Connected components in an undirected graph (count components, group nodes)
- Kruskal's Minimum Spanning Tree (MST)
- Cycle detection in undirected graphs
- Accounts merge / friend circles (group by equivalence)
- Constraint satisfaction like equality equations (e.g., a==b, a!=c)
- Percolation / image segmentation / unioning grid cells

## API we'll implement
- add(x): add an element if not present
- find(x): return the root representative; adds x implicitly if new
- union(x, y): union-by-rank with path compression; returns the new root
- connected(x, y): true if x and y are in the same set
- size_of(x): size of the set containing x
- count(): number of disjoint sets currently tracked
- groups(): map root -> elements in that group

In [5]:
use std::collections::HashMap;
use std::hash::Hash;

// UnionFind represents a disjoint set union data structure
#[derive(Debug, Clone)]
pub struct UnionFind<T> {
    parent: HashMap<T, T>,
    size: HashMap<T, usize>,
    rank: HashMap<T, usize>,
    count: usize,
}

impl<T> UnionFind<T>
where
    T: Clone + Eq + Hash,
{
    // Create a new Union-Find structure
    pub fn new() -> Self {
        UnionFind {
            parent: HashMap::new(),
            size: HashMap::new(),
            rank: HashMap::new(),
            count: 0,
        }
    }

    // Create a Union-Find with initial elements
    pub fn with_elements<I>(elements: I) -> Self
    where
        I: IntoIterator<Item = T>,
    {
        let mut uf = Self::new();
        for elem in elements { uf.add(elem); }
        uf
    }

    // Add a new element as a singleton set if unseen
    pub fn add(&mut self, x: T) {
        if !self.parent.contains_key(&x) {
            self.parent.insert(x.clone(), x.clone());
            self.size.insert(x.clone(), 1);
            self.rank.insert(x, 0);
            self.count += 1;
        }
    }

    // Find the representative of the set containing x
    // Uses path compression to flatten trees
    pub fn find(&mut self, x: T) -> T {
        if !self.parent.contains_key(&x) {
            self.add(x.clone());
            return x;
        }

        // Path compression find
        let mut root = x.clone();
        while root != *self.parent.get(&root).unwrap() {
            root = self.parent.get(&root).unwrap().clone();
        }

        // Compress the path
        let mut current = x;
        while current != root {
            let parent = self.parent.get(&current).unwrap().clone();
            self.parent.insert(current.clone(), root.clone());
            current = parent;
        }

        root
    }

    // Union the sets containing a and b using union-by-rank
    // Returns the new root
    pub fn union(&mut self, a: T, b: T) -> T {
        let root_a = self.find(a);
        let root_b = self.find(b);

        if root_a == root_b { return root_a; }

        let rank_a = *self.rank.get(&root_a).unwrap();
        let rank_b = *self.rank.get(&root_b).unwrap();

        let (new_root, old_root) = if rank_a < rank_b { (root_b, root_a) } else { (root_a, root_b) };

        self.parent.insert(old_root.clone(), new_root.clone());
        let size_new = self.size.get(&new_root).unwrap() + self.size.get(&old_root).unwrap();
        self.size.insert(new_root.clone(), size_new);
        self.size.remove(&old_root);

        if rank_a == rank_b {
            let new_rank = self.rank.get(&new_root).unwrap() + 1;
            self.rank.insert(new_root.clone(), new_rank);
        }
        self.rank.remove(&old_root);

        self.count -= 1;
        new_root
    }

    // Return true if a and b are in the same set
    pub fn connected(&mut self, a: T, b: T) -> bool { self.find(a) == self.find(b) }

    // Size of the set containing x
    pub fn size_of(&mut self, x: T) -> usize {
        let root = self.find(x);
        *self.size.get(&root).unwrap_or(&0)
    }

    // Rank of the set containing x
    pub fn rank_of(&mut self, x: T) -> usize {
        let root = self.find(x);
        *self.rank.get(&root).unwrap_or(&0)
    }

    // Number of disjoint sets
    pub fn count(&self) -> usize { self.count }

    // All tracked elements
    pub fn elements(&self) -> Vec<T> { self.parent.keys().cloned().collect() }

    // Groups: mapping of root -> elements in that group
    pub fn groups(&mut self) -> HashMap<T, Vec<T>> {
        let mut groups = HashMap::new();
        let elements: Vec<T> = self.parent.keys().cloned().collect();
        for elem in elements {
            let root = self.find(elem.clone());
            groups.entry(root).or_insert_with(Vec::new).push(elem);
        }
        groups
    }

    // String representation (only requires Debug here)
    pub fn description(&mut self) -> String
    where
        T: std::fmt::Debug,
    {
        let groups = self.groups();
        format!("UnionFind(count={}, groups={:?})", self.count, groups)
    }
}

impl<T> Default for UnionFind<T>
where
    T: Clone + Eq + Hash,
{
    fn default() -> Self { Self::new() }
}

## Examples

Below are examples showing basic usage and connected components in a graph.

In [6]:
// Quick sanity tests
let mut uf = UnionFind::with_elements(vec![1, 2, 3, 4, 5]);
println!("Initial count: {}", uf.count());

uf.union(1, 2);
uf.union(3, 4);
println!("After unions, connected(1,2): {}", uf.connected(1, 2));
println!("Connected(2,3): {}", uf.connected(2, 3));
println!("Size of set containing 1: {}", uf.size_of(1));

let root = uf.union(2, 3);
println!("Root after union(2,3): {} rank: {}", root, uf.rank_of(root));
println!("Connected(1,4): {}", uf.connected(1, 4));
println!("Size of set containing 4: {}", uf.size_of(4));
println!("Final count: {}", uf.count());

// Implicit add
println!("Connected(99,5): {}", uf.connected(99, 5));
uf.union(99, 5);
println!("After union(99,5), connected(99,5): {}", uf.connected(99, 5));
println!("{}", uf.description());

Initial count: 5
After unions, connected(1,2): true
Connected(2,3): false
Size of set containing 1: 2
Root after union(2,3): 1 rank: 2
Connected(1,4): true
Size of set containing 4: 4
Final count: 2
Connected(99,5): false
After union(99,5), connected(99,5): true
UnionFind(count=2, groups={99: [99, 5], 1: [4, 2, 3, 1]})
After unions, connected(1,2): true
Connected(2,3): false
Size of set containing 1: 2
Root after union(2,3): 1 rank: 2
Connected(1,4): true
Size of set containing 4: 4
Final count: 2
Connected(99,5): false
After union(99,5), connected(99,5): true
UnionFind(count=2, groups={99: [99, 5], 1: [4, 2, 3, 1]})


In [7]:
// Example: Counting connected components in an undirected graph
fn count_components(n: usize, edges: &[(usize, usize)]) -> (usize, std::collections::HashMap<usize, Vec<usize>>) {
    let mut uf = UnionFind::with_elements(0..n);
    
    for &(u, v) in edges {
        uf.union(u, v);
    }
    
    let count = uf.count();
    let groups = uf.groups();
    (count, groups)
}

let n = 7;
let edges = [(0, 1), (1, 2), (3, 4), (5, 6)];
let (count, groups) = count_components(n, &edges);
println!("Components: {}", count);
println!("Groups:");
for (root, group) in groups {
    println!("{}: {:?}", root, group);
}

Components: 3
Groups:
3: [3, 4]
0: [2, 0, 1]
5: [5, 6]
Groups:
3: [3, 4]
0: [2, 0, 1]
5: [5, 6]


()