# Prim's Algorithm

In [20]:
from typing import List
import collections
import heapq

def minimumCost(N: int, connections: List[List[int]]) -> int:
    '''
    Prim's Algorithm:
    1) Initialize a tree with a single node, chosen
    arbitrarily from the graph.
    2) Grow the tree by one edge: of the edges that
    connect the tree to vertices not yet in the tree,
    find the minimum-weight edge, and transfer it to the tree.
    3) Repeat step 2 (until all vertices are in the tree).
    '''
    # city1 <-> city2 may have multiple different cost connections,
    # so use a list of tuples. Nested dict will break algorithm.
    # ここの持ち方は、与えられた問題による
    G = collections.defaultdict(list)
    for node1, node2, cost in connections:
        # 優先度キューに入れるために (cost, node) の順
        G[node1].append((cost, node2))
        G[node2].append((cost, node1))
    
    print(G)

    # queue = [(cost, node)] # heappop するときは、cost を見ている。
    # コストが小さい=優先度が高い、優先度キュー
    queue = [(0, 1)]  # [1] Arbitrary starting point N costs 0.
    visited = set()
    total = 0
    while queue and len(visited) < N: # [3] Exit if all cities are visited.
        # cost is always least cost connection in queue.
        cost, city = heapq.heappop(queue)
        print("pop", cost, city)
        if city not in visited:
            visited.add(city)
            total += cost # [2] Grow tree by one edge.
            for edge_cost, next_city in G[city]:

                heapq.heappush(queue, (edge_cost, next_city))
            print(city)
            print("queue に行けるnodeを追加", queue)
    return total if len(visited) == N else -1

N = 3
connections = [[1,2,5],[1,3,6],[2,3,1]]
print(minimumCost(N, connections))

defaultdict(<class 'list'>, {1: [(5, 2), (6, 3)], 2: [(5, 1), (1, 3)], 3: [(6, 1), (1, 2)]})
pop 0 1
1
queue に行けるnodeを追加 [(5, 2), (6, 3)]
pop 5 2
2
queue に行けるnodeを追加 [(1, 3), (6, 3), (5, 1)]
pop 1 3
3
queue に行けるnodeを追加 [(1, 2), (5, 1), (6, 1), (6, 3)]
6


In [None]:
# Find Union

In [26]:
def find(x):
    # BaseCase
    if par[x] == x:
        return x
    else:
        return find(par[x])
    
def same(x,y):
    return find(x) == find(y)

def unite(x,y):
    x = find(x)
    y = find(y)
    if x == y:
        return 0 # 既に同じグループ
    par[x] = y
    return True

N = 7
par = {node: node for node in range(1, N+1)}
print(par)
unite(1,5)
unite(2,6)
unite(3,4)
unite(2,4)
print(par)

print(same(2,3))

{1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7}
{1: 5, 2: 6, 3: 4, 4: 4, 5: 5, 6: 4, 7: 7}
True


### union by size
size の小さい方を大きい方にuniteする。木の高さが O(logN) に抑えられる

In [29]:
def find(x):
    # BaseCase
    if par[x] == x:
        return x
    else:
        return find(par[x])
    
def same(x,y):
    return find(x) == find(y)

def unite(x,y):
    size_x = 0
    size_y = 0
    
    x = find(x)
    y = find(y)
    
    for i in range(1, N+1):
        if x == find(i):
            size_x +=1
        if y == find(i):
            size_y += 1

    if x == y:
        return 0 # 既に同じグループ
    if size_x <= size_y:
        small = x
        large = y
    else:
        small = y
        large  = x
    par[small] = large
    return True

N = 7
par = {node: node for node in range(1, N+1)}
print(par)
unite(1,5)
unite(2,6)
unite(3,4)
unite(2,4)
print(par)

print(same(2,3))

{1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7}
{1: 5, 2: 6, 3: 4, 4: 4, 5: 5, 6: 4, 7: 7}
True


### 経路圧縮

In [25]:
def find(x):
    if par[x] == x:
        return x
    else:
        par[x] = find(par[x]) #par[x] を更新してをしている経路圧縮
        return par[x]
    
def same(x,y):
    return find(x) == find(y)

def unite(x,y):
    x = find(x)
    y = find(y)
    if x == y:
        return 0 # 既に同じグループ
    par[x] = y
    return True

N = 7
par = {node: node for node in range(1, N+1)}
print(par)
unite(1,5)
unite(2,6)
unite(3,4)
unite(2,4)
print(par)

print(same(2,3))

{1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7}
{1: 5, 2: 6, 3: 4, 4: 4, 5: 5, 6: 4, 7: 7}


# Kruskal's Algorithm

In [None]:
def minimumCost(N: int, connections: List[List[int]]) -> int:
    '''
    Kruskal's Algorithm:
    1) Create a forest F (a set of trees), where each vertex in 
    the graph is a separate tree.
    2) Create a set S containing all the edges in the graph.
    3) While S is nonempty and F is not yet spanning (fully connected):
        3A) Remove an edge with minimum weight from S
        3B) If the removed edge connects two different trees then 
        add it to the forest F, combining two trees into a single tree.
        
    time complexity: sort connections by cost(O(ElogV)), union process (O(Eα(V))
    '''
    def find(node):
        """自分の親を探す"""
        # Recursively re-set city's parent to its parent's parent.
        # Build the bush: ideally each tree/set is of height 1.
        # 自分が親でなかったら再帰
        if parent[node] != node:
            parent[node] = find(parent[node]) # indirect な繋がり方を探すために recursive。経路圧縮している。
        # 自分が親なら自分の番号を返す
        return parent[node]

    def union(c1, c2):
        """親を統合して、c1 と c2 の親が既に同じ(つまり統合済)なら False, 異なっていて統合したら True"""
        root1, root2 = find(c1), find(c2)
        
        # 既に尾長同じである = c1 と c2 をつないでしまうとサイクルができてしまう
        if root1 == root2:
            return False
        # ルートを同じにする = c1 と c2 をつなぐことと同じ
        parent[root2] = root1  # Always join roots!
        return True

    # [1] Keep track of disjoint sets. Initially each city is its own set.
    # まずは全部の要素がそれ自信を親として持つ = どこともつながっていない状態
    parent = {city: city for city in range(1, N+1)}
    # [2] Sort connections by cost(connections[2]), so we are always picking minimum cost edge.
    connections.sort(key=lambda x: x[2])
    total = 0
    for city1, city2, cost in connections:  # [3A] 一番辺のコストが小さいものから追加。 
        if union(city1, city2):  # [3B]
            total += cost
    # Check that all cities are connected.
    root = find(1)
    return total if all(root == find(city) for city in range(1, N+1)) else -1

N = 3
connections = [[1,2,5],[1,3,6],[2,3,1]]
print(minimumCost(N, connections))