In [1]:
import os
import sys
import numpy as np
import time
import warnings
import random
import matplotlib.pyplot as plt


# Get the absolute path of src/ directory
notebooks_path = os.path.abspath(os.getcwd())  # Get the notebook’s current directory
src_path = os.path.abspath(os.path.join(notebooks_path, "../src"))

# Ensure src is in sys.path
if src_path not in sys.path:
    sys.path.insert(0, src_path)  # Insert at the beginning to prioritize it

    
from multi_dimension.Multidimension_trees import *
from multi_dimension.Multidimension_solver import *
from multi_dimension.Multidimension_adapted_empirical_measure import *

from measure_sampling.Gen_Path_and_AdaptedTrees import generate_adapted_tree
from trees.tree_Node import *
from trees.treeAnalysis import *
from trees.treeVisualization import *
from trees.save_Load_trees import *
from trees.tree_AWD_utilities import *
from trees.build_trees_from_paths import build_tree_from_paths

from adapted_empirical_measure.AEM_grid import *
from adapted_empirical_measure.AEM_kMeans import *
from benchmark_value_gaussian.Comp_AWD2_Gaussian import *
from awd_trees.Gurobi_AOT import *
from awd_trees.Nested_Dist_Algo import compute_nested_distance, compute_nested_distance_parallel, compute_nested_distance_parallel_generic

from optimal_code.utils import *
from optimal_code.optimal_solver import *

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

2025-03-23 16:58:39,678	INFO worker.py:1841 -- Started a local Ray instance.


In [2]:
random_seed = np.random.randint(100)
random_seed = 0

In [None]:
n_sample = 1000
T = 3

L = np.array([[1, 0, 0], [2, 4, 0], [3, 2, 1]])
X,A = Lmatrix2paths(L, n_sample, seed = random_seed, verbose = False)
M = np.array([[1, 0, 0], [2, 3, 0], [3, 1, 2]])
Y,B = Lmatrix2paths(M, n_sample, seed = random_seed, verbose = False)

dist_bench = adapted_wasserstein_squared(A, B)
print("Theoretical AW_2^2: ", dist_bench)

adaptedX = path2adaptedpath(X, delta_n = 0.1)
adaptedY = path2adaptedpath(Y, delta_n = 0.1)


Theoretical AW_2^2:  30.0


In [10]:
# Quantization map
q2v = np.unique(np.concatenate([adaptedX, adaptedY], axis=0))
v2q = {k: v for v, k in enumerate(q2v)}  # Value to Quantization

# Quantized paths
qX = np.array([[v2q[x] for x in y] for y in adaptedX])
qY = np.array([[v2q[x] for x in y] for y in adaptedY])

# Sort paths and transpose to (n_sample, T+1)
qX = sort_qpath(qX.T)
qY = sort_qpath(qY.T)

# Get conditional distribution mu_{x_{1:t}} = mu_x[t][(x_1,...,x_t)] = {x_{t+1} : mu_{x_{1:t}}(x_{t+1}), ...}
mu_x = qpath2mu_x(qX)
nu_y = qpath2mu_x(qY)

mu_x_c, mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn = list_repr_mu_x(mu_x, q2v)
nu_y_c, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn = list_repr_mu_x(nu_y, q2v)
# All list except weights should be increasing! 


In [8]:
start_time = time.perf_counter()
AW_2square = nested2_parallel(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn, n_processes = 6)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)

100%|██████████| 150/150 [00:00<00:00, 151.54it/s]
100%|██████████| 150/150 [00:01<00:00, 148.13it/s]
100%|██████████| 150/150 [00:01<00:00, 136.99it/s]
100%|██████████| 150/150 [00:01<00:00, 132.89it/s]
100%|██████████| 150/150 [00:01<00:00, 125.54it/s]
100%|██████████| 150/150 [00:01<00:00, 116.35it/s]
100%|██████████| 57/57 [00:00<00:00, 93.24it/s] 
100%|██████████| 1/1 [00:00<00:00, 1078.78it/s]


Elapsed time (Adapted OT): 44.6108 seconds
Numerical AW_2^2:  3.0517166666666666


In [9]:
adapted_X, adapted_weights_X = uniform_empirical_grid_measure(X.T, delta_n = 0.1, use_weights=True)
adapted_Y, adapted_weights_Y = uniform_empirical_grid_measure(Y.T, delta_n = 0.1, use_weights=True)
# Build trees from the adapted paths
tree_1 = build_tree_from_paths(adapted_X, adapted_weights_X)
tree_2 = build_tree_from_paths(adapted_Y, adapted_weights_Y)
# Compute nested distance and record timing
max_depth = get_depth(tree_1)
start_time = time.time()
distance_pot = compute_nested_distance_parallel(
    tree_1, tree_2, max_depth, power=2, num_chunks=6
)
elapsed_time = time.time() - start_time

Depth: 2


Parallel Depth 2:   0%|          | 0/6 [00:00<?, ?it/s]

2025-03-23 17:01:09,133	INFO worker.py:1841 -- Started a local Ray instance.
2025-03-23 17:01:09,157	INFO worker.py:1841 -- Started a local Ray instance.
2025-03-23 17:01:09,268	INFO worker.py:1841 -- Started a local Ray instance.
2025-03-23 17:01:09,390	INFO worker.py:1841 -- Started a local Ray instance.
2025-03-23 17:01:09,873	INFO worker.py:1841 -- Started a local Ray instance.
2025-03-23 17:01:10,144	INFO worker.py:1841 -- Started a local Ray instance.


Depth: 1
Depth: 0


In [None]:
print(distance_pot)
print(elapsed_time)

2.5179
1.8190641403198242
