# Plotting multimarginal OT solutions

In [None]:
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

import matplotlib_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
plt.style.use("math.mplstyle")


def solve_sdp(sigmas):
    p = len(sigmas)
    if p < 2:
        raise ValueError("At least two covariances are required.")

    dims = [s.shape[0] for s in sigmas]
    if not all(d == dims[0] for d in dims):
        raise ValueError("All covariances must have the same dimensions.")

    if not all(s.shape[0] == s.shape[1] for s in sigmas):
        raise ValueError("All covariances must be square.")

    C_vars = {}
    for i in range(p):
        for j in range(i + 1, p):
            C_vars[(i, j)] = cp.Variable((dims[i], dims[j]), name=f"C_{i+1}_{j+1}")

    M_blocks = [[None for _ in range(p)] for _ in range(p)]
    for i in range(p):
        for j in range(p):
            if i == j:
                M_blocks[i][j] = sigmas[i]
            elif i < j:
                M_blocks[i][j] = C_vars[(i, j)]
            else:
                M_blocks[i][j] = cp.transpose(C_vars[(j, i)])

    M = cp.bmat(M_blocks)

    objective_terms = []
    for i in range(p):
        for j in range(i + 1, p):
            objective_terms.append(cp.trace(C_vars[(i, j)]))

    objective = cp.Maximize(cp.sum(objective_terms))
    constraints = [M >> 0]
    problem = cp.Problem(objective, constraints)

    try:
        problem.solve(solver=cp.MOSEK, verbose=False)
    except cp.error.SolverError:
        print("MOSEK not found. Falling back to the SCS solver.")
        problem.solve(solver=cp.SCS, verbose=False)

    if problem.status not in ["optimal", "optimal_inaccurate"]:
        print(f"Warning: Solver finished with status: {problem.status}")
        return None, None, None

    constant_term = sum((p - 1) * np.trace(s @ s) for s in sigmas)
    maximized_trace_sum = problem.value
    final_objective_value = constant_term - 2 * maximized_trace_sum
    solved_C = {key: var.value for key, var in C_vars.items()}
    M_star = M.value

    return final_objective_value, solved_C, M_star


def plot_gaussian_density(cov_matrix, title, filename=None):
    mu = np.array([0, 0])

    x = np.linspace(-6, 6, 100)
    y = np.linspace(-5, 5, 100)
    X, Y = np.meshgrid(x, y)

    pos = np.dstack((X, Y))

    rv = multivariate_normal(mu, cov_matrix, allow_singular=True)
    Z = rv.pdf(pos)

    contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    if filename:
        plt.savefig(f"images/{filename}.pdf")

    plt.show()


def run_and_plot_example(sigmas_list, example_name):
    print(f"\n{'='*15} {example_name} {'='*15}")

    final_obj, solved_C, M_star = solve_sdp(sigmas_list)

    if final_obj is None:
        print("\nCould not find a solution.")
        print(f"{'='* (34 + len(example_name))}\n")
        return

    print(f"Objective value: {final_obj:.4f}")
    print("\nOptimal matrix:\n", np.round(M_star, 4))

    eigenvalues = np.linalg.eigvalsh(M_star)

    for i, sigma in enumerate(sigmas_list):
        plot_gaussian_density(
            sigma,
            f"Input density for $\\Sigma_{{{i+1}}}$",
            filename=f"{safe_name}_{i+1}",
        )

    U, s, Vh = np.linalg.svd(M_star)
    U_rank2 = U[:, :2]
    M_proj = U_rank2.T @ M_star @ U_rank2

    plot_gaussian_density(
        M_proj,
        f"Multi-marginal density (non-degenerate part)",
        filename=f"{safe_name}_o",
    )

    print(f"{'='* (34 + len(example_name))}\n")


S1_ex1 = np.array([[1.0, 0.6], [0.6, 1.0]])
S2_ex1 = np.array([[2.0, -0.9], [-0.9, 2.0]])
S3_ex1 = np.array([[1.5, 0.8], [0.8, 1.5]])
run_and_plot_example([S1_ex1, S2_ex1, S3_ex1], "mm1")

rotation = np.array(
    [
        [np.cos(np.pi / 2), -np.sin(np.pi / 2)],
        [np.sin(np.pi / 2), np.cos(np.pi / 2)],
    ]
)
S1_ex2 = np.array([[1.0, 0.6], [0.6, 1.0]])
S2_ex2 = rotation @ np.array([[2.0, -0.9], [-0.9, 2.0]]) @ rotation.T
S3_ex2 = np.array([[1.5, 0.8], [0.8, 1.5]])
run_and_plot_example([S1_ex2, S2_ex2, S3_ex2], "mm2")

# Benchmarking multimarginal OT algorithms

In [None]:
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import scipy.linalg
from scipy.optimize import minimize

import matplotlib.pyplot as plt
import matplotlib_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
plt.style.use("math.mplstyle")


def matrix_sqrt(A):
    eigvals, eigvecs = np.linalg.eigh(A)
    sqrt_eigvals = np.sqrt(np.maximum(eigvals, 1e-12))
    return eigvecs @ np.diag(sqrt_eigvals) @ eigvecs.T


def project_to_constraint_manifold(U_list, sigmas):
    projected_U = []
    for i, (U, sigma) in enumerate(zip(U_list, sigmas)):
        current_sigma = U @ U.T

        sqrt_sigma = matrix_sqrt(sigma)

        try:
            U_svd, s, Vt = np.linalg.svd(U, full_matrices=False)
            Q = np.linalg.lstsq(sqrt_sigma, U, rcond=None)[0]

            if Q.shape[1] <= Q.shape[0]:
                Q_orth, _ = np.linalg.qr(Q)
            else:
                U_q, s_q, Vt_q = np.linalg.svd(Q, full_matrices=False)
                Q_orth = U_q @ Vt_q

            U_new = sqrt_sigma @ Q_orth
            projected_U.append(U_new)

        except np.linalg.LinAlgError:
            d = sigma.shape[0]
            Q_random = np.random.randn(d, d + 1)
            Q_orth, _ = np.linalg.qr(Q_random)
            U_new = sqrt_sigma @ Q_orth
            projected_U.append(U_new)

    return projected_U


def solve_burer_monteiro_manifold(sigmas, max_iter=2000, lr=0.01, verbose=True):
    p = len(sigmas)
    d = sigmas[0].shape[0]

    sqrt_sigmas = [matrix_sqrt(sigma) for sigma in sigmas]

    Q_list = []
    for i in range(p):
        Q_random = np.random.randn(d, d + 1)
        Q_orth, _ = np.linalg.qr(Q_random)
        Q_list.append(Q_orth)

    def retract_to_manifold(Q):
        Q_orth, _ = np.linalg.qr(Q)
        return Q_orth

    def project_to_tangent_space(Q, H):
        return H - Q @ (Q.T @ H + H.T @ Q) / 2

    for iteration in range(max_iter):
        U_list = [sqrt_sigmas[i] @ Q_list[i] for i in range(p)]

        obj = 0.0
        grad_U = [np.zeros_like(U) for U in U_list]

        for i in range(p):
            for j in range(i + 1, p):
                trace_val = np.trace(U_list[i] @ U_list[j].T)
                obj -= trace_val

                grad_U[i] -= U_list[j]
                grad_U[j] -= U_list[i]

        grad_Q = []
        for i in range(p):
            grad_Q_i = sqrt_sigmas[i].T @ grad_U[i]
            grad_Q_i_proj = project_to_tangent_space(Q_list[i], grad_Q_i)
            grad_Q.append(grad_Q_i_proj)

        if iteration % 200 == 0 and verbose:
            print(f"Iteration {iteration}: objective = {-obj:.6f}")

        for i in range(p):
            Q_new = Q_list[i] - lr * grad_Q[i]
            Q_list[i] = retract_to_manifold(Q_new)

        grad_norm = sum(np.linalg.norm(grad) for grad in grad_Q)
        if grad_norm < 1e-8:
            break

    U_list = [sqrt_sigmas[i] @ Q_list[i] for i in range(p)]
    U_full = np.vstack(U_list)
    Sigma_full = U_full @ U_full.T

    final_obj = 0.0
    for i in range(p):
        for j in range(i + 1, p):
            final_obj += np.trace(U_list[i] @ U_list[j].T)

    return final_obj, U_list, Sigma_full


def solve_sdp(sigmas):
    p = len(sigmas)
    if p < 2:
        raise ValueError("At least two covariances are required.")

    dims = [s.shape[0] for s in sigmas]
    if not all(d == dims[0] for d in dims):
        raise ValueError("All covariances must have the same dimensions.")

    C_vars = {}
    for i in range(p):
        for j in range(i + 1, p):
            C_vars[(i, j)] = cp.Variable((dims[i], dims[j]), name=f"C_{i+1}_{j+1}")

    M_blocks = [[None for _ in range(p)] for _ in range(p)]
    for i in range(p):
        for j in range(p):
            if i == j:
                M_blocks[i][j] = sigmas[i]
            elif i < j:
                M_blocks[i][j] = C_vars[(i, j)]
            else:
                M_blocks[i][j] = cp.transpose(C_vars[(j, i)])

    M = cp.bmat(M_blocks)

    objective_terms = []
    for i in range(p):
        for j in range(i + 1, p):
            objective_terms.append(cp.trace(C_vars[(i, j)]))

    objective = cp.Maximize(cp.sum(objective_terms))
    constraints = [M >> 0]
    problem = cp.Problem(objective, constraints)

    try:
        problem.solve(solver=cp.MOSEK, verbose=False)
    except cp.error.SolverError:
        problem.solve(solver=cp.SCS, verbose=False)

    if problem.status not in ["optimal", "optimal_inaccurate"]:
        print(f"Warning: Solver finished with status: {problem.status}")
        return None, None, None

    solved_C = {key: var.value for key, var in C_vars.items()}
    M_star = M.value
    maximized_trace_sum = problem.value

    return maximized_trace_sum, solved_C, M_star


def generate_random_covariances(p, d, seed=42):
    np.random.seed(seed)
    sigmas = []
    for i in range(p):
        A = np.random.randn(d, d)
        sigma = A @ A.T + 0.1 * np.eye(d)
        sigmas.append(sigma)
    return sigmas


def benchmark_methods(p_values, d=3, max_p_for_sdp=15):
    import time

    results = {
        "p_values": [],
        "sdp_times": [],
        "bm_times": [],
        "sdp_objectives": [],
        "bm_objectives": [],
        "variable_counts": [],
    }

    for p in p_values:
        sigmas = generate_random_covariances(p, d)

        sdp_vars = d**2 * p * (p - 1) // 2
        bm_vars = p * d * (d + 1)

        results["p_values"].append(p)
        results["variable_counts"].append((sdp_vars, bm_vars))

        start_time = time.time()
        bm_obj, _, _ = solve_burer_monteiro_manifold(
            sigmas, max_iter=1000, verbose=False
        )
        bm_time = time.time() - start_time

        results["bm_times"].append(bm_time)
        results["bm_objectives"].append(bm_obj)

        if p <= max_p_for_sdp:
            start_time = time.time()
            sdp_obj, _, _ = solve_sdp(sigmas)
            sdp_time = time.time() - start_time

            if sdp_obj is None:
                print("Failed to solve")
                sdp_time = float("inf")
                sdp_obj = None

            results["sdp_times"].append(sdp_time)
            results["sdp_objectives"].append(sdp_obj)

    return results


def plot_benchmark_results(results):
    p_vals = results["p_values"]

    plt.figure()
    valid_sdp = [
        (p, t)
        for p, t in zip(p_vals, results["sdp_times"])
        if t is not None and t != float("inf")
    ]
    if valid_sdp:
        sdp_p, sdp_times = zip(*valid_sdp)
        plt.plot(sdp_p, sdp_times, "-", label="SDP")

    plt.plot(p_vals, results["bm_times"], "-", label="Burer-Monteiro")
    plt.xlabel("Number of measures ($p$)", fontsize=12)
    plt.ylabel("Time (seconds)", fontsize=12)
    plt.title("Number of measures vs. computation time", fontsize=14)
    plt.legend(fontsize=12)
    plt.savefig("images/computation_times.pdf")
    plt.show()

    plt.figure()
    sdp_vars, bm_vars = zip(*results["variable_counts"])
    plt.plot(p_vals, sdp_vars, "-", label="SDP: $O(d^2 p^2)$")
    plt.plot(p_vals, bm_vars, "-", label="Burer-Monteiro: $O(d^2 p)$")
    plt.xlabel("Number of measures ($p$)", fontsize=12)
    plt.ylabel("Number of variables", fontsize=12)
    plt.title("Variable count scaling", fontsize=14)
    plt.legend(fontsize=12)
    plt.savefig("images/variable_counts.pdf")
    plt.show()

    valid_both = [
        (p, s, b)
        for p, s, b in zip(p_vals, results["sdp_objectives"], results["bm_objectives"])
        if s is not None
    ]
    if valid_both:
        plt.figure()
        both_p, both_sdp, both_bm = zip(*valid_both)
        plt.plot(both_p, both_sdp, "-", label="SDP")
        plt.plot(both_p, both_bm, "-", label="Burer-Monteiro")
        plt.xlabel("Number of measures ($p$)", fontsize=12)
        plt.ylabel("Objective value", fontsize=12)
        plt.title("Objective value comparison", fontsize=14)
        plt.legend(fontsize=12)
        plt.savefig("images/objective_comparison.pdf")
        plt.show()


p_values = [3, 5, 8, 10, 15, 20, 30, 50, 80, 100]

results = benchmark_methods(p_values, d=3, max_p_for_sdp=100)
plot_benchmark_results(results)