# Union-Find 算法

- https://labuladong.gitee.io/algo/2/20/40/
- https://www.cnblogs.com/jiduxia/p/9602878.html

## 问题本质
- 等价关系，一般包含等价关系的都可以想想UF 算法！
 
## 几个操作
- union：一开始有n个分量，每次union 都会归并让连通分量减一
- find：返回指定节点的连通分量标识符
- connected
- count

- 一步到位，最优解法，演进过程可以看下面的quick-find， quick-union

In [89]:
# 构造函数，n 为图的节点总数
class UF():
    # 构造函数，n 为图的节点总数
    def __init__(self, n):
        # 记录连通分量
        self.n = n
        # 一开始互不连通
        self.countN = self.n
        # 节点 x 的节点是 parent[x]
        self.parent = []
        # 父节点指针初始指向自己
        self.parent = [-1]*self.n
        for i in range(n):
            self.parent[i] = i
        print(self.parent)       # [0, 1, 2, 3, 4]
    
    # 如果某两个节点被连通，则让其中的（任意）一个节点的根节点接到另一个节点的根节点上
    # 将 p 和 q 连接
    def union(self, p, q):
        rootP = self.find(p)
        rootQ = self.find(q)

        if rootP == rootQ:
            return

        # 将两棵树合并为一棵
        self.parent[rootP] = rootQ
        # self.parent[rootQ] = rootP 也一样
        self.countN -= 1       # 两个分量合二为一
        print(self.parent)


    # 返回某个节点 x 的根节点
    # 理清楚这个就行了
    def find(self, p):
        # print("self.parent[]",p, self.parent[p])
        # 第一次循环，p = 0, self[0] = 1
        # 第二次循环，p = 1, self[1] = 2
        # 第三次循环，p = 2, self[2] = 2
        while self.parent[p] != p:      # 循环，直到找到根节点
            # 在检查节点的同时，直接将他们连接到根节点上
            self.parent[p] = self.parent[self.parent[p]]
            p = self.parent[p]
        return p

    # 返回当前的连通分量个数
    def count(self,):
        return self.countN

    def connected(self, p, q):
        rootP = self.find(p)
        rootQ = self.find(q)
        return rootP == rootQ         

In [90]:
class countComponents():
    def __init__(self, n, edges):
        # 记录总的节点个数
        self.n = n 
        self.edges = edges

    def countC(self,):
        uf = UF(self.n)
        # 把所给的边都 union
        for e in self.edges:
            uf.union(e[0], e[1])
        return uf.count()
    
    def connectedC(self, p, q):
        uf = UF(self.n)
        for e in self.edges:
            uf.union(e[0], e[1])
        return uf.connected(p, q)

In [92]:
n = 5
edges = [[0, 1], [1, 2], [0, 2], [3, 4]]
#  0          3
#  |          |
#  1 --- 2    4 


edges = [[0, 1], [1, 2], [2, 3], [3, 4]]
#  0           4
#  |           |
#  1 --- 2 --- 3

s = countComponents(n, edges)
s.countC()

[0, 1, 2, 3, 4]
[1, 1, 2, 3, 4]
[1, 2, 2, 3, 4]
[1, 2, 3, 3, 4]
[1, 2, 3, 4, 4]


1

# Quick_Find
- 平方级复杂度，不行

In [68]:
class Quick_Find:
    def __init__(self,N):
        self.count = N
        # 每个分量有自己的标识符
        self.ids = [i for i in range(self.count)]
        print("self.ids0",self.ids)

    def connect(self,p,q):
        return self.find[p] == self.find(q)

    def find(self,p):
        return self.ids[p]

    def union(self,p,q):
        pId = self.find(p)
        qId = self.find(q)

        # 如果 p 和 q 已经在相同的分量中，不需要采取操作
        if pId == qId:
            return

        # 将 p 的分量重命名为 q
        # 这里会把所有已经连同的分量全都重命名
        for i in range(len(self.ids)):
            if self.ids[i] == pId:
                self.ids[i] = qId
        self.count-=1
        print("self.ids",self.ids)

    def getcount(self):
        return self.count

In [69]:
class countComponents():
    def __init__(self, n, edges):
        # 记录总的节点个数
        self.n = n 
        self.edges = edges

    def countC(self,):
        uf = Quick_Find(self.n)
        for e in self.edges:
            uf.union(e[0], e[1])
        return uf.getcount()

In [70]:
n = 5
edges = [[0, 1], [1, 2], [0, 2], [3, 4]]
#  0          3
#  |          |
#  1 --- 2    4 


# edges = [[0, 1], [1, 2], [2, 3], [3, 4]]
#  0           4
#  |           |
#  1 --- 2 --- 3

s = countComponents(n, edges)
s.countC()

self.ids0 [0, 1, 2, 3, 4]
self.ids [1, 1, 2, 3, 4]
self.ids [2, 2, 2, 3, 4]
self.ids [2, 2, 2, 4, 4]


2

# Quick_Union
- 不能用，算法复杂度，平方级
- 理解怎么构造的树！！森林！！

In [80]:
class Quick_Union:
    def __init__(self,N):
        self.count = N
        self.ids = [i for i in range(N)]
        print("self.ids0",self.ids)

    def connect(self,p,q):
        return self.find(p) == self.find(q)

    def find(self,p):
        while self.ids[p] != p: # 循环，直到找到根节点
            # print("while", p, self.ids[p])
            p = self.ids[p]
        return p

    def union(self,p,q):
        pID = self.find(p)
        qID = self.find(q)
        if pID == qID:
            return
        
        # 将 p 和 q 的根节点统一！！
        self.ids[pID] = qID
        self.count -= 1
        print("self.ids",self.ids)

    def getcount(self):
        return self.count

In [81]:
class countComponents():
    def __init__(self, n, edges):
        # 记录总的节点个数
        self.n = n 
        self.edges = edges

    def countC(self,):
        uf = Quick_Union(self.n)
        for e in self.edges:
            uf.union(e[0], e[1])
        return uf.getcount()

In [84]:
n = 5
# edges = [[0, 1], [1, 2], [0, 2], [3, 4]]
#  0          3
#  |          |
#  1 --- 2    4 


edges = [[0, 1], [1, 2], [2, 3], [3, 4]]
#  0           4
#  |           |
#  1 --- 2 --- 3

# 这里面的这种链式结构是理解的难点
# self.ids0 [0, 1, 2, 3, 4]
# self.ids [1, 1, *, *, *]
# self.ids [*, 2, 2, *, *]
# self.ids [*, *, 3, 3, *]
# self.ids [*, *, *, 4, 4]

s = countComponents(n, edges)
s.countC()

self.ids0 [0, 1, 2, 3, 4]
self.ids [1, 1, 2, 3, 4]
self.ids [1, 2, 2, 3, 4]
self.ids [1, 2, 3, 3, 4]
self.ids [1, 2, 3, 4, 4]


1

# Weighted_Union_Find

In [52]:
class Weighted_Union_Find:
    def __init__(self,N):
        self.count = N
        self.ids = [i for i in range(N)]
        self.size = [1 for i in range(N)] # 加权

    def connect(self,p,q):
        return self.find(p) == self.find(q)

    def find(self,p):
        while self.ids[p] != p:
            p = self.ids[p]
        return p

    def union(self,p,q):
        pID = self.find(p)
        qID = self.find(q)
        if pID == qID:
            return

        if self.size[pID] < self.size[qID]: # 小的树并到大的树下
            self.ids[pID] = qID
            self.size[qID] += self.size[pID]
        else:
            self.ids[qID] = pID
            self.size[pID] += self.size[qID]
        self.count-=1

    def getcount(self):
        return self.count


In [53]:
class countComponents():
    def __init__(self, n, edges):
        # 记录总的节点个数
        self.n = n 
        self.edges = edges

    def countC(self,):
        uf = Weighted_Union_Find(self.n)
        for e in self.edges:
            uf.union(e[0], e[1])
        return uf.getcount()

In [55]:
n = 5
edges = [[0, 1], [1, 2], [0, 2], [3, 4]]
#  0          3
#  |          |
#  1 --- 2    4 


# edges = [[0, 1], [1, 2], [2, 3], [3, 4]]
#  0           4
#  |           |
#  1 --- 2 --- 3

s = countComponents(n, edges)
s.countC()

2