In [None]:
"""
Alice and Bob have an undirected graph of n nodes and three types of edges:

Type 1: Can be traversed by Alice only.
Type 2: Can be traversed by Bob only.
Type 3: Can be traversed by both Alice and Bob.
Given an array edges where edges[i] = [typei, ui, vi] represents a bidirectional edge of type typei between nodes ui and vi, find the maximum number of edges you can remove so that after removing the edges, the graph can still be fully traversed by both Alice and Bob. The graph is fully traversed by Alice and Bob if starting from any node, they can reach all other nodes.

Return the maximum number of edges you can remove, or return -1 if Alice and Bob cannot fully traverse the graph.


Example 1:
    Input: n = 4, edges = [[3,1,2],[3,2,3],[1,1,3],[1,2,4],[1,1,2],[2,3,4]]
    Output: 2
    Explanation: If we remove the 2 edges [1,1,2] and [1,1,3]. The graph will still be fully traversable by Alice and Bob. Removing any additional edge will not make it so. So the maximum number of edges we can remove is 2.

Example 2:
    Input: n = 4, edges = [[3,1,2],[3,2,3],[1,1,4],[2,1,4]]
    Output: 0
    Explanation: Notice that removing any edge will not make the graph fully traversable by Alice and Bob.


Example 3:
    Input: n = 4, edges = [[3,2,3],[1,1,2],[2,3,4]]
    Output: -1
    Explanation: In the current graph, Alice cannot reach node 4 from the other nodes. Likewise, Bob cannot reach 1. Therefore it's impossible to make the graph fully traversable.

Constraints:
    1 <= n <= 105
    1 <= edges.length <= min(105, 3 * n * (n - 1) / 2)
    edges[i].length == 3
    1 <= typei <= 3
    1 <= ui < vi <= n
    All tuples (typei, ui, vi) are distinct.
    
TIP:
    Just do the logic. Think if it works fine.
    1. Just variant of Kruskal's
    2. Create connected nodes from common, then check if remaining can connect whole tree in both set.
    3. Union find.
"""

from typing import List
class Solution:
    def maxNumEdgesToRemove(self, n: int, edges: List[List[int]]) -> int:
        np = list(range(n))
        rp = [0] * n
        ec = len(edges)

        edgeL = {1: [], 2: [], 3: []}
        for (t, u, v) in edges:
            edgeL[t].append((u-1, v-1))

        def findp(x):
            px = np[x]
            if px == x:
                return px
            np[x] = findp(px)
            return np[x]

        def union(x, y):
            px = findp(x)
            py = findp(y)
            if px == py:
                return False
            rx = rp[px]
            ry = rp[py]
            if rx < ry:
                np[px] = py
            elif ry < rx:
                np[py] = px
            else:
                np[px] = py
                rp[py] += 1
            return True
        
        def unify_nodes(nodes, nc, cc):
            uc = 0
            for u, v in nodes:
                if findp(u) == findp(v):
                    continue
                union(u, v)
                uc += 1
                if (uc + cc) == nc:
                    break
            return uc
        
        cec = unify_nodes(edgeL[3], n-1, 0)
        if cec == n - 1:
            return ec - cec

        tnp = np[:]
        trp = rp[:]

        # alice
        ac = unify_nodes(edgeL[1], n-1, cec)
        if (cec + ac) != n-1:
            return -1

        np = tnp[:]
        rp = trp[:]      
        # bob
        bc = unify_nodes(edgeL[2], n-1, cec)
        if (cec + bc) != n-1:
            return -1

        return (ec - (cec + ac + bc))

In [None]:


# Try - 1
from typing import List
class Solution:
    def maxNumEdgesToRemove(self, n: int, edges: List[List[int]]) -> int:
        np = list(range(n))
        rp = [0] * n

        def findp(x):
            px = np[x]
            if px == x:
                return px
            np[x] = findp(px)
            return np[x]

        def union(x, y):
            px = findp(x)
            py = findp(y)

            if px == py:
                return False
            
            rx = rp[px]
            ry = rp[py]
            if rx < ry:
                np[px] = py
            elif ry < rx:
                np[py] = px
            else:
                np[px] = py
                rp[py] += 1
            return True
        
        t1, t2, t3 = [], [], []
        for (t, u, v) in edges:
            if t == 1:
                t2.append((u-1, v-1))
            elif t == 2:
                t3.append((u-1, v-1))
            else:
                t1.append((u-1, v-1))

        ec = len(edges)

        cec = 0
        for u, v in t1:
            if findp(u) == findp(v):
                continue
            union(u, v)
            cec += 1
            if cec == (n -1):
                break
        
        if cec == n - 1:
            return ec - (n - 1)

        tnp = np[:]
        trp = rp[:]

        # alice
        ac = 0
        for u, v in t2:
            print(u, v)
            if findp(u) == findp(v):
                continue
            union(u, v)
            ac += 1
            if (cec + ac) == n-1:
                break

        if (cec + ac) != n-1:
            return -1
      

        # bob
        np = tnp[:]
        rp = trp[:]
        bc = 0
        for u, v in t3:
            if findp(u) == findp(v):
                continue
            union(u, v)
            bc += 1
            if (cec + bc) == n-1:
                break

        if (cec + bc) != n-1:
            return -1

        return (ec - (cec + ac + bc))