## Kruskal's Algorithm
- Graph has exactly one MST weight, but could be multiple MSTs with this weight
- Critical edge: Edge that if removed from graph, would increase MST weight, so must be in MST
- Pseudo-critical edge: Edge that can appear in some MSTs but not all, means not necessary to maintain weight, but can include without increasing weight

1. Sort all edges in increasing order of weight
2. Implement Kruskal's to find MST weight
    - Repeatedly select smallest edge that doesn't form cycle with edges already in MST
    - Use Union Find to keep track of disjoint subsets
3. For each edge, to identify if critical, remove from graph and re-calculate MST weight using Kruskal's
    - If weight increases or graph is disconnected, edge is critical
    - To check if pseudo-critical, check if not critical and run Kruskal's forcing edge to be part of tree
        - if final weight is same as MST weight, edge is part of at least one MSt and thus is speduo-critical

### Algorithm
1. Organize all edges by edge weight
2. Initialize union find and calculate standard MST weight
3. Iterate over each edge
    - Calculate MST weight without edge
        - If resulting MST is disconnected or total weight larger than standard MST is critical
    - Calcualte MST weight with edge
        - If total weight of resulting is same as standard, edge is pseudo-critical

In [None]:
class UnionFind:
    def __init__(self, size):
        self.root = [node for node in range(size)]
        self.rank = [0] * size
        self.num_sets = size
    
    def find(self, node):
        if self.root[node] == node:
            return node

        self.root[node] = self.find(self.root[node])
        return self.root[node]

    def union(self, node1, node2):
        root1 = self.find(node1)
        root2 = self.find(node2)

        if root1 == root2:
            return False
        
        if self.rank[root1] > self.rank[root2]:
            self.root[root2] = root1
        elif self.rank[root1] < self.rank[root2]:
            self.root[root1] = root2
        else:
            self.root[root1] = root2
            self.rank[root2] += 1
        self.num_sets -= 1

        return True

class Solution:
    def find_critical_and_pseudo_critical_edges(self, n, edges):
        sorted_edges = sorted([[edge[0], edge[1], edge[2], i] for i, edge in enumerate(edges)], key=lambda edge: edge[2])
    
        union_find = UnionFind(n)
        mst_cost = 0
        for node1, node2, cost, _ in sorted_edges:
            if union_find.union(node1, node2):
                mst_cost += cost
        
        critical = []
        pseudo_critical = []
        for edge1_node1, edge1_node2, cost1, i in sorted_edges:
            union_find_ignore = UnionFind(n)
            ignore_cost = 0
            for edge2_node1, edge2_node2, cost2, j in sorted_edges:
                if i != j and union_find_ignore.union(edge2_node1, edge2_node2):
                    ignore_cost += cost2
            if union_find_ignore.num_sets != 1 or ignore_cost > mst_cost:
                critical.append(i)
                continue

            union_find_include = UnionFind(n)
            include_cost = cost1
            union_find_include.union(edge1_node1, edge1_node2)
            for edge2_node1, edge2_node2, cost2, j in sorted_edges:
                if i != j and union_find_include.union(edge2_node1, edge2_node2):
                    include_cost += cost2
            if include_cost == mst_cost:
                pseudo_critical.append(i)
        
        return [critical, pseudo_critical]