# Prim's algorithm (Finding MST [Minimum Spanning Tree] )

> # Min-Priority Queue for Prime's algorithm

In [1]:
import heapq

class PriorityQueue:
    def __init__(self):
        self._queue = []
        self._entry_finder = {} #key : task, value : entry=[priority, self._counter, task]
        self._REMOVED = '<removed-task>'
        self._counter = 0

    def add_task(self, task, priority=0):
        ## remove the origin one
        if task in self._entry_finder:
            self.remove_task(task)
        entry = [priority, self._counter, task] #counter는 priority가 같을 경우에 순서를 비교하기 위한 용도
        self._counter += 1
        self._entry_finder[task] = entry
        heapq.heappush(self._queue, entry)

    def remove_task(self, task):
        entry = self._entry_finder.pop(task)
        entry[-1] = self._REMOVED #heap 내에 있는 entry도 수정이 됨(alias이기 때문)

    def pop_task(self):
        while self._queue:
            priority, count, task = heapq.heappop(self._queue)
            if task is not self._REMOVED:
                del self._entry_finder[task]
                return task
        raise KeyError('pop from an empty priority queue')

> ## Prim's algorithm 

>> vertices : 0~(num_vertex-1)  
>> adjacent : dict[vertex:int, List[Tuple[vertex:int, weight:float]]]  
>> key : List[float] minimum weight of edges  
>> pi : List[int]  pi[u] predecessor of u on the minimum weight path from source to u

In [43]:
from typing import Dict, List, Tuple
INF=float('inf')


def mst_prim(num_vertex:int, adjacent:Dict[int, List[Tuple[int, float]]], source:int) :
    key=[INF]*num_vertex #minimum weight of edges(update on the algorithm)
    pi=[None]*num_vertex #predecessor on the minimum weight path from source to u 
    
    key[source]=0
    pq=PriorityQueue()
    set_for_popped=set()
    for u in range(num_vertex) :
        pq.add_task(task=u, priority=key[u])
    # while len(set_for_popped)<num_vertex :
    while pq._entry_finder :
        u=pq.pop_task()
        print(f"({num_to_char[u]}) is popped.")
        set_for_popped.add(u)
        for (v, weight) in adjacent[u] :
            if ((v not in set_for_popped) and weight < key[v]) :
                pi[v]=u
                key[v]=weight
                pq.add_task(v, weight)
    return key, pi
    

In [46]:
#example
num_vertex=9
adjacent=dict()
for i in range(num_vertex) :
    adjacent[i]=[]
input="""a b 4
b a 4
a h 8
h a 8
b h 11
h b 11
b c 8
c b 8
c i 2
i c 2
h i 7
i h 7
h g 1
g h 1
g i 6
i g 6
g f 2
f g 2
c f 4
f c 4
c d 7
d c 7
d f 14
f d 14
d e 9
e d 9
f e 10
e f 10"""

lines=input.splitlines()



# 문자를 숫자로 매핑하는 딕셔너리 생성 a~i -> 0~8
char_to_num = {chr(97 + i): i for i in range(9)}

# 숫자를 문자로 매핑하는 딕셔너리 생성
num_to_char = {i: chr(97 + i) for i in range(9)}

for i in range(len(lines)):
    u_char, v_char = map(str, lines[i].split()[:2]) 
    weight = (int)(lines[i].split()[2])
    u=char_to_num[u_char]
    v=char_to_num[v_char]
    adjacent[u].append((v, weight))

key, pi = mst_prim(num_vertex, adjacent, 0)
print(f"key is {key}")
print(f"pi is {pi}")
print(f"Sum of Edges in MST : {sum(key)}")
print()
print("PI    :", end="")
for i in range(num_vertex) :
    if pi[i]==None : print("X   ", end="")
    else:  print(f"{num_to_char[pi[i]]}   ", end="")
print()
print("VERTEX:", end="")
for i in range(num_vertex) :
    print(f"{num_to_char[i]}   ", end="")

(a) is popped.
(b) is popped.
(h) is popped.
(g) is popped.
(f) is popped.
(c) is popped.
(i) is popped.
(d) is popped.
(e) is popped.
key is [0, 4, 4, 7, 9, 2, 1, 8, 2]
pi is [None, 0, 5, 2, 3, 6, 7, 0, 2]
Sum of Edges in MST : 37

PI    :X   a   f   c   d   g   h   a   c   
VERTEX:a   b   c   d   e   f   g   h   i   

In [47]:
# bing chat's
import heapq

def prim(graph, start):
    mst = []
    visited = set()
    edges = [(0, start)]
    while edges:
        weight, node = heapq.heappop(edges)
        if node not in visited:
            visited.add(node)
            mst.append((weight, node))
            for edge in graph[node]:
                if edge[1] not in visited:
                    heapq.heappush(edges, edge)
    return mst

# 예제 그래프
graph = {
    'a': [(2, 'b'), (3, 'c')],
    'b': [(2, 'a'), (3, 'c'), (5, 'd')],
    'c': [(3, 'a'), (3, 'b'), (1, 'd')],
    'd': [(5, 'b'), (1, 'c')]
}

mst = prim(graph, 'a')
print("Minimum Spanning Tree:")
for edge in mst:
    print(edge)

Minimum Spanning Tree:
(0, 'a')
(2, 'b')
(3, 'c')
(1, 'd')
