In [1]:
from dataclasses import dataclass, field

@dataclass
class Node:
    text: str  # some text value.
    children: list["Node"] = field(default_factory=list)
    _path: list[int] = field(default_factory=list)
    _chars: int = 0

    # def __repr__(self, indent_level: int = 0):
    #     base = "  " * indent_level + f"text={self.text} chars={self.chars}"
    #     if len(self.children) > 0:
    #         base += "\n"
    #         for child in self.children:
    #             base += child.__repr__(indent_level + 1)
    #             base += "\n"
    #     return base

In [2]:
def format(node: Node, indent_level: int = 0, cumul_ub=0, parent_path: list[int] = []) -> Node:
    indent = "  " * indent_level
    # Start with role and optional text
    result = f"{indent}{node.text}" 
    # Recursively format children
    if len(node.children) > 0:
        result += " {\n"
        for i, child in enumerate(node.children):
            result += format(child, indent_level + 1, cumul_ub + len(node.text) + 10*(indent_level+1), parent_path + [i])
        result += indent + "}\n"
    else:
        result += "\n"

    node._subtree_chars = len(result)
    node._chars = len(result) + cumul_ub
    node._path = parent_path

    return result

def func(node: Node) -> tuple[Node, str]:
    f = format(node)
    return node, f

In [3]:
tree, f = func(Node(
    text="root-fjeowhfuohewu",
    children=[
        Node(text="mektoub",children=[]),
        Node(text="child0-urwoewjs", children=[Node(text="child0.0-fejwoujfow")]),
        Node(text="child0-urwoewjs", children=[Node(text="child0.0-fejwoujfow")]),
        Node(text="child1-fjoewiklrfn"),
        Node(
            text="child2-ufiejkndsv",
            children=[
                Node(text="child2.1-jfeouwjnkfdsjfkdnkqwfeuwifhuew", children=[Node(text="child2.1.0-jfeouwjnkfdsjfkdnkqwfeuwifhuew")]),
                Node(text="child2.2-jfeouwjnkfdsjfkdnkqw"),
            ],
        ),
    ],
))

In [4]:
def find_node_by_path(root: Node, path: list[int]) -> Node:
    way = path[0] if len(path) > 0 else None
    if way is None:
        return root
    return find_node(root.children[way], path[1:])

In [5]:
def partial_tree_by_path(root: Node, path: list[int]) -> Node:
    way = path[0] if len(path) > 0 else None
    if way is None:
        return root
    subtree = partial_tree_by_path(root.children[way], path[1:])
    return Node(text=root.text, children=[subtree])

In [6]:
print(format(tree))

root-fjeowhfuohewu {
  mektoub
  child0-urwoewjs {
    child0.0-fejwoujfow
  }
  child0-urwoewjs {
    child0.0-fejwoujfow
  }
  child1-fjoewiklrfn
  child2-ufiejkndsv {
    child2.1-jfeouwjnkfdsjfkdnkqwfeuwifhuew {
      child2.1.0-jfeouwjnkfdsjfkdnkqwfeuwifhuew
    }
    child2.2-jfeouwjnkfdsjfkdnkqw
  }
}



In [8]:
print(format(partial_tree_by_path(tree, [4,0,0])))

root-fjeowhfuohewu {
  child2-ufiejkndsv {
    child2.1-jfeouwjnkfdsjfkdnkqwfeuwifhuew {
      child2.1.0-jfeouwjnkfdsjfkdnkqwfeuwifhuew
    }
  }
}



In [9]:
tree.children[0]._path

[0]

In [10]:
def split(node: Node, threshold: int) -> list[tuple[Node, int]]:
    if node._chars <= threshold:
        return [(node._chars, node._path)]
    if len(node.children) == 0:
        raise ValueError("Move to notte.sdk to handle very long context webpages.")
    ls = []
    for child in node.children:
        if child._chars <= threshold:
            ls.append((child._chars, child._path))
        else:
            ls.extend(split(child, threshold))
    return ls

In [11]:
gamma = 165
chars, paths = list(map(list, zip(*split(tree, gamma))))

In [12]:
chars

[38, 76, 76, 49, 165, 99]

In [13]:
paths

[[0], [1], [2], [3], [4, 0], [4, 1]]

In [14]:
def partition_array(arr, gamma):
    n = len(arr)
    
    # dp to track minimum partitions
    # dp[i] will store the minimum number of partitions for subarray up to index i
    dp = [float('inf')] * n

    # stores the indices where partitions start
    partition_starts = [[] for _ in range(n)]
    
    # base case: first element
    if arr[0] <= gamma:
        dp[0] = 1
        partition_starts[0] = [[0]]
    
    # fill dp table
    for i in range(1, n):
        # try all possible last partition starting points
        curr_sum = 0
        for j in range(i, -1, -1):
            curr_sum += arr[j]
            
            # if current sum exceeds gamma, break
            if curr_sum > gamma:
                break
            
            # calculate partitions up to j-1
            prev_partitions = float('inf')
            if j > 0:
                prev_partitions = dp[j-1]
            else:
                prev_partitions = 0
            
            # update if we find a better partitioning
            if prev_partitions + 1 < dp[i]:
                dp[i] = prev_partitions + 1
                
                # copy previous partitions and add current partition
                if j > 0:
                    partition_starts[i] = [p + [j] for p in partition_starts[j-1]]
                else:
                    partition_starts[i] = [[j]]
    
    # find the best partitioning
    if dp[n-1] == float('inf'):
        return []
    
    # reconstruct the actual partitions
    best_partitions = partition_starts[n-1][0]
    best_partitions.append(n)
    
    # convert partition starts to actual partitions
    result_partitions = []
    for start, end in zip(best_partitions[:-1], best_partitions[1:]):
        result_partitions.append(list(range(start, end)))
    
    return result_partitions

In [15]:
def merge_trees(trees: list[Node]) -> Node | None:
    # Check if the list is empty
    if not trees:
        return None
    
    # Check if all trees have the same root text
    root_text = trees[0].text
    if not all(tree.text == root_text for tree in trees):
        raise ValueError("Cannot merge trees with different root texts")
    
    # Create a new merged tree with the same root
    merged_tree = Node(text=root_text)
    
    # Create a dictionary to track children by their text across all trees
    children_map = {}
    
    # Collect children from all trees
    for tree in trees:
        for child in tree.children:
            if child.text not in children_map:
                children_map[child.text] = []
            children_map[child.text].append(child)
    
    # Process each unique child text
    for child_text, child_nodes in children_map.items():
        # If multiple children with the same text exist, recursively merge them
        if len(child_nodes) > 1:
            merged_child = merge_trees(child_nodes)
            merged_tree.children.append(merged_child)
        else:
            # If only one child exists, add it directly
            merged_tree.children.append(child_nodes[0])
    
    return merged_tree

In [16]:
partitions = partition_array(chars, gamma)

In [17]:
partitions

[[0, 1], [2, 3], [4], [5]]

In [18]:
trs = []
for partition in partitions:
    if len(partition) > 1:
        _paths = [paths[i] for i in partition]
        partials = [partial_tree_by_path(tree, path) for path in _paths]
        merged = merge_trees(partials)
        print(format(merged))
        trs.append(merged)
    else:
        path = paths[partition[0]]
        partial = partial_tree_by_path(tree, path)
        print(format(partial))
        trs.append(partial)

root-fjeowhfuohewu {
  mektoub
  child0-urwoewjs {
    child0.0-fejwoujfow
  }
}

root-fjeowhfuohewu {
  child0-urwoewjs {
    child0.0-fejwoujfow
  }
  child1-fjoewiklrfn
}

root-fjeowhfuohewu {
  child2-ufiejkndsv {
    child2.1-jfeouwjnkfdsjfkdnkqwfeuwifhuew {
      child2.1.0-jfeouwjnkfdsjfkdnkqwfeuwifhuew
    }
  }
}

root-fjeowhfuohewu {
  child2-ufiejkndsv {
    child2.2-jfeouwjnkfdsjfkdnkqw
  }
}



In [19]:
format(merge_trees(trs)) == format(tree)

False