<a href="https://colab.research.google.com/github/sanadv/Algorithms/blob/main/CS_160_MST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import heapq

def prim_jarnik_mst(graph):
    """
    Finds the Minimum Spanning Tree (MST) using Prim-Jarnik Algorithm.

    Parameters:
        graph: A dictionary where keys are nodes and values are lists of tuples (neighbor, weight).

    Returns:
        mst: A list of edges (u, v, weight) that form the MST.
        total_weight: The total weight of the MST.
    """
    # Priority queue for storing (key, vertex)
    pq = []
    vertex_info = {}  # Stores (current_weight, parent_edge) for each vertex
    for vertex in graph.keys():
        vertex_info[vertex] = [float('inf'), None]  # (weight, parent_edge)
        heapq.heappush(pq, (float('inf'), vertex))

    # Arbitrary starting point
    start = next(iter(graph))
    vertex_info[start][0] = 0  # Set key of starting vertex to 0
    heapq.heappush(pq, (0, start))

    mst = []  # To store MST edges
    total_weight = 0
    in_mst = set()  # Track visited vertices

    while pq:
        # Extract vertex with minimum key
        key, u = heapq.heappop(pq)
        if u in in_mst:
            continue
        in_mst.add(u)
        if vertex_info[u][1] is not None:
            # Add edge to MST
            mst.append(vertex_info[u][1])
            total_weight += key

        # Relaxation step
        for v, weight in graph[u]:
            if v not in in_mst and weight < vertex_info[v][0]:
                vertex_info[v] = [weight, (u, v, weight)]  # Update weight and parent edge
                heapq.heappush(pq, (weight, v))

    return mst, total_weight


# Example graph
graph = {
    'A': [('B', 1), ('C', 4)],
    'B': [('A', 1), ('C', 2), ('D', 6)],
    'C': [('A', 4), ('B', 2), ('D', 3)],
    'D': [('B', 6), ('C', 3)]
}

# Run Prim-Jarnik Algorithm
mst, total_weight = prim_jarnik_mst(graph)

# Output the result
print("Edges in the Minimum Spanning Tree:")
for u, v, weight in mst:
    print(f"{u} -- {v} (weight: {weight})")
print(f"\nTotal weight of the MST: {total_weight}")


Edges in the Minimum Spanning Tree:
A -- B (weight: 1)
B -- C (weight: 2)
C -- D (weight: 3)

Total weight of the MST: 6


In [None]:
import heapq

class DisjointSet:
    def __init__(self, vertices):
        self.parent = {v: v for v in vertices}
        self.rank = {v: 0 for v in vertices}

    def find(self, v):
        if self.parent[v] != v:
            self.parent[v] = self.find(self.parent[v])  # Path compression
        return self.parent[v]

    def union(self, v1, v2):
        root1 = self.find(v1)
        root2 = self.find(v2)

        if root1 != root2:
            # Union by rank
            if self.rank[root1] > self.rank[root2]:
                self.parent[root2] = root1
            elif self.rank[root1] < self.rank[root2]:
                self.parent[root1] = root2
            else:
                self.parent[root2] = root1
                self.rank[root1] += 1


def kruskal_with_heapq(graph):
    # Step 1: Initialize a min-heap for edges
    edges = []
    for node in graph:
        for neighbor, weight in graph[node]:
            # Ensure each edge is only added once
            if (neighbor, node, weight) not in edges:
                heapq.heappush(edges, (weight, node, neighbor))

    # Step 2: Initialize disjoint set
    dsu = DisjointSet(graph.keys())

    mst = []  # To store the edges in the MST
    total_weight = 0

    # Step 3: Process edges
    while edges and len(mst) < len(graph) - 1:
        weight, u, v = heapq.heappop(edges)
        if dsu.find(u) != dsu.find(v):
            dsu.union(u, v)
            mst.append((u, v, weight))
            total_weight += weight

    return mst, total_weight


# Graph input
graph = {
    'A': [('B', 1), ('C', 4)],
    'B': [('A', 1), ('C', 2), ('D', 6)],
    'C': [('A', 4), ('B', 2), ('D', 3)],
    'D': [('B', 6), ('C', 3)]
}

# Find MST using heapq
mst, total_weight = kruskal_with_heapq(graph)

# Output the result
print("Edges in the MST:")
for edge in mst:
    print(f"{edge[0]} - {edge[1]}: {edge[2]}")
print(f"Total weight of MST: {total_weight}")
