# Super Maximum Cost Queries
Victoria has a tree $T$, consisting of $N$ nodes numbered from $1$ to $N$. 
Each edge $E_i = (U_i, V_i)$ has an integer weight $W_i$.
Let's define the cost $C$ of a path from node $X$ to node $Y$ as the maximum weight of edge in the unique path from $X$ to $Y$.

Victoria needs help processing $Q$ queries on tree $T$, where each query contains $2$ integers $L$ and $R$, such that $L\leq R$. 
For each query, she wants to print the number of different paths in $T$ that have a cost in the inclusive range $[L, R]$.

In [34]:
class disjointSet:
    def __init__(self, N):
        self.memo = [i for i in range(N)]
    def find(self, a):
        while self.memo[a] != a:
            a = self.memo[a]
        return a
        
    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra != rb:
            self.memo[ra] = rb

    def getComponents(self):
        comps = {}
        for i in range(len(self.memo)):

            r = self.find(i)
            if r in comps:
                comps[r].append(i)
            else:
                comps[r] = [i]
        return comps

In [52]:
def solve_single(tree, L, R):
    ds = disjointSet(len(tree) + 1)
    for h, t, w in tree:
        if w < L:
            ds.union(h - 1, t - 1)
    comps = ds.getComponents()
    # print(comps)
    numPaths_1 = sum([len(comp) * (len(comp) - 1) // 2 for _, comp in comps.items()])

    for h, t, w in tree:
        if w <= R:
            ds.union(h - 1, t - 1)

    comps = ds.getComponents()
    # print(comps)
    numPaths_2 = sum([len(comp) * (len(comp) - 1) // 2 for _, comp in comps.items()])

    return numPaths_2 - numPaths_1

In [61]:
def solve(tree, queries):
    boundaries = set()
    for L, R in queries:
        boundaries.add(L)
        boundaries.add(R + 1)
    boundaries = sorted(list(boundaries))

    tree = sorted(tree, key=lambda a: a[2])
    ds = disjointSet(len(tree) + 1)
    
    j, memo = 0, {}
    for b in boundaries:
        while j < len(tree) and tree[j][2] < b:
            ds.union(tree[j][0] - 1, tree[j][1] - 1)
            j += 1
        comps = ds.getComponents()
        memo[b] = sum([len(comp) * (len(comp) - 1) // 2 for _, comp in comps.items()])
        if j == len(tree):
            break

    result = []
    path_total = len(tree) * (len(tree) + 1) // 2
    for L, R in queries:
        p1, p2 = path_total, path_total 
        if L in memo:   
            p1 = memo[L]
        if R + 1 in memo:   
            p2 = memo[R + 1]
        result.append(p2 - p1)
    return result

In [62]:
V, Q = 5, 5
tree = [
    [1, 2, 3],
    [1, 4, 2],
    [2, 5, 6],
    [3, 4, 1]
]
queries = [
    [1, 1],
    [1, 2],
    [2, 3],
    [2, 5],
    [1, 6]
]


In [63]:
result = solve(tree, queries)

In [64]:
result

[1, 3, 5, 5, 10]