In [1]:
class Edge:
    """Sortable edge in the graph"""
    def __init__(self, weight, start, target):
        self.weight = weight 
        self.start = start  # Node
        self.target = target  # Node
        
    def __repr__(self):
        return 'Edge(weight={}, start={}, target={})'.format(self.weight,
                                                             self.start,
                                                             self.target)
 
    def __cmp__(self, other):
        return self.cmp(self.weight, other.weight)
 
    def __lt__(self, other):
        return self.weight < other.weight
        
class Node:
    """Node live in a graph / disjoint set"""
    def __init__(self, name):
        self.name = name
        self.parent = None
        self.set_ = None
        
    def __repr__(self):
        return self.name
        parent = None
        if self.parent:
            parent = self.parent.name
        return 'Node(name={}, parent={})'.format(self.name, parent)
        
class DisjointSet:
    """Represent a disjoint set"""
    def __init__(self, node):
        """make set"""
        self.nodes = set([node])
        self.root = node  
        self.root.set_ = self
        
    def __str__(self):
        if not self.nodes:
            return 'Empty'
        return str(self.nodes)
    
    def __len__(self):
        return len(self.nodes)
        
    @staticmethod
    def find(node):
        """Find root node in nodes and do path compression"""
        
        root = node
        while root.parent is not None:
            root = root.parent
        
        # path compression
        while node is not root:
            temp = node.parent
            node.parent = root
            node = temp
           
        return root
    
    @staticmethod
    def merge(s1, s2):
        """Merge two set base on """
        
        if s1 is s2:  # is equal
            return
        
        if len(s1) < len(s2):  # s1 --> s2
            s1.root.parent = s2.root      
            
            for n in s1.nodes:  # point all node to new set
                n.set_ = s2                
            s2.nodes.update(s1.nodes)
            s1.nodes = set()
            
        else:  # s2 --> s1
            s2.root.parent = s1.root
            
            for n in s2.nodes:  # point all node to new set
                n.set_ = s1                
            s1.nodes.update(s2.nodes)
            s2.nodes = set()

In [2]:
class Kruskal:
    def __init__(self, nodes, edges):
        self.spanning_tree = []
        self.edges = edges
        self.edges.sort()  # O(NlogN)
        self.sets = []
        self.nodes = nodes
        for node in nodes:
            self.sets.append(DisjointSet(node))
            
    def run(self):
        
        self.logging(init=True)
        
        for edge in self.edges:
            r1 = DisjointSet.find(edge.start)
            r2 = DisjointSet.find(edge.target)
            
            if r1.set_ is not r2.set_:  # if in different set
                
                self.logging()

                DisjointSet.merge(r1.set_, r2.set_)
                self.spanning_tree.append(edge)

                if len(self.spanning_tree) == len(self.nodes)-1: 
                    # If we have selected n-1 edges, all the other 
                    # edges will be discarted, so, we can stop here
                    self.logging()
                    break
            
    def logging(self, init=False):
        row_format = ''
        length = 0
        for i, set_ in enumerate(self.sets):
            if i == 2:
                row_format += '{!s:>25} '
                length += 26
            else:
                row_format += '{!s:>7} '
                length += 8
        
        if init:
            print(row_format.format(*range(0, len(self.sets))))
            print('=' * length)
        else:
            print(row_format.format(*self.sets))

In [3]:
graph = []

# construct A,B,C,D,E,F,G Nodes
node_str = 'ABCDEFG'
for s in node_str:
    node = Node(s)
    locals()[s] = node
    graph.append(node)
    
# lined nodes
edges = [
    Edge(2, A, B),
    Edge(6, A, C),
    Edge(5, A, E),
    Edge(10, A, F),
    Edge(3, B, D),
    Edge(3, B, E),
    Edge(1, C, D),
    Edge(2, C, F),
    Edge(4, D, E),
    Edge(5, D, G),
    Edge(5, F, G),
]

![image](https://storage.googleapis.com/ssivart/super9-blog/kruskal.png)

In [4]:
algorithm = Kruskal(graph, edges)

## sets merging history

In [5]:
algorithm.run()

      0       1                         2       3       4       5       6 
    {A}     {B}                       {C}     {D}     {E}     {F}     {G} 
    {A}     {B}                    {D, C}   Empty     {E}     {F}     {G} 
 {A, B}   Empty                    {D, C}   Empty     {E}     {F}     {G} 
 {A, B}   Empty                 {F, D, C}   Empty     {E}   Empty     {G} 
  Empty   Empty           {F, D, A, B, C}   Empty     {E}   Empty     {G} 
  Empty   Empty        {F, A, B, C, D, E}   Empty   Empty   Empty     {G} 
  Empty   Empty     {F, A, G, B, C, D, E}   Empty   Empty   Empty   Empty 


## spinning tree edges

In [6]:
for edge in algorithm.spanning_tree:
    print('{} ---> {}'.format(edge.start, edge.target))

C ---> D
A ---> B
C ---> F
B ---> D
B ---> E
D ---> G


## set tree structure

In [7]:
for node in graph:
    parent = None
    if node.parent:
        parent = node.parent.name
    print('{} ---> {}'.format(node.name, parent))

A ---> C
B ---> C
C ---> None
D ---> C
E ---> C
F ---> C
G ---> C
