In [18]:
from collections import deque

def heavy_light_decomposition(tree):
    def dfs_size(v, parent):
        subtree_size[v] = 1
        for child in tree[v]:
            if child != parent:
                dfs_size(child, v)
                subtree_size[v] += subtree_size[child]

    def decompose_chain(v, chain_head, parent):
        chain[v] = chain_head
        pos_in_chain[v] = len(chains[chain_head])
        chains[chain_head].append(v)

        max_subtree_size_child = -1
        heavy_child = -1

        for child in tree[v]:
            if child != parent:
                if subtree_size[child] > max_subtree_size_child:
                    max_subtree_size_child = subtree_size[child]
                    heavy_child = child

        if heavy_child != -1:
            decompose_chain(heavy_child, chain_head, v)

        for child in tree[v]:
            if child != parent and child != heavy_child:
                chains.append([])  # Initialize a new chain
                decompose_chain(child, len(chains) - 1, v)

    n = len(tree)
    subtree_size = [0] * n
    chain = [-1] * n
    pos_in_chain = [-1] * n
    chains = [[]]  # Initialize with a single empty chain
    chain_head = 0

    dfs_size(0, -1)
    decompose_chain(0, chain_head, -1)

    return chains

# Example usage:
tree = {
    0: [1, 2],
    1: [3, 4],
    2: [5, 6],
    3: [],
    4: [],
    5: [],
    6: []
}

heavy_chains = heavy_light_decomposition(tree)
for i, chain in enumerate(heavy_chains):
    print(f"Heavy Chain {i}: {chain}")


Heavy Chain 0: [0, 1, 3]
Heavy Chain 1: [4]
Heavy Chain 2: [2, 5]
Heavy Chain 3: [6]


In [6]:
from collections import defaultdict

class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None

def build_balanced_bst(vertices):
    def build_bst(left, right):
        if left > right:
            return None

        mid = (left + right) // 2
        root = TreeNode(vertices[mid])
        root.left = build_bst(left, mid - 1)
        root.right = build_bst(mid + 1, right)

        return root

    n = len(vertices)
    return build_bst(0, n - 1)

def construct_tree_from_decomposition(tree, heavy_chains):
    path_to_bst = {}
    for chain_id, chain in enumerate(heavy_chains):
        vertices = []

        # Extract vertices for the chain
        for vertex in chain:
            vertices.append(vertex)

        # Build a balanced BST for the chain
        bst = build_balanced_bst(vertices)

        # Store the BST for the chain
        path_to_bst[chain_id] = bst

    return path_to_bst

def print_bst(root):
    if root:
        print_bst(root.left)
        print(root.val)
        print_bst(root.right)


# Example usage:
tree = {
    0: [1, 2],
    1: [3, 4],
    2: [5, 6],
    3: [7, 8],
    4: [],
    5: [9, 10],
    6: [],
    7: [],
    8: [11],
    9: [],
    10: [],
    11: []
}

heavy_chains = heavy_light_decomposition(tree)
path_to_bst = construct_tree_from_decomposition(tree, heavy_chains)

# Now you can access the balanced BSTs for each chain using path_to_bst
for chain_id, bst in path_to_bst.items():
    print(f"BST for Chain {chain_id}:")
    print_bst(bst)



BST for Chain 0:
0
1
3
8
11
BST for Chain 1:
7
BST for Chain 2:
4
BST for Chain 3:
2
5
9
BST for Chain 4:
10
BST for Chain 5:
6
