## Utility Functions for Tree Analysis
This notebook contains utility functions to compute and analyze values related to tree structures, such as node retrieval and distance matrix computation.

### Function: get_nodes_at_depth
Collect all nodes at a specific depth from the root.

In [1]:
import numpy as np
from collections import deque

def get_nodes_at_depth(tree_root, depth):
    """Collect all nodes at a specific depth from the root."""
    nodes = []

    def traverse(node, path, current_depth):
        # If we've reached the desired depth, store the path to this node
        if current_depth == depth:
            nodes.append(path + [node.value])
            return

        # Recurse to children if we haven't reached the depth yet
        for child, _ in node.children:
            traverse(child, path + [node.value], current_depth + 1)

    traverse(tree_root, [], 0)
    return nodes

### Function: compute_distance_matrix_at_depth
Compute the distance (sum of L1 distance between path) matrix up to a specific depth.

In [2]:
def compute_distance_matrix_at_depth(tree_1_root, tree_2_root, depth):
    """Compute the distance matrix up to a specific depth."""
    nodes_at_depth_tree1 = get_nodes_at_depth(tree_1_root, depth)
    nodes_at_depth_tree2 = get_nodes_at_depth(tree_2_root, depth)

    num_nodes_tree1 = len(nodes_at_depth_tree1)
    num_nodes_tree2 = len(nodes_at_depth_tree2)
    distance_matrix = np.zeros((num_nodes_tree1, num_nodes_tree2))

    for i, path1 in enumerate(nodes_at_depth_tree1):
        for j, path2 in enumerate(nodes_at_depth_tree2):
            max_len = max(len(path1), len(path2))
            padded_path1 = path1 + [0] * (max_len - len(path1))
            padded_path2 = path2 + [0] * (max_len - len(path2))
            distance = sum(abs(a - b) for a, b in zip(padded_path1, padded_path2))
            distance_matrix[i, j] = distance

    return distance_matrix

### Function: get_paths_to_leaves
Generate all paths from the root to each leaf node up to a specified depth.

In [3]:
def get_paths_to_leaves(tree_root, max_depth):
    """Generate all paths from the root to each leaf node up a specified depth."""
    paths = []

    def traverse(node, path, depth):
        if depth == max_depth or not node.children:
            paths.append(path + [node.value])
            return

        for child, _ in node.children:
            traverse(child, path + [node.value], depth + 1)

    traverse(tree_root, [], 0)
    return paths

### Function: get_node_from_path
Retrieve the node at the end of a given path from the root.

In [4]:
def get_node_from_path(tree_root, path):
    """Given a root node and a path (list of values), return the node at the end of the path."""
    current_node = tree_root
    for value in path[1:]:  # Start from the second element, as the first is the root itself
        # Find the child with the given value
        current_node = next(child for child, _ in current_node.children if child.value == value)
    return current_node

### Function: find_node_by_path
Traverse the tree following a given path and return the final node.

In [5]:
def find_node_by_path(node, path):
    """Traverses the tree following the given path and returns the final node."""
    current_node = node
    for value in path[1:]:  # The first value is the root
        # Find all children with the matching value
        matching_children = [child for child, _ in current_node.children if child.value == value]
        
        if not matching_children:
            print(f"Invalid path: No child with value {value} under node with value {current_node.value}. Path so far: {path}")
            return None  # Path is invalid
        
        if len(matching_children) > 1:
            print(f"Warning: Multiple children with value {value} found under node with value {current_node.value}. Path: {path}")
        
        # Select the first matching child
        current_node = matching_children[0]
    
    return current_node

### Function: compute_marginal_probabilities_for_subset
Compute marginal probabilities for the direct successors of node1 and node2.

In [6]:
def compute_marginal_probabilities_for_subset(node1_path, node2_path, tree_1_root, tree_2_root):
    """Compute marginal probabilities for the direct successors of node1 and node2."""
    # Get the actual node objects at the end of each path
    node1 = get_node_from_path(tree_1_root, node1_path)
    node2 = get_node_from_path(tree_2_root, node2_path)

    # Get direct successors and their transition probabilities
    successors_node1 = [(node1_path + [child.value], prob) for child, prob in node1.children]
    successors_node2 = [(node2_path + [child.value], prob) for child, prob in node2.children]

    # Compute marginal probabilities for each set of successors
    pi_ratios = [prob for _, prob in successors_node1]
    pi_tilde_ratios = [prob for _, prob in successors_node2]

    return pi_ratios, pi_tilde_ratios

### Solver for Optimal Probability Matrix
This function computes the optimal probability matrix that minimizes the cost defined by a given subset of the distance matrix. It uses `scipy.optimize.linprog` to solve the associated linear program, ensuring the constraints on rows and columns (derived from `pi_ratios` and `pi_tilde_ratios`) are satisfied.

In [7]:
from scipy.optimize import linprog

def solver(distance_matrix_subset, pi_ratios, pi_tilde_ratios):
    """
    Solve for the optimal probability matrix that minimizes the cost when
    multiplied with the distance_matrix_subset.
    """
    num_rows, num_cols = distance_matrix_subset.shape

    # Flatten the distance matrix to use it as the cost vector in linprog
    c = distance_matrix_subset.flatten()

    # Constraints
    A_eq = []
    b_eq = []

    # Row constraints: each row should sum to the corresponding value in pi_ratios
    for i in range(num_rows):
        row_constraint = [0] * (num_rows * num_cols)
        for j in range(num_cols):
            row_constraint[i * num_cols + j] = 1
        A_eq.append(row_constraint)
        b_eq.append(pi_ratios[i])

    # Column constraints: each column should sum to the corresponding value in pi_tilde_ratios
    for j in range(num_cols):
        col_constraint = [0] * (num_rows * num_cols)
        for i in range(num_rows):
            col_constraint[i * num_cols + j] = 1
        A_eq.append(col_constraint)
        b_eq.append(pi_tilde_ratios[j])

    # Bounds: each entry in the probability matrix should be non-negative
    bounds = [(0, None)] * (num_rows * num_cols)

    # Solve the linear program
    result = linprog(c, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method='highs')
    
    # Reshape the result back into a matrix of shape (num_rows, num_cols)
    probability_matrix = result.x.reshape(num_rows, num_cols)
    
    return probability_matrix