In [None]:
import matplotlib.pyplot as plt
import matplotlib_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
plt.style.use("math.mplstyle")

# IGW and Wasserstein interpolations

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import qr, eigh
from scipy.stats import multivariate_normal


def stiefel_objective(X, Lambda1, Lambda2, v1, v2):
    term1 = np.trace(Lambda1 @ X @ Lambda2 @ X.T)
    term2 = 2 * v1.T @ X @ v2
    return term1 + term2


def stiefel_gradient(X, Lambda1, Lambda2, v1, v2):
    grad = 2 * Lambda1 @ X @ Lambda2 + 2 * np.outer(v1, v2)
    return grad


def stiefel_riemannian_gradient(X, euclidean_grad):
    XTgrad = X.T @ euclidean_grad
    sym_XTgrad = 0.5 * (XTgrad + XTgrad.T)
    riem_grad = euclidean_grad - X @ sym_XTgrad
    return riem_grad


def stiefel_retraction(X, direction, step_size):
    Y = X + step_size * direction
    if Y.shape[0] == Y.shape[1]:
        Q, R = qr(Y)
    else:
        Q, R = qr(Y, mode="economic")
    signs = np.sign(np.diag(R))
    signs[signs == 0] = 1
    Q = Q @ np.diag(signs)
    return Q


def optimize_stiefel(
    Lambda1, Lambda2, v1, v2, d1, d2, max_iter=500, tol=1e-8, step_size=0.01
):
    np.random.seed(42)
    X = np.eye(d1, d2)
    objective_history = []

    for i in range(max_iter):
        obj = stiefel_objective(X, Lambda1, Lambda2, v1, v2)
        objective_history.append(obj)

        euclidean_grad = stiefel_gradient(X, Lambda1, Lambda2, v1, v2)
        riem_grad = stiefel_riemannian_gradient(X, euclidean_grad)

        grad_norm = np.linalg.norm(riem_grad)
        if grad_norm < tol:
            print(f"Converged after {i} iterations")
            break

        X_new = stiefel_retraction(X, riem_grad, step_size)
        obj_new = stiefel_objective(X_new, Lambda1, Lambda2, v1, v2)

        backtrack_count = 0
        while obj_new < obj and step_size > 1e-12 and backtrack_count < 10:
            step_size *= 0.5
            X_new = stiefel_retraction(X, riem_grad, step_size)
            obj_new = stiefel_objective(X_new, Lambda1, Lambda2, v1, v2)
            backtrack_count += 1

        if obj_new >= obj:
            X = X_new
            step_size = min(step_size * 1.05, 0.1)
        else:
            step_size *= 0.5

    return X, objective_history


def solve_igw_transport(m1, m2, Sigma1, Sigma2):
    d1, d2 = len(m1), len(m2)

    eig1 = np.linalg.eigvals(Sigma1)
    eig2 = np.linalg.eigvals(Sigma2)
    if not np.all(eig1 > 1e-12):
        Sigma1 += np.eye(d1) * 1e-6
    if not np.all(eig2 > 1e-12):
        Sigma2 += np.eye(d2) * 1e-6

    Lambda1_vals, Q1 = eigh(Sigma1)
    Lambda2_vals, Q2 = eigh(Sigma2)

    Lambda1_vals = np.maximum(Lambda1_vals, 1e-12)
    Lambda2_vals = np.maximum(Lambda2_vals, 1e-12)

    Lambda1 = np.diag(Lambda1_vals)
    Lambda2 = np.diag(Lambda2_vals)

    Sigma1_sqrt = Q1 @ np.diag(np.sqrt(Lambda1_vals)) @ Q1.T
    Sigma2_sqrt = Q2 @ np.diag(np.sqrt(Lambda2_vals)) @ Q2.T

    v1_orig = Sigma1_sqrt @ m1
    v2_orig = Sigma2_sqrt @ m2

    v1 = Q1.T @ v1_orig
    v2 = Q2.T @ v2_orig

    X_opt, obj_history = optimize_stiefel(Lambda1, Lambda2, v1, v2, d1, d2)

    C_opt = Sigma1_sqrt @ Q1 @ X_opt @ Q2.T @ Sigma2_sqrt

    A = C_opt.T @ np.linalg.pinv(Sigma1)

    I1 = np.trace(Sigma1 @ Sigma1) + 2 * (m1.T @ Sigma1 @ m1) + np.linalg.norm(m1) ** 4
    I2 = np.trace(Sigma2 @ Sigma2) + 2 * (m2.T @ Sigma2 @ m2) + np.linalg.norm(m2) ** 4
    I3 = (
        np.trace(C_opt.T @ C_opt)
        + 2 * (m1.T @ C_opt @ m2)
        + np.linalg.norm(m1) ** 2 * np.linalg.norm(m2) ** 2
    )

    igw_distance = np.sqrt(I1 + I2 - 2 * I3)

    return A, igw_distance, obj_history


def plot_contour_gaussian(mean, cov, t_value, x_range, y_range):
    x = np.linspace(x_range[0], x_range[1], 200)
    y = np.linspace(y_range[0], y_range[1], 200)
    X, Y = np.meshgrid(x, y)

    pos = np.dstack((X, Y))
    rv = multivariate_normal(mean, cov)
    Z = rv.pdf(pos)

    plt.figure()
    contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

    plt.plot(
        0,
        0,
        "o",
        markersize=12,
        label="Origin",
        color="red",
        markeredgecolor="white",
        markeredgewidth=1,
    )
    plt.legend(fontsize=14)

    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.tight_layout()
    plt.savefig(f"images/igw_interpolation_{t_value:.2f}.pdf")
    plt.show()


m1 = np.array([10, 10])
theta = -np.pi / 6
rotation1 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma1 = rotation1 @ np.array([[10, 0], [0, 1]]) @ rotation1.T

m2 = np.array([-10, 10])
theta = np.pi / 4
rotation2 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma2 = rotation2 @ np.array([[10, 0], [0, 1]]) @ rotation2.T

A, igw_dist, obj_hist = solve_igw_transport(m1, m2, Sigma1, Sigma2)


def transport_map(x):
    return m2 + A @ (x - m1)

transported_cov = A @ Sigma1 @ A.T

all_means_x = [m1[0], m2[0], 0]
all_means_y = [m1[1], m2[1], 0]

eigenvals1, _ = np.linalg.eigh(Sigma1)
eigenvals2, _ = np.linalg.eigh(Sigma2)
max_std = 3 * np.sqrt(max(np.max(eigenvals1), np.max(eigenvals2)))

x_range = (min(all_means_x) - max_std, max(all_means_x) + max_std)
y_range = (-2, 17)

t_values = [0, 0.33, 0.67, 1.0]

for t in t_values:
    if t == 0:
        interpolated_mean = m1
        interpolated_cov = Sigma1
    elif t == 1:
        interpolated_mean = m2
        interpolated_cov = Sigma2
    else:
        interpolated_mean = (1 - t) * m1 + t * m2
        B = (1 - t) * np.eye(2) + t * A
        interpolated_cov = B @ Sigma1 @ B.T

    plot_contour_gaussian(interpolated_mean, interpolated_cov, t, x_range, y_range)

plt.figure()
plt.plot(obj_hist)
plt.title("Convergence of Stiefel Optimization")
plt.xlabel("Iteration")
plt.ylabel("Objective Value")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import qr, eigh
from scipy.stats import multivariate_normal


def stiefel_objective(X, Lambda1, Lambda2, v1, v2):
    term1 = np.trace(Lambda1 @ X @ Lambda2 @ X.T)
    term2 = 2 * v1.T @ X @ v2
    return term1 + term2


def stiefel_gradient(X, Lambda1, Lambda2, v1, v2):
    grad = 2 * Lambda1 @ X @ Lambda2 + 2 * np.outer(v1, v2)
    return grad


def stiefel_riemannian_gradient(X, euclidean_grad):
    XTgrad = X.T @ euclidean_grad
    sym_XTgrad = 0.5 * (XTgrad + XTgrad.T)
    riem_grad = euclidean_grad - X @ sym_XTgrad
    return riem_grad


def stiefel_retraction(X, direction, step_size):
    Y = X + step_size * direction
    if Y.shape[0] == Y.shape[1]:
        Q, R = qr(Y)
    else:
        Q, R = qr(Y, mode="economic")
    signs = np.sign(np.diag(R))
    signs[signs == 0] = 1
    Q = Q @ np.diag(signs)
    return Q


def optimize_stiefel(
    Lambda1, Lambda2, v1, v2, d1, d2, max_iter=500, tol=1e-8, step_size=0.01
):
    np.random.seed(42)
    X = np.eye(d1, d2)
    objective_history = []

    for i in range(max_iter):
        obj = stiefel_objective(X, Lambda1, Lambda2, v1, v2)
        objective_history.append(obj)

        euclidean_grad = stiefel_gradient(X, Lambda1, Lambda2, v1, v2)
        riem_grad = stiefel_riemannian_gradient(X, euclidean_grad)

        grad_norm = np.linalg.norm(riem_grad)
        if grad_norm < tol:
            print(f"Converged after {i} iterations")
            break

        X_new = stiefel_retraction(X, riem_grad, step_size)
        obj_new = stiefel_objective(X_new, Lambda1, Lambda2, v1, v2)

        backtrack_count = 0
        while obj_new < obj and step_size > 1e-12 and backtrack_count < 10:
            step_size *= 0.5
            X_new = stiefel_retraction(X, riem_grad, step_size)
            obj_new = stiefel_objective(X_new, Lambda1, Lambda2, v1, v2)
            backtrack_count += 1

        if obj_new >= obj:
            X = X_new
            step_size = min(step_size * 1.05, 0.1)
        else:
            step_size *= 0.5

    return X, objective_history


def solve_igw_transport(m1, m2, Sigma1, Sigma2):
    d1, d2 = len(m1), len(m2)

    eig1 = np.linalg.eigvals(Sigma1)
    eig2 = np.linalg.eigvals(Sigma2)
    if not np.all(eig1 > 1e-12):
        Sigma1 += np.eye(d1) * 1e-6
    if not np.all(eig2 > 1e-12):
        Sigma2 += np.eye(d2) * 1e-6

    Lambda1_vals, Q1 = eigh(Sigma1)
    Lambda2_vals, Q2 = eigh(Sigma2)

    Lambda1_vals = np.maximum(Lambda1_vals, 1e-12)
    Lambda2_vals = np.maximum(Lambda2_vals, 1e-12)

    Lambda1 = np.diag(Lambda1_vals)
    Lambda2 = np.diag(Lambda2_vals)

    Sigma1_sqrt = Q1 @ np.diag(np.sqrt(Lambda1_vals)) @ Q1.T
    Sigma2_sqrt = Q2 @ np.diag(np.sqrt(Lambda2_vals)) @ Q2.T

    v1_orig = Sigma1_sqrt @ m1
    v2_orig = Sigma2_sqrt @ m2

    v1 = Q1.T @ v1_orig
    v2 = Q2.T @ v2_orig

    X_opt, obj_history = optimize_stiefel(Lambda1, Lambda2, v1, v2, d1, d2)

    C_opt = Sigma1_sqrt @ Q1 @ X_opt @ Q2.T @ Sigma2_sqrt

    A = C_opt.T @ np.linalg.pinv(Sigma1)

    I1 = np.trace(Sigma1 @ Sigma1) + 2 * (m1.T @ Sigma1 @ m1) + np.linalg.norm(m1) ** 4
    I2 = np.trace(Sigma2 @ Sigma2) + 2 * (m2.T @ Sigma2 @ m2) + np.linalg.norm(m2) ** 4
    I3 = (
        np.trace(C_opt.T @ C_opt)
        + 2 * (m1.T @ C_opt @ m2)
        + np.linalg.norm(m1) ** 2 * np.linalg.norm(m2) ** 2
    )

    igw_distance = np.sqrt(I1 + I2 - 2 * I3)

    return A, igw_distance, obj_history


def plot_contour_gaussian(mean, cov, t_value, x_range, y_range):
    x = np.linspace(x_range[0], x_range[1], 200)
    y = np.linspace(y_range[0], y_range[1], 200)
    X, Y = np.meshgrid(x, y)

    pos = np.dstack((X, Y))
    rv = multivariate_normal(mean, cov)
    Z = rv.pdf(pos)

    plt.figure()
    contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

    plt.plot(
        0,
        0,
        "o",
        markersize=12,
        label="Origin",
        color="red",
        markeredgecolor="white",
        markeredgewidth=1,
    )
    plt.legend(fontsize=14)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.tight_layout()
    plt.savefig(f"images/igw_interpolation_shifted_{t_value:.2f}.pdf")
    plt.show()

m1 = np.array([10, -10])
theta = -np.pi / 6
rotation1 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma1 = rotation1 @ np.array([[10, 0], [0, 1]]) @ rotation1.T

m2 = np.array([-10, -10])
theta = np.pi / 4
rotation2 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma2 = rotation2 @ np.array([[10, 0], [0, 1]]) @ rotation2.T

A, igw_dist, obj_hist = solve_igw_transport(m1, m2, Sigma1, Sigma2)

def transport_map(x):
    return m2 + A @ (x - m1)


transported_cov = A @ Sigma1 @ A.T

all_means_x = [m1[0], m2[0]]
all_means_y = [m1[1], m2[1]]

eigenvals1, _ = np.linalg.eigh(Sigma1)
eigenvals2, _ = np.linalg.eigh(Sigma2)
max_std = 3 * np.sqrt(max(np.max(eigenvals1), np.max(eigenvals2)))

x_range = (min(all_means_x) - max_std, max(all_means_x) + max_std)
y_range = (-17, 2)

t_values = [0, 0.33, 0.67, 1.0]

for t in t_values:
    if t == 0:
        interpolated_mean = m1
        interpolated_cov = Sigma1
    elif t == 1:
        interpolated_mean = m2
        interpolated_cov = Sigma2
    else:
        interpolated_mean = (1 - t) * m1 + t * m2
        B = (1 - t) * np.eye(2) + t * A
        interpolated_cov = B @ Sigma1 @ B.T

    plot_contour_gaussian(interpolated_mean, interpolated_cov, t, x_range, y_range)

plt.figure()
plt.plot(obj_hist)
plt.title("Convergence of Stiefel Optimization")
plt.xlabel("Iteration")
plt.ylabel("Objective Value")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal


def psd_sqrt(A, eps=1e-12):
    w, V = np.linalg.eigh(A)
    w = np.clip(w, eps, None)
    return V @ np.diag(np.sqrt(w)) @ V.T


def psd_invsqrt(A, eps=1e-12):
    w, V = np.linalg.eigh(A)
    w = np.clip(w, eps, None)
    return V @ np.diag(1.0 / np.sqrt(w)) @ V.T


def solve_w2_transport(m1, m2, Sigma1, Sigma2):
    d = len(m1)

    Sigma1 = 0.5 * (Sigma1 + Sigma1.T)
    Sigma2 = 0.5 * (Sigma2 + Sigma2.T)
    eps = 1e-12
    Sigma1 += eps * np.eye(d)
    Sigma2 += eps * np.eye(d)

    S1h = psd_sqrt(Sigma1)
    S1hi = psd_invsqrt(Sigma1)
    M = S1h @ Sigma2 @ S1h
    Mhalf = psd_sqrt(M)

    A = S1hi @ Mhalf @ S1hi

    W2_sq = (
        np.sum((m1 - m2) ** 2)
        + np.trace(Sigma1)
        + np.trace(Sigma2)
        - 2.0 * np.trace(Mhalf)
    )
    W2 = np.sqrt(max(W2_sq, 0.0))
    return A, W2


def plot_contour_gaussian(mean, cov, t_value, x_range, y_range):
    x = np.linspace(x_range[0], x_range[1], 200)
    y = np.linspace(y_range[0], y_range[1], 200)
    X, Y = np.meshgrid(x, y)

    pos = np.dstack((X, Y))
    rv = multivariate_normal(mean, cov)
    Z = rv.pdf(pos)

    plt.figure()
    contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

    plt.plot(
        0,
        0,
        "o",
        markersize=12,
        label="Origin",
        color="red",
        markeredgecolor="white",
        markeredgewidth=1,
    )
    plt.legend(fontsize=14)

    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.tight_layout()
    plt.savefig(f"images/w2_interpolation_{t_value:.2f}.pdf")
    plt.show()


m1 = np.array([10, 10])
theta = -np.pi / 6
rotation1 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma1 = rotation1 @ np.array([[10, 0], [0, 1]]) @ rotation1.T

m2 = np.array([-10, 10])
theta = np.pi / 4
rotation2 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma2 = rotation2 @ np.array([[10, 0], [0, 1]]) @ rotation2.T

A, w2_dist = solve_w2_transport(m1, m2, Sigma1, Sigma2)

def transport_map(x):
    return m2 + A @ (x - m1)

transported_cov = A @ Sigma1 @ A.T

all_means_x = [m1[0], m2[0]]
all_means_y = [m1[1], m2[1]]

eigenvals1, _ = np.linalg.eigh(Sigma1)
eigenvals2, _ = np.linalg.eigh(Sigma2)
max_std = 3 * np.sqrt(max(np.max(eigenvals1), np.max(eigenvals2)))

x_range = (min(all_means_x) - max_std, max(all_means_x) + max_std)
y_range = (-2, 17)

t_values = [0, 0.33, 0.67, 1.0]

for t in t_values:
    if t == 0:
        interpolated_mean = m1
        interpolated_cov = Sigma1
    elif t == 1:
        interpolated_mean = m2
        interpolated_cov = Sigma2
    else:
        interpolated_mean = (1 - t) * m1 + t * m2
        B = (1 - t) * np.eye(2) + t * A
        interpolated_cov = B @ Sigma1 @ B.T

    plot_contour_gaussian(interpolated_mean, interpolated_cov, t, x_range, y_range)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal


def psd_sqrt(A, eps=1e-12):
    w, V = np.linalg.eigh(A)
    w = np.clip(w, eps, None)
    return V @ np.diag(np.sqrt(w)) @ V.T


def psd_invsqrt(A, eps=1e-12):
    w, V = np.linalg.eigh(A)
    w = np.clip(w, eps, None)
    return V @ np.diag(1.0 / np.sqrt(w)) @ V.T


def solve_w2_transport(m1, m2, Sigma1, Sigma2):
    d = len(m1)

    Sigma1 = 0.5 * (Sigma1 + Sigma1.T)
    Sigma2 = 0.5 * (Sigma2 + Sigma2.T)
    eps = 1e-12
    Sigma1 += eps * np.eye(d)
    Sigma2 += eps * np.eye(d)

    S1h = psd_sqrt(Sigma1)
    S1hi = psd_invsqrt(Sigma1)
    M = S1h @ Sigma2 @ S1h
    Mhalf = psd_sqrt(M)

    A = S1hi @ Mhalf @ S1hi

    W2_sq = (
        np.sum((m1 - m2) ** 2)
        + np.trace(Sigma1)
        + np.trace(Sigma2)
        - 2.0 * np.trace(Mhalf)
    )
    W2 = np.sqrt(max(W2_sq, 0.0))
    return A, W2


def plot_contour_gaussian(mean, cov, t_value, x_range, y_range):
    x = np.linspace(x_range[0], x_range[1], 200)
    y = np.linspace(y_range[0], y_range[1], 200)
    X, Y = np.meshgrid(x, y)

    pos = np.dstack((X, Y))
    rv = multivariate_normal(mean, cov)
    Z = rv.pdf(pos)

    plt.figure()
    contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

    plt.plot(
        0,
        0,
        "o",
        markersize=12,
        label="Origin",
        color="red",
        markeredgecolor="white",
        markeredgewidth=1,
    )
    plt.legend(fontsize=14)

    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.tight_layout()
    plt.savefig(f"images/w2_interpolation_shifted_{t_value:.2f}.pdf")
    plt.show()

m1 = np.array([10, -10])
theta = -np.pi / 6
rotation1 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma1 = rotation1 @ np.array([[10, 0], [0, 1]]) @ rotation1.T

m2 = np.array([-10, -10])
theta = np.pi / 4
rotation2 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Sigma2 = rotation2 @ np.array([[10, 0], [0, 1]]) @ rotation2.T

A, w2_dist = solve_w2_transport(m1, m2, Sigma1, Sigma2)

def transport_map(x):
    return m2 + A @ (x - m1)

transported_cov = A @ Sigma1 @ A.T

all_means_x = [m1[0], m2[0]]
all_means_y = [m1[1], m2[1]]

eigenvals1, _ = np.linalg.eigh(Sigma1)
eigenvals2, _ = np.linalg.eigh(Sigma2)
max_std = 3 * np.sqrt(max(np.max(eigenvals1), np.max(eigenvals2)))

x_range = (min(all_means_x) - max_std, max(all_means_x) + max_std)
y_range = (-17, 2)

t_values = [0, 0.33, 0.67, 1.0]

for t in t_values:
    if t == 0:
        interpolated_mean = m1
        interpolated_cov = Sigma1
    elif t == 1:
        interpolated_mean = m2
        interpolated_cov = Sigma2
    else:
        interpolated_mean = (1 - t) * m1 + t * m2
        B = (1 - t) * np.eye(2) + t * A
        interpolated_cov = B @ Sigma1 @ B.T

    plot_contour_gaussian(interpolated_mean, interpolated_cov, t, x_range, y_range)

# IGW and Wasserstein barycenters

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from scipy.linalg import sqrtm

import matplotlib_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
plt.style.use("math.mplstyle")

In [None]:
mu = np.array([0, 0])

theta = -np.pi / 6
rotation1 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
cov1 = rotation1 @ np.array([[100, 0], [0, 10]]) @ rotation1.T

x = np.linspace(-40, 40, 100)
y = np.linspace(-40, 40, 100)
X, Y = np.meshgrid(x, y)

pos = np.dstack((X, Y))
rv = multivariate_normal(mu, cov1)
Z = rv.pdf(pos)

contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.savefig("images/density1.png")

In [None]:
theta = np.pi / 4
rotation2 = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
cov2 = rotation2 @ np.array([[100, 0], [0, 10]]) @ rotation2.T

rv = multivariate_normal(mu, cov2)
Z = rv.pdf(pos)

contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.savefig("images/density2.png")

In [None]:
brenier_map = (
    sqrtm(cov2) @ np.linalg.inv(sqrtm(sqrtm(cov2) @ cov1 @ sqrtm(cov2))) @ sqrtm(cov2)
)
wasserstein_barycenter_cov = (
    1 / 4 * (np.eye(2) + brenier_map) @ cov1 @ (np.eye(2) + brenier_map).T
)

rv = multivariate_normal(mu, wasserstein_barycenter_cov)
Z = rv.pdf(pos)

contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.savefig("images/density_w2.png")

In [None]:
igw_barycenter_cov = np.array([[100, 0], [0, 10]])

x = np.linspace(-40, 40, 100)
y = np.linspace(-40, 40, 100)
X, Y = np.meshgrid(x, y)

pos = np.dstack((X, Y))
rv = multivariate_normal(mu, igw_barycenter_cov)
Z = rv.pdf(pos)

contourf = plt.contourf(X, Y, Z, cmap="viridis", alpha=0.8)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.savefig("images/density_igw.png")