In [26]:
import numpy as np
import time

# Assume the modified tree functions (TreeNode, build_tree_from_paths, etc.) have been defined or imported.
# For example, TreeNode can remain unchanged:
class TreeNode:
    def __init__(self, value):
        self.value = value
        self.children = []
    def add_child(self, child_node, probability):
        self.children.append((child_node, probability))

In [27]:
import numpy as np

# A helper to convert a value to a hashable type.
def to_hashable(x):
    if isinstance(x, np.ndarray):
        return tuple(x.tolist())
    return x

# Modified pad_paths: if the node value is a vector, pad with a zero vector.
def pad_paths(paths, pad_value=None):
    max_length = max(len(path) for path in paths)
    # Determine pad_value based on the first element if not provided.
    if pad_value is None:
        first_val = paths[0][0]
        if isinstance(first_val, (np.ndarray, list)):
            first_arr = np.array(first_val)
            pad_value = np.zeros(first_arr.shape)
        else:
            pad_value = 0
    padded = np.array([path + [pad_value] * (max_length - len(path)) for path in paths])
    return padded

# Modified build_tree_from_paths: works with vector node values.
def build_tree_from_paths(sample_paths, weights):
    """
    Builds a weighted tree from sample paths. Each path is assumed to be a list/array
    where each element can be a scalar (d=1) or a vector (d>1).
    """
    # Use to_hashable for checking the starting value.
    start_value = sample_paths[0][0]
    start_key = to_hashable(start_value)
    for path in sample_paths:
        if to_hashable(path[0]) != start_key:
            raise ValueError("All sample paths must have the same value at time step 0.")
    
    total_weight = sum(weights)
    if abs(total_weight - 1.0) > 1e-6:
        raise ValueError("The sum of weights must equal 1. Got sum(weights) = {}".format(total_weight))
    
    tree_dict = {"value": start_value, "children": {}}
    for path, path_weight in zip(sample_paths, weights):
        current = tree_dict
        for value in path[1:]:
            key = to_hashable(value)
            if key not in current["children"]:
                current["children"][key] = {"node": {"value": value, "children": {}}, "weight": 0.0}
            current["children"][key]["weight"] += path_weight
            current = current["children"][key]["node"]
    
    def convert_tree_dict(node_dict):
        node = TreeNode(node_dict["value"])
        children = node_dict["children"]
        if children:
            total = sum(child_info["weight"] for child_info in children.values())
            for child_key, child_info in children.items():
                child_node = convert_tree_dict(child_info["node"])
                probability = child_info["weight"] / total if total > 0 else 0
                node.add_child(child_node, probability)
        return node

    return convert_tree_dict(tree_dict)

# Modified find_node_by_path: uses to_hashable for vector comparisons.
def find_node_by_path(node, path):
    current_node = node
    for value in path[1:]:
        key = to_hashable(value)
        matching_children = [
            child for child, _ in current_node.children if to_hashable(child.value) == key
        ]
        if not matching_children:
            print(f"Invalid path: No child with value {value} under node with value {current_node.value}.")
            return None
        if len(matching_children) > 1:
            print(f"Warning: Multiple children with value {value} found under node with value {current_node.value}.")
        current_node = matching_children[0]
    return current_node

# Modified uniform_empirical_grid_measure: supports both 2D (d=1) and 3D (d>1) sample paths.
def uniform_empirical_grid_measure(data, delta_n=None, use_weights=False):
    # If data is 2D, each path is a 1D sequence (d = 1).
    if data.ndim == 2:
        num_path, t = data.shape
        if delta_n is None:
            delta_n = 1 / (num_path ** (1 / t))
        grid_func = lambda x: np.floor(x / delta_n + 0.5) * delta_n
        quantized_data = grid_func(data)
        quantized_data[:, 0] = data[:, 0]
        if not use_weights:
            return quantized_data
        else:
            unique_paths, indices, counts = np.unique(
                quantized_data, axis=0, return_inverse=True, return_counts=True
            )
            weights = counts / num_path
            return unique_paths, weights
    # If data is 3D, each path has shape (T+1, d).
    elif data.ndim == 3:
        num_path, t, d = data.shape
        if delta_n is None:
            # For multi-dimensional paths, one may adjust the exponent as needed.
            delta_n = 1 / (num_path ** (1 / t))
        grid_func = lambda x: np.floor(x / delta_n + 0.5) * delta_n
        quantized_data = grid_func(data)
        quantized_data[:, 0, :] = data[:, 0, :]
        if not use_weights:
            return quantized_data
        else:
            # Convert each sample path (2D array) into a hashable tuple of tuples.
            quantized_paths = [tuple(map(tuple, quantized_data[i])) for i in range(num_path)]
            # Preserve the original order.
            unique_paths_list = []
            counts_list = []
            for p in quantized_paths:
                if p in unique_paths_list:
                    counts_list[unique_paths_list.index(p)] += 1
                else:
                    unique_paths_list.append(p)
                    counts_list.append(1)
            weights = np.array(counts_list) / num_path
            # Convert unique paths back to numpy arrays.
            unique_paths = np.array([np.array(up) for up in unique_paths_list])
            return unique_paths, weights
    else:
        raise ValueError("Data must be either 2D (d=1) or 3D (d>1).")

# Modified compute_distance_matrix_at_depth: if node values are vectors, use Euclidean norm.
def compute_distance_matrix_at_depth(tree_1_root, tree_2_root, depth, power):
    nodes_at_depth_tree1 = get_nodes_at_depth(tree_1_root, depth)
    nodes_at_depth_tree2 = get_nodes_at_depth(tree_2_root, depth)
    arr1 = pad_paths(nodes_at_depth_tree1)  # shape: (m, L) or (m, L, d)
    arr2 = pad_paths(nodes_at_depth_tree2)  # shape: (n, L) or (n, L, d)
    
    if arr1.ndim == 2:  # d = 1 (scalar values)
        diff = np.abs(arr1[:, None, :] - arr2[None, :, :]) ** power
        distance_matrix = diff.sum(axis=2)
    elif arr1.ndim == 3:  # d > 1 (vector values)
        diff = arr1[:, None, :, :] - arr2[None, :, :, :]  # shape (m, n, L, d)
        diff_norm = np.linalg.norm(diff, axis=3) ** power  # norm over the d-dimensions at each time step
        distance_matrix = diff_norm.sum(axis=2)
    return distance_matrix


In [28]:
# Helper function to compare two node values.
def value_equal(v1, v2):
    return to_hashable(v1) == to_hashable(v2)

def get_nodes_at_depth(tree_root, depth):
    """Collect all nodes at a specific depth from the root.
       Returns each path as a list of node values (which can be scalars or vectors)."""
    nodes = []

    def traverse(node, path, current_depth):
        if current_depth == depth:
            nodes.append(path + [node.value])
            return
        for child, _ in node.children:
            traverse(child, path + [node.value], current_depth + 1)

    traverse(tree_root, [], 0)
    return nodes


def get_paths_to_leaves(tree_root, max_depth):
    """Generate all paths from the root to each leaf node up to a specified depth.
       Each path is returned as a list of node values."""
    paths = []

    def traverse(node, path, depth):
        if depth == max_depth or not node.children:
            paths.append(path + [node.value])
            return
        for child, _ in node.children:
            traverse(child, path + [node.value], depth + 1)

    traverse(tree_root, [], 0)
    return paths


def get_node_from_path(tree_root, path):
    """Given a root node and a path (list of values), return the node at the end of the path.
       Uses value_equal() to compare node values, ensuring consistency for both d=1 and d>1."""
    current_node = tree_root
    for value in path[1:]:
        current_node = next(
            child for child, _ in current_node.children if value_equal(child.value, value)
        )
    return current_node


def compute_marginal_probabilities_for_subset(
    node1_path, node2_path, tree_1_root, tree_2_root
):
    """Compute marginal probabilities for the direct successors of node1 and node2.
       It extracts the nodes at the end of the paths and then builds their successor paths."""
    node1 = get_node_from_path(tree_1_root, node1_path)
    node2 = get_node_from_path(tree_2_root, node2_path)

    successors_node1 = [
        (node1_path + [child.value], prob) for child, prob in node1.children
    ]
    successors_node2 = [
        (node2_path + [child.value], prob) for child, prob in node2.children
    ]

    pi_ratios = [prob for _, prob in successors_node1]
    pi_tilde_ratios = [prob for _, prob in successors_node2]

    return pi_ratios, pi_tilde_ratios

In [29]:
def get_sample_paths(tree_root):
    """
    Extracts all possible paths from the root to the leaves of the tree
    along with their associated probabilities.

    Parameters:
    - tree_root (TreeNode): The root node of the tree.

    Returns:
    - tuple: A tuple containing:
        - paths_array (np.ndarray): 2D array where each row is a path.
        - probabilities_array (np.ndarray): 1D array of path probabilities.
    """
    paths = []
    probabilities = []

    def traverse(node, current_path, current_prob):
        """
        Recursively traverses the tree to collect paths and probabilities.

        Parameters:
        - node (TreeNode): The current node being traversed.
        - current_path (list): The path taken to reach the current node.
        - current_prob (float): The cumulative probability of the current path.
        """
        new_path = current_path + [node.value]

        if not node.children:
            paths.append(new_path)
            probabilities.append(current_prob)
            return

        for child, prob in node.children:
            traverse(child, new_path, current_prob * prob)

    # Initialize traversal from the root
    traverse(tree_root, [], 1.0)

    # Convert lists to NumPy arrays
    paths_array = np.array(paths)
    probabilities_array = np.array(probabilities)

    return [paths_array, probabilities_array]


def display_tree_data(paths_weights, tree_name):
    """
    Displays the paths and their associated probabilities for a given tree.

    Parameters:
    - paths_weights (tuple): A tuple containing paths and probabilities.
    - tree_name (str): The name of the tree for display purposes.
    """
    print(f"\n{tree_name} (Path and Weight Format):")
    print("Paths:")
    print(paths_weights[0])
    print("Weights:")
    print(paths_weights[1])


def get_depth(tree_root):
    """
    Calculates the depth (height) of the tree, starting at 0 for the root node.

    Parameters:
    - tree_root (TreeNode): The root node of the tree.

    Returns:
    - int: The depth of the tree.
    """
    if tree_root is None:
        return -1  # Assuming an empty tree has a depth of -1 (convenient for us)

    # If the node has no children, its depth is 0
    if not tree_root.children:
        return 0

    # Since all paths have the same depth, we can check just one path
    first_child, _ = tree_root.children[0]
    return 1 + get_depth(first_child)

In [30]:

def nested_optimal_transport_loop(
    tree1_root, tree2_root, max_depth, method, lambda_reg, power
):
    """
    Computes the nested optimal transport plan between two trees.

    Parameters:
    - tree1_root (TreeNode): Root of the first tree.
    - tree2_root (TreeNode): Root of the second tree.
    - max_depth (int): Maximum depth to compute.
    - method (str): Solver method: "Sinkhorn", "solver_lp", or "solver_pot".
    - lambda_reg (float): Regularization parameter for Sinkhorn (only used if method="Sinkhorn").

    Returns:
    - float: Computed nested distance.
    - dict: Dictionary of probability matrices for each step.
    """
    if method == "Sinkhorn" and lambda_reg <= 0:
        raise ValueError("Lambda must be positive when using Sinkhorn iteration.")
    elif method not in (
        "solver_lp",
        "solver_pot",
        "Sinkhorn",
        "solver_sinkhorn",
        "solver_lp_pot",
        "solver_pot_sinkhorn",
        "solver_jax",
    ):
        raise ValueError(
            "Method must be one of 'Sinkhorn', 'solver_lp', 'solver_jax', 'solver_lp_pot', 'solver_pot_sinkhorn' or 'solver_pot'."
        )

    probability_matrices = {}
    full_distance_matrix = compute_distance_matrix_at_depth(
        tree1_root, tree2_root, max_depth, power
    )

    for depth in range(max_depth - 1, -1, -1):
        paths_tree1 = get_nodes_at_depth(tree1_root, depth)
        paths_tree2 = get_nodes_at_depth(tree2_root, depth)

        nodes_tree1 = [find_node_by_path(tree1_root, path) for path in paths_tree1]
        nodes_tree2 = [find_node_by_path(tree2_root, path) for path in paths_tree2]

        children_count_tree1 = [len(node.children) for node in nodes_tree1 if node]
        children_count_tree2 = [len(node.children) for node in nodes_tree2 if node]

        updated_distance_matrix = np.zeros((len(paths_tree1), len(paths_tree2)))

        for i, path1 in enumerate(paths_tree1):
            for j, path2 in enumerate(paths_tree2):
                step_name = (depth, to_hashable(path1[-1]), to_hashable(path2[-1]))

                start_row, end_row = sum(children_count_tree1[:i]), sum(
                    children_count_tree1[: i + 1]
                )
                start_col, end_col = sum(children_count_tree2[:j]), sum(
                    children_count_tree2[: j + 1]
                )
                sub_matrix = full_distance_matrix[start_row:end_row, start_col:end_col]

                pi_ratios, pi_tilde_ratios = compute_marginal_probabilities_for_subset(
                    path1, path2, tree1_root, tree2_root
                )

                if method == "Sinkhorn":
                    probability_matrix = Sinkhorn_iteration(
                        sub_matrix,
                        pi_ratios,
                        pi_tilde_ratios,
                        stopping_criterion=1e-4,
                        lambda_reg=lambda_reg,
                    )
                elif method == "solver_lp_pot":
                    probability_matrix = solver_lp_pot(
                        sub_matrix, pi_ratios, pi_tilde_ratios
                    )
                elif method == "solver_lp":
                    probability_matrix = solver_lp(
                        sub_matrix, pi_ratios, pi_tilde_ratios
                    )
                elif method == "solver_jax":
                    probability_matrix = solver_jax(
                        sub_matrix, pi_ratios, pi_tilde_ratios, epsilon=(1 / lambda_reg)
                    )
                elif method == "solver_pot_sinkhorn":
                    probability_matrix = solver_pot_sinkhorn(
                        sub_matrix, pi_ratios, pi_tilde_ratios, epsilon=(1 / lambda_reg)
                    )
                elif method == "solver_pot_1D":
                    probability_matrix = solver_pot_1D(
                        sub_matrix, pi_ratios, pi_tilde_ratios
                    )
                else:
                    probability_matrix = solver_pot(
                        sub_matrix, pi_ratios, pi_tilde_ratios
                    )

                cost = np.sum(probability_matrix * sub_matrix)
                probability_matrices[step_name] = probability_matrix
                updated_distance_matrix[i, j] = cost

        full_distance_matrix = updated_distance_matrix

    return full_distance_matrix[0][0], probability_matrices


# DO NOT USE FOR BIG PROBLEMS
def compute_final_probability_matrix(
    probability_matrices, tree1_root, tree2_root, max_depth
):
    """
    Combines probability matrices along all paths to compute the final probability matrix.

    Parameters:
    - probability_matrices (dict): Dictionary of probability matrices for each step.
    - tree1_root (TreeNode): Root of the first tree.
    - tree2_root (TreeNode): Root of the second tree.
    - max_depth (int): Maximum depth considered.

    Returns:
    - np.ndarray: Final probability matrix representing nested distance.
    """
    paths_tree1 = get_paths_to_leaves(tree1_root, max_depth)
    paths_tree2 = get_paths_to_leaves(tree2_root, max_depth)
    final_prob_matrix = np.zeros((len(paths_tree1), len(paths_tree2)))

    for i, path1 in enumerate(paths_tree1):
        for j, path2 in enumerate(paths_tree2):
            probability = 1.0
            for depth in range(max_depth):
                if depth >= len(path1) or depth >= len(path2):
                    break
                step_name = (depth, path1[depth], path2[depth])
                prob_matrix = probability_matrices.get(step_name, None)
                if prob_matrix is None or prob_matrix.size == 0:
                    probability = 0
                    break
                next_node1 = path1[depth + 1] if depth + 1 < len(path1) else None
                next_node2 = path2[depth + 1] if depth + 1 < len(path2) else None
                successors_node1 = [
                    child[-1]
                    for child in get_paths_to_leaves(tree1_root, depth + 1)
                    if child[:-1] == path1[: depth + 1]
                ]
                successors_node2 = [
                    child[-1]
                    for child in get_paths_to_leaves(tree2_root, depth + 1)
                    if child[:-1] == path2[: depth + 1]
                ]
                try:
                    index1 = successors_node1.index(next_node1)
                    index2 = successors_node2.index(next_node2)
                except ValueError:
                    probability = 0
                    break
                probability *= prob_matrix[index1, index2]
            final_prob_matrix[i, j] = probability
    return final_prob_matrix


def compute_nested_distance(
    tree1_root,
    tree2_root,
    max_depth,
    return_matrix=False,
    method="solver_lp",
    lambda_reg=0,
    power=1,
):
    """
    Computes the nested Wasserstein distance between two trees.

    Parameters:
    - tree1_root (TreeNode): Root of the first tree.
    - tree2_root (TreeNode): Root of the second tree.
    - max_depth (int): Maximum depth to compute.
    - return_matrix (bool): If True, returns the final probability matrix.
    - method (str): Solver method: "Sinkhorn", "solver_lp", or "solver_pot".
    - lambda_reg (float): Regularization parameter for Sinkhorn.

    Returns:
    - float: Computed nested distance.
    - np.ndarray (optional): Final probability matrix if return_matrix is True.
    """
    distance, probability_matrices = nested_optimal_transport_loop(
        tree1_root, tree2_root, max_depth, method, lambda_reg, power
    )

    if return_matrix:
        final_prob_matrix = compute_final_probability_matrix(
            probability_matrices, tree1_root, tree2_root, max_depth
        )
        return distance, final_prob_matrix
    return distance

In [31]:
import ot

def solver_lp_pot(distance_matrix_subset, pi_ratios, pi_tilde_ratios, reg=1e-2):
    """
    Solve for the optimal transport plan using the POT library's (fast!) EMD solver.

    Parameters:
    - distance_matrix_subset (np.ndarray): A 2D cost matrix.
    - pi_ratios (np.ndarray): 1D source distribution (row marginals).
    - pi_tilde_ratios (np.ndarray): 1D target distribution (column marginals).

    Returns:
    - np.ndarray: The optimal transport plan (probability matrix).
    """
    pi_ratios = np.array(pi_ratios, dtype=np.float64)
    pi_tilde_ratios = np.array(pi_tilde_ratios, dtype=np.float64)

    return ot.lp.emd(pi_ratios, pi_tilde_ratios, distance_matrix_subset)

In [63]:
# Set normalization flag (set to False to use L0 and M0 directly)
normalize = False

# Define factor matrices (6x6) for d=2, T=3
L0 = np.array([
    [1, 0, 0, 0],
    [1, 2, 0, 0],
    [1, 2, 3, 0],
    [7, 5, 4, 9]
])
A0 = L0 @ L0.T
L = L0 / np.sqrt(np.trace(A0)) if normalize else L0
A = L @ L.T

M0 = np.array([
    [2, 0, 0, 0],
    [3, 1, 0, 0],
    [1, 4, 2, 0],
    [8, 5, 3, 7]
])
B0 = M0 @ M0.T
M = M0 / np.sqrt(np.trace(B0)) if normalize else M0
B = M @ M.T

# Set dimension parameters: for d=2, T=3 (thus total dimension = 6)
d = 2
T = 2
dim = d * T  # 4

n_sample_plot = 600  # number of sample paths

X_paths = []
Y_paths = []
for _ in range(n_sample_plot):
    # Generate noise as a vector in R^{dim}
    noise1 = np.random.normal(size=(dim,))  # shape: (6,)
    noise2 = np.random.normal(size=(dim,))  # shape: (6,)
    # Obtain increments: these are vectors in R^{dim} (6,)
    X_increments = L @ noise1  # shape: (6,)
    Y_increments = M @ noise2  # shape: (6,)
    # Reshape into (T, d) = (3, 2)
    X_increments = X_increments.reshape((T, d))
    Y_increments = Y_increments.reshape((T, d))
    # (Optionally, if you still want to prepend a zero step, do it here.)
    X_sample = np.vstack([np.zeros((1, d)), X_increments])
    Y_sample = np.vstack([np.zeros((1, d)), Y_increments])

    X_paths.append(X_sample)
    Y_paths.append(Y_sample)
    
X_paths = np.array(X_paths)  # shape: (160, 3, 2)
Y_paths = np.array(Y_paths)  # shape: (160, 3, 2)

# Adapt the empirical measure using uniform grid quantization.
adapted_X, adapted_weights_X = uniform_empirical_grid_measure(X_paths, use_weights=True)
adapted_Y, adapted_weights_Y = uniform_empirical_grid_measure(Y_paths, use_weights=True)

# Build trees from the adapted paths.
adapted_tree_1 = build_tree_from_paths(adapted_X, adapted_weights_X)
adapted_tree_2 = build_tree_from_paths(adapted_Y, adapted_weights_Y)

max_depth = get_depth(adapted_tree_1)
start_time = time.time()

# Compute the nested distance using your chosen optimal transport solver (here, "solver_lp_pot" is used).
distance_pot = compute_nested_distance(
    adapted_tree_1,
    adapted_tree_2,
    max_depth,
    method="solver_lp_pot",
    return_matrix=False,
    lambda_reg=0,
    power=2,
)
end_time = time.time()

print("Nested distance:", distance_pot)
print("Computation time: {:.4f} seconds".format(end_time - start_time))

Nested distance: 18.467465828293193
Computation time: 184.9159 seconds


In [64]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np

notebooks_path = os.path.abspath(os.getcwd()) 
src_path = os.path.abspath(os.path.join(notebooks_path, "../src"))

if src_path not in sys.path:
    sys.path.insert(0, src_path)

from benchmark_value_gaussian.Comp_AWD2_Gaussian import *


# Define zero mean vectors for both processes in R^(d*T)
a = np.zeros(dim)
b = np.zeros(dim)

# Compute the adapted Wasserstein squared distance for the custom Gaussian process
distance_aw2 = adapted_wasserstein_squared(a, A, b, B, d, T)
print("Adapted Wasserstein Squared Distance for custom Gaussian process:", distance_aw2)


Adapted Wasserstein Squared Distance for custom Gaussian process: 16.32897915629934
