# Minimum Spanning Tree


## Helpers


### Disjoint Set


In [105]:
class DisjointSet:
    def __init__(self, vertices: list) -> None:
        self.vertices = vertices
        self.parent  = {}
        for vertice in vertices:
            self.parent[vertice] = vertice
        self.rank = dict.fromkeys(vertices, 0)
    
    def find(self, item):
        if self.parent[item] == item:
            return item
        else:
            return self.find(item = self.parent[item])
    
    def union(self, x, y) -> None:
        xroot = self.find(x)
        yroot = self.find(y)

        if self.rank[xroot] < self.rank[yroot]:
            self.parent[xroot] = yroot
        elif self.rank[xroot] > self.rank[yroot]:
            self.parent[yroot] = xroot
        else:
            self.parent[yroot] = xroot
            self.rank[xroot] += 1

### Graph


In [106]:
class Graph:
    def __init__(self, no_of_vertices: int, graph: list = [], nodes: list = []) -> None:
        self.noOfVertices = no_of_vertices
        self.graph = graph
        self.nodes = nodes
        self.MST = []
    
    def add_edge(self, start_node, end_node, weight) -> None:
        self.graph.append([start_node, end_node, weight])
    
    def add_node(self, value) -> None:
        self.nodes.append(value)
    
    def print_solution(self) -> None:
        for startNode, endNode, weight in self.MST:
            print('%s -> %s: %s' % (startNode, endNode, weight))

## Algorithms


### Kruskal's Algorithm


In [107]:
def KruskalAlgo(customGraph: Graph) -> None:
    i, e = 0, 0

    dset = DisjointSet(vertices = customGraph.nodes)
    customGraph.graph = sorted(customGraph.graph, key = lambda item: item[2])
    while e < customGraph.noOfVertices - 1:
        startNode, endNode, weight = customGraph.graph[i]
        i += 1
        x = dset.find(item = startNode)
        y = dset.find(item = endNode)
        if x != y:
            e += 1
            customGraph.MST.append([startNode, endNode, weight])
            dset.union(x = x, y = y)
    customGraph.print_solution()

### Prim's Algorithm


In [108]:
import math
import sys

def PrimsAlgo(customGraph: Graph) -> None:
    visited = [0] * customGraph.noOfVertices
    edgeNum = 0
    visited[0] = True
    while edgeNum < customGraph.noOfVertices - 1:
        min = sys.maxsize
        for i in range(customGraph.noOfVertices):
            if visited[i]:
                for j in range(customGraph.noOfVertices):
                    if ((not visited[j]) and customGraph.graph[i][j]):
                        if min > customGraph.graph[i][j]:
                            min = customGraph.graph[i][j]
                            s = i
                            d = j
        customGraph.MST.append([customGraph.nodes[s], customGraph.nodes[d], customGraph.graph[s][d]])
        visited[d] = True
        edgeNum += 1
    customGraph.print_solution()


# Main


In [109]:
def makeKruskalGraph() -> Graph:
    g = Graph(no_of_vertices = 5)
    
    g.add_node('A')
    g.add_node('B')
    g.add_node('C')
    g.add_node('D')
    g.add_node('E')

    g.add_edge(start_node = 'A', end_node = 'B', weight = 5)
    g.add_edge(start_node = 'A', end_node = 'C', weight = 13)
    g.add_edge(start_node = 'A', end_node = 'E', weight = 15)
    
    g.add_edge(start_node = 'B', end_node = 'A', weight = 5)
    g.add_edge(start_node = 'B', end_node = 'C', weight = 10)
    g.add_edge(start_node = 'B', end_node = 'D', weight = 8)
    
    g.add_edge(start_node = 'C', end_node = 'A', weight = 13)
    g.add_edge(start_node = 'C', end_node = 'B', weight = 10)
    g.add_edge(start_node = 'C', end_node = 'D', weight = 6)
    g.add_edge(start_node = 'C', end_node = 'E', weight = 20)
    
    g.add_edge(start_node = 'D', end_node = 'B', weight = 8)
    g.add_edge(start_node = 'D', end_node = 'C', weight = 6)
    
    g.add_edge(start_node = 'E', end_node = 'A', weight = 15)
    g.add_edge(start_node = 'E', end_node = 'C', weight = 20)

    return g

def makePrimsGraph() -> Graph:
    edges = [[  0,  10, 20, 0,  0],
		    [   10, 0,  30, 5,  0],
		    [   20, 30, 0,  15, 6],
		    [   0,  5,  15, 0,  8],
		    [   0,  0,  6,  8,  0]]
    
    nodes = ['A', 'B', 'C', 'D', 'E']

    g = Graph(no_of_vertices = 5, graph = edges, nodes = nodes)

    return g
    
    

if __name__ == '__main__':
    # vertices1 = ['A', 'B', 'C', 'D', 'E']

    # ds = DisjointSet(vertices = vertices1)

    # ds.union(x = 'A', y = 'B')
    # ds.union(x = 'A', y = 'C')

    # print(ds.find(item = 'C'))

    # g = makeKruskalGraph()

    # KruskalAlgo(customGraph = g)

    h = makePrimsGraph()

    PrimsAlgo(customGraph = h)


A -> B: 10
B -> D: 5
D -> E: 8
E -> C: 6
