In [1]:
import joblib
import numpy as np
import torch

import modularised_utils as mut
import matplotlib.pyplot as plt

import opt_utils as oput

import Linear_Additive_Noise_Models as lanm
import operations as ops
from scipy.linalg import sqrtm

import params

np.random.seed(0)

In [2]:
experiment = 'synth1'

In [3]:
# Define the radius of the Wasserstein balls (epsilon, delta) and the size for both models.
epsilon         = params.radius[experiment][0]
ll_num_envs     = params.n_envs[experiment][0]

delta           = params.radius[experiment][1]
hl_num_envs     = params.n_envs[experiment][1]

# Define the number of samples per environment. Currently every environment has the same number of samples
num_llsamples   = params.n_samples[experiment][0]
num_hlsamples   = params.n_samples[experiment][1]

In [4]:
Dll = mut.load_samples(experiment)[None][0] 
Gll = mut.load_ll_model(experiment)[0]
Ill = mut.load_ll_model(experiment)[1]


Dhl = mut.load_samples(experiment)[None][1] 
Ghl = mut.load_hl_model(experiment)[0]
Ihl = mut.load_hl_model(experiment)[1]

omega = mut.load_omega_map(experiment)

In [5]:
ll_coeffs = mut.get_coefficients(Dll, Gll)
hl_coeffs = mut.get_coefficients(Dhl, Ghl) 

In [6]:
# # [Not suggested] In case we want to explore also the interventional --> worse estimation!
# Dlls, Dhls = [], []
# for dpair in list(mut.load_samples(experiment).values()):
#     Dlls.append(dpair[0])
#     Dhls.append(dpair[1])
    
# ll_coeffs = mut.get_coefficients(Dlls, Gll)
# hl_coeffs = mut.get_coefficients(Dhls, Ghl) 

In [7]:
U_ll_hat, mu_U_ll_hat, Sigma_U_ll_hat = mut.lan_abduction(Dll, Gll, ll_coeffs)
U_hl_hat, mu_U_hl_hat, Sigma_U_hl_hat = mut.lan_abduction(Dhl, Ghl, hl_coeffs)

In [8]:
LLmodels = {}
for iota in Ill:
    LLmodels[iota] = lanm.LinearAddSCM(Gll, ll_coeffs, iota)
    
HLmodels, Dhl_samples = {}, {}
for eta in Ihl:
    HLmodels[eta] = lanm.LinearAddSCM(Ghl, hl_coeffs, eta)

In [27]:
L_matrices = compute_struc_matrices(LLmodels, Ill)

for L in L_matrices:
    print(L @ Sigma_U_ll_hat @ L.T)

tensor([[1.1607, 0.5735, 0.0607],
        [0.5735, 2.0470, 0.2167],
        [0.0607, 0.2167, 1.0000]])
tensor([[1.1607, 0.5735, 0.0607],
        [0.5735, 2.0470, 0.2167],
        [0.0607, 0.2167, 1.0000]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.0000, 2.0470, 0.2167],
        [0.0000, 0.2167, 1.0000]])
tensor([[1.1607, 0.5735, 0.0607],
        [0.5735, 2.0470, 0.2167],
        [0.0607, 0.2167, 1.0000]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.0000, 2.0470, 0.2167],
        [0.0000, 0.2167, 1.0000]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.0000, 2.0470, 0.2167],
        [0.0000, 0.2167, 1.0000]])


In [31]:
mu_U_ll_hat    = torch.from_numpy(np.array([0, 0, 0])).float()  
Sigma_U_ll_hat = torch.from_numpy(np.diag([1, 2, 1])).float() 

mu_U_hl_hat    = torch.from_numpy(np.array([0, 0])).float()  
Sigma_U_hl_hat = torch.from_numpy(np.diag([1, 1])).float()

In [32]:
mu_L    = mu_U_ll_hat
Sigma_L = Sigma_U_ll_hat

mu_H    = mu_U_hl_hat
Sigma_H = Sigma_U_hl_hat

l = mu_L.shape[0]
h = mu_H.shape[0]

lambda_L =.2
lambda_H =.3
eta      = .01
max_iter = 10

In [29]:
mu_L    = torch.from_numpy(mu_U_ll_hat).float()
Sigma_L = torch.from_numpy(Sigma_U_ll_hat).float()

mu_H    = torch.from_numpy(mu_U_hl_hat).float()
Sigma_H = torch.from_numpy(Sigma_U_hl_hat).float()

l = mu_L.shape[0]
h = mu_H.shape[0]

lambda_L =.2
lambda_H =.3
eta      = .01
max_iter = 10

TypeError: expected np.ndarray (got Tensor)

In [10]:
# # Define the necessary functions using PyTorch for automatic differentiation
# def F_func(mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, lambda_L, lambda_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta):
#     term1 = 0
#     term2 = 0
#     term3 = 0

#     # Loop to compute the sum of terms
#     for n, iota in enumerate(Ill):
#         L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  # Convert to float32
#         V_i = T @ L_i  # Matrix multiplication, ensure V_i is float32
#         H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  # Convert to float32

#         term1 += torch.norm(torch.matmul(V_i, mu_L) - torch.matmul(H_i, mu_H))**2 + torch.trace(torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))) + torch.trace(torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T)))

#     term2 = lambda_L * (epsilon**2 - torch.norm(mu_L - hat_mu_L)**2 - torch.norm(sqrtm_svd(Sigma_L) - sqrtm_svd(hat_Sigma_L))**2)
#     term3 = lambda_H * (delta**2 - torch.norm(mu_H - hat_mu_H)**2 - torch.norm(sqrtm_svd(Sigma_H) - sqrtm_svd(hat_Sigma_H))**2)

#     return term1 / n + term2 + term3

# # Proximal operator for Sigma_L (using soft-thresholding)
# def prox_Sigma_L(Sigma_L, lambda_L, LLmodels, HLmodels, Sigma_H):
#     # Using the Frobenius norm as a soft-thresholding operator for Sigma_L
#     prox = torch.zeros_like(Sigma_L)
#     for n, iota in enumerate(Ill):
#         L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
#         V_i = T @ L_i  
#         H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  

#         V_Sigma_V       = torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))
#         sqrtm_V_Sigma_V = sqrtm_svd(V_Sigma_V)
#         prox_i          = prox_operator(sqrtm_V_Sigma_V, lambda_L)
#         ll_term         = torch.linalg.pinv(V_i) @ torch.matmul(prox_i, prox_i.T) @ torch.linalg.pinv(V_i).T

#         H_Sigma_H       = torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T))
#         sqrtm_H_Sigma_H = sqrtm_svd(H_Sigma_H)
#         hl_term         = torch.norm(sqrtm_H_Sigma_H, p='fro') 
       
#         prox += ll_term * hl_term

#     prox *= (2 / n)
#     prox = diagonalize(prox)
#     return prox

# # Proximal operator for Sigma_H (using soft-thresholding)
# def prox_Sigma_H(Sigma_H, lambda_H, LLmodels, HLmodels, Sigma_L):
#     prox = torch.zeros_like(Sigma_H)
#     for n, iota in enumerate(Ill):
#         L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
#         V_i = T @ L_i  
#         H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  
       
#         H_Sigma_H       = torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T))
#         sqrtm_H_Sigma_H = sqrtm_svd(H_Sigma_H)
#         prox_i          = prox_operator(sqrtm_H_Sigma_H, lambda_H)
#         hl_term         = torch.linalg.inv(H_i) @ torch.matmul(prox_i, prox_i.T) @ torch.linalg.inv(H_i).T
#         #hl_term        = torch.inverse(H_i) @ torch.matmul(prox_i, prox_i.T) @ torch.inverse(H_i).T

#         V_Sigma_V       = torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))
#         sqrtm_V_Sigma_V = sqrtm_svd(V_Sigma_V)
#         ll_term         = torch.norm(sqrtm_V_Sigma_V, p='fro') 
        
#         prox     += ll_term * hl_term

#     prox *= (2 / n)

#     prox = diagonalize(prox)
#     return prox

# # Proximal operator of a matrix frobenious norm
# def prox_operator(A, lambda_param):
#     frobenius_norm = torch.norm(A, p='fro')
#     scaling_factor = torch.max(1 - lambda_param / frobenius_norm, torch.zeros_like(frobenius_norm))
#     return scaling_factor * A

# def diagonalize(A):
#     # Get eigenvalues and eigenvectors
#     eigvals, eigvecs = torch.linalg.eig(A)  
#     eigvals_real     = eigvals.real  
#     eigvals_real     = torch.sqrt(eigvals_real)  # Take the square root of the eigenvalues

#     return torch.diag(eigvals_real)

# def sqrtm_svd(A):
#     # Compute the SVD of A
#     U, S, V = torch.svd(A)
    
#     # Take the square root of the singular values
#     S_sqrt = torch.sqrt(torch.clamp(S, min=0.0))  # Ensure non-negative singular values
    
#     # Reconstruct the square root matrix
#     sqrt_A = U @ torch.diag(S_sqrt) @ V.T
    
#     return sqrt_A

# def sqrtm_eig(A):
#     eigvals, eigvecs = torch.linalg.eig(A)
#     eigvals_real = eigvals.real
    
#     # Ensure eigenvalues are non-negative for the square root to be valid
#     eigvals_sqrt = torch.sqrt(torch.clamp(eigvals_real, min=0.0))  # Square root of non-negative eigenvalues

#     # Reconstruct the square root of the matrix using the eigenvectors
#     # Make sure the eigenvectors are also real
#     eigvecs_real = eigvecs.real
    
#     # Reconstruct the matrix square root
#     sqrt_A = eigvecs_real @ torch.diag(eigvals_sqrt) @ eigvecs_real.T
    
#     return sqrt_A


# # Optimization loop using autograd and PyProximal (maximize using gradient ascent)
# def optimize(LLmodels, HLmodels, mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta, lambda_L, lambda_H, eta, max_iter):
#     mu_L.requires_grad_(True)  # Enable autograd for mu_L
#     Sigma_L_half.requires_grad_(True)  # Enable autograd for Sigma_L
#     mu_H.requires_grad_(True)  # Enable autograd for mu_H
#     Sigma_H_half.requires_grad_(True)  # Enable autograd for Sigma_H

#     for t in range(max_iter):
#         print(f"Iteration {t}")
        
#         objective = F_func(mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, lambda_L, lambda_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta)
#         objective.backward()

#         with torch.no_grad():
#             mu_L += eta * mu_L.grad  # Ascent for mu_L
#             mu_H += eta * mu_H.grad  # Ascent for mu_H
            
#             print(f"Sigma_L: {Sigma_L.grad}")
#             Sigma_L_half += eta * Sigma_L.grad  # Ascent for Sigma_L
#             Sigma_H += eta * Sigma_H.grad  # Ascent for Sigma_H
#             Sigma_L = prox_Sigma_L(Sigma_L_half, lambda_L, LLmodels, HLmodels, Sigma_H)
#             print(Sigma_L)  
#             Sigma_H = prox_Sigma_H(Sigma_H_half, lambda_H, LLmodels, HLmodels, Sigma_L)
            
#             #Zero the gradients after the update
#             mu_L.grad.zero_()
#             mu_H.grad.zero_()
#             Sigma_L.grad.zero_()
#             Sigma_H.grad.zero_()

#             # if mu_L.grad is not None:
#             #     mu_L.grad.zero_()
#             # if mu_H.grad is not None:
#             #     mu_H.grad.zero_()
#             # if Sigma_L.grad is not None:
#             #     Sigma_L.grad.zero_()
#             # if Sigma_H.grad is not None:
#             #     Sigma_H.grad.zero_()

#         # Print progress
#         if t % 10 == 0:
#             print(f"Iteration {t}, Objective Value: {objective.item()}")

#     return mu_L, Sigma_L, mu_H, Sigma_H


In [11]:
# def optimize_max(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, lambda_L, lambda_H, lambda_param, eta, num_steps_max):

#     for t in range(num_steps_max): 
#         #print('mu_L before update:', mu_L)
#         mu_L         = update_mu_L(T, mu_L, mu_H, LLmodels, HLmodels, lambda_L, hat_mu_L, eta)
#         # print('mu_L after update:', mu_L)
#         # print('mu_H before update:', mu_H)
#         mu_H         = update_mu_H(T, mu_L, mu_H, LLmodels, HLmodels, lambda_H, hat_mu_H, eta)
#         # print('mu_H after update:', mu_H)

#         # print('Sigma_L before update:', Sigma_L)
#         Sigma_L_half = update_Sigma_L_half(T, Sigma_L, LLmodels, lambda_L, hat_Sigma_L, eta)
#         Sigma_L      = update_Sigma_L(T, Sigma_L_half, LLmodels, Sigma_H, HLmodels, lambda_param)
#         # print('Sigma_L after update:', Sigma_L)
        
#         # print('Sigma_H before update:', Sigma_H)
#         Sigma_H_half = update_Sigma_H_half(T, Sigma_H, HLmodels, lambda_H, hat_Sigma_H, eta)
#         Sigma_H      = update_Sigma_H(T, Sigma_H_half, LLmodels, Sigma_L, HLmodels, lambda_param)
#         # print('Sigma_H after update:', Sigma_H)
        
#         mu_L, Sigma_L, mu_H, Sigma_H = enforce_constraints(mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta)
#         # print('mu_L after constraints:', mu_L)
#         # print('Sigma_L after constraints:', Sigma_L)
#         # print('mu_H after constraints:', mu_H)
#         # print('Sigma_H after constraints:', Sigma_H)
#         # print( )
#         # Compute the objective function for the current iteration
#         obj = 0
        
#         for i, iota in enumerate(Ill):
#             L_i = torch.from_numpy(LLmodels[iota].compute_mechanism())
#             V_i = T @ L_i.float()
#             H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()
                        
#             L_i_mu_L = V_i @ mu_L
#             H_i_mu_H = H_i @ mu_H
#             term1 = torch.norm(L_i_mu_L.float() - H_i_mu_H.float())**2
            
#             V_Sigma_V = V_i.float() @ Sigma_L.float() @ V_i.T.float()
#             H_Sigma_H = H_i.float() @ Sigma_H.float() @ H_i.T.float()

#             term2 = torch.trace(V_Sigma_V)
#             term3 = torch.trace(H_Sigma_H)
            
#             sqrtVSV = oput.sqrtm_svd(V_Sigma_V)
#             sqrtHSH = oput.sqrtm_svd(H_Sigma_H)

#             #term4 = -2*torch.trace(oput.sqrtm_svd(sqrtHSH @ V_Sigma_V @ sqrtHSH))
#             term4 = -2*torch.norm(oput.sqrtm_svd(sqrtVSV) @ oput.sqrtm_svd(sqrtHSH), 'nuc')
            
#             obj = obj + (term1 + term2 + term3 + term4)
        
#         obj = obj/i
        
#         print(f"Max step {t+1}/{num_steps_max}, Objective: {obj.item()}")

#     return mu_L, Sigma_L, mu_H, Sigma_H

# def optimize_min(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, num_steps_min, optimizer_T):

#     objective_T = 0  # Initialize the objective for this step

#     for step in range(num_steps_min):
#         objective_T = 0  # Reset objective at the start of each step
#         for n, iota in enumerate(Ill):
#             L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()
#             H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()

#             L_i_mu_L = L_i @ mu_L  
#             H_i_mu_H = H_i @ mu_H 

#             term1 = torch.norm(T @ L_i_mu_L - H_i_mu_H) ** 2
#             term2 = torch.trace(T @ L_i @ Sigma_L @ L_i.T @ T.T)
#             term3 = torch.trace(H_i @ Sigma_H @ H_i.T)
            
#             L_i_Sigma_L = T @ L_i @ Sigma_L @ L_i.T @ T.T
#             H_i_Sigma_H = H_i @ Sigma_H @ H_i.T

#             # Using the SVD square root term
#             term4 = -2 * torch.norm(oput.sqrtm_svd(L_i_Sigma_L) @ oput.sqrtm_svd(H_i_Sigma_H), 'nuc')

#             objective_T += term1 + term2 + term3 + term4

#         objective_T = objective_T/n

#         optimizer_T.zero_grad() # Clear previous gradients
#         objective_T.backward(retain_graph=True)  # Backpropagate to compute gradients
#         optimizer_T.step()      # Update T using the optimizer

#         print(f"Min step {step+1}/{num_steps_min}, Objective: {objective_T.item()}")

#     return objective_T, T  # Return both the objective and T

In [40]:
def compute_struc_matrices(models, I):
    matrices = []
    for iota in I:
        M_i = torch.from_numpy(models[iota].compute_mechanism()).float()  
        matrices.append(M_i)

    return matrices

def compute_mu_bary(struc_matrices, mu):
    struc_matrices_tensor = torch.stack(struc_matrices)
    mu_barycenter         = torch.sum(struc_matrices_tensor @ mu, dim=0) / len(struc_matrices)

    return mu_barycenter

def compute_Sigma_bary(matrices, Sigma, initialization, max_iter, tol):

    Sigma_matrices = []
    for M in matrices:
        Sigma_matrices.append(M @ Sigma @ M.T)

    return covariance_bary_optim(Sigma_matrices, initialization, max_iter, tol)

def covariance_bary_optim(Sigma_list, initialization, max_iter, tol):
    
    if initialization == 'psd':
        S_0 = create_psd_matrix(Sigma_list[0].shape[0])
    elif initialization == 'avg':
        S_0 = sum(Sigma_list) / len(Sigma_list)
    
    S_n = S_0.clone()
    n   = len(Sigma_list)  # Number of matrices
    lambda_j = 1.0 / n   # Equal weights
    
    for n in range(max_iter):
        S_n_old = S_n.clone()

        S_n_inv_half = oput.sqrtm_svd(regmat(torch.inverse(S_n)))
        
        # Compute the sum of S_n^(1/2) Σ_j S_n^(1/2)
        sum_term = torch.zeros_like(S_n)
        for Sigma_j in Sigma_list:
            S_n_half   = oput.sqrtm_svd(regmat(S_n))
            inner_term = torch.matmul(torch.matmul(S_n_half, Sigma_j), S_n_half)
            sqrt_term  = oput.sqrtm_svd(regmat(inner_term))
            sum_term  += lambda_j * sqrt_term
        # Square the sum term
        squared_sum = torch.matmul(sum_term, sum_term.T)

        S_n_next = torch.matmul(torch.matmul(S_n_inv_half, squared_sum), S_n_inv_half)
        S_n = S_n_next

        if torch.norm(S_n - S_n_old, p='fro') < tol:
            print(f"Converged after {n+1} iterations")
            break
            
    return S_n

def monge(m1, S1, m2, S2):
    inner      = torch.matmul(oput.sqrtm_svd(S1), torch.matmul(S2, oput.sqrtm_svd(S1)))
    sqrt_inner = oput.sqrtm_svd(inner)
    A          = torch.matmul(torch.inverse(oput.sqrtm_svd(regmat(S1))), torch.matmul(sqrt_inner, torch.inverse(oput.sqrtm_svd(regmat(S1)))))  

    # Define the Monge map as a function τ(x) = m_2 + A(x - m_1)
    def tau(x):
        return m2 + A @ (x - m1)

    return tau, A

def regmat(matrix, eps=1e-10):
    # Replace NaN and Inf values with finite numbers
    matrix = torch.nan_to_num(matrix, nan=0.0, posinf=1e10, neginf=-1e10)
    
    # Add a small epsilon to the diagonal for numerical stability
    if matrix.dim() == 2 and matrix.size(0) == matrix.size(1):
        matrix = matrix + eps * torch.eye(matrix.size(0), device=matrix.device)
    
    return matrix


In [41]:
def create_psd_matrix(size):
    A = torch.randn(size, size).float()

    return torch.matmul(A, A.T)

# PCA Projection from higher to lower dimension
def pca_projection(Sigma, target_dim):
    """
    Project a d×d matrix to a k×k matrix where k < d
    Args:
        Sigma: source matrix (d×d)
        target_dim: target dimension k
    Returns:
        k×k projected matrix
    """
    # Perform eigenvalue decomposition
    eigenvalues, eigenvectors = torch.linalg.eigh(Sigma)
    
    # Sort eigenvalues and eigenvectors in descending order
    sorted_indices = torch.argsort(eigenvalues, descending=True)
    eigenvalues = eigenvalues[sorted_indices]
    eigenvectors = eigenvectors[:, sorted_indices]
    
    # Take only the top target_dim eigenvectors
    V = eigenvectors[:, :target_dim]  # d×k matrix
    
    # Project the covariance matrix
    Sigma_projected = torch.matmul(torch.matmul(V.T, Sigma), V)  # k×k matrix
    
    return Sigma_projected, V

# SVD Projection from higher to lower dimension
def svd_projection(Sigma, target_dim):
    """
    Project a d×d matrix to a k×k matrix where k < d using SVD
    Args:
        Sigma: source matrix (d×d)
        target_dim: target dimension k
    Returns:
        k×k projected matrix
    """
    # Perform SVD
    U, S, V = torch.svd(Sigma)
    
    # Take only the first target_dim components
    U_k = U[:, :target_dim]  # d×k matrix
    S_k = S[:target_dim]     # k singular values
    
    # Project the covariance matrix
    Sigma_projected = torch.matmul(torch.matmul(U_k.T, Sigma), U_k)  # k×k matrix
    
    return Sigma_projected, U_k

def project_covariance(Sigma, n, method):
    if method == 'pca':
        return pca_projection(Sigma, n)
    elif method == 'svd':
        return svd_projection(Sigma, n)
    else:
        raise ValueError(f"Unknown projection method: {method}")

In [68]:
def barycentric_optimization(mu_L, mu_H, Sigma_L, Sigma_H, LLmodels, HLmodels, Ill, Ihl, projection_method, initialization, max_iter, tol):

    h, l = mu_H.shape[0], mu_L.shape[0]

    # Initialize the structural matrices    
    L_matrices = compute_struc_matrices(LLmodels, Ill)
    H_matrices = compute_struc_matrices(HLmodels, Ihl)

    # Initilize the barycenteric means and covariances
    print("Computing barycentric mu_L")
    mu_bary_L = compute_mu_bary(L_matrices, mu_L)
    print("mu_bary_L:", mu_bary_L)  
    print("\nComputing barycentric mu_H")
    mu_bary_H = compute_mu_bary(H_matrices, mu_H)
    print("mu_bary_H:", mu_bary_H)  

    print("\nComputing barycentric Sigma_L")
    Sigma_bary_L = compute_Sigma_bary(L_matrices, Sigma_L, initialization, max_iter, tol)
    print(Sigma_bary_L)
    print("\nComputing barycentric Sigma_H")
    Sigma_bary_H = compute_Sigma_bary(H_matrices, Sigma_H, initialization, max_iter, tol)
    print(Sigma_bary_H)
    
    proj_Sigma_bary_L, Tp = project_covariance(Sigma_bary_L, h, projection_method)
    proj_mu_bary_L        = torch.matmul(Tp.T, mu_bary_L)

    tau, A = monge(proj_mu_bary_L, proj_Sigma_bary_L, mu_bary_H, Sigma_bary_H)

    T = torch.matmul(A, Tp.T)

    return tau, T, Tp, mu_bary_L, Sigma_bary_L, mu_bary_H, Sigma_bary_H

In [70]:
tau, T, Tp, mu_bary_L, Sigma_bary_L, mu_bary_H, Sigma_bary_H = barycentric_optimization(mu_L, mu_H, Sigma_L, Sigma_H,
                                                                                        LLmodels, HLmodels, Ill, Ihl,
                                                                                        'svd', 'avg', 100, 1e-6)

Computing barycentric mu_L
mu_bary_L: tensor([0., 0., 0.])

Computing barycentric mu_H
mu_bary_H: tensor([0., 0.])

Computing barycentric Sigma_L
tensor([[-0.1602, -0.1863, -1.3203],
        [-1.2593, -1.3067, -1.3496],
        [-1.1682,  0.0345, -0.3648]])

Computing barycentric Sigma_H
Converged after 1 iterations
tensor([[1.3602, 0.6002],
        [0.6002, 1.0000]])


In [57]:
x = torch.randn(10, l)
for x_i in x:
    print(T@x_i)

tensor([-1.3200, -0.6395])
tensor([ 0.4221, -0.6327])
tensor([-1.5095, -0.3283])
tensor([1.8620, 0.5970])
tensor([-0.2151, -0.3287])
tensor([ 0.3774, -0.0894])
tensor([1.6335, 0.2247])
tensor([0.4873, 0.4357])
tensor([-0.0981, -1.5215])
tensor([-1.0676, -0.5375])


In [45]:
y = Tp.T@x[0]



tensor([-0.6622, -0.9276])
tensor([-0.6622, -0.9276])


In [67]:
# from tqdm import tqdm
# import time
# from datetime import datetime, timedelta
# import torch

# def barycentric_optimization(mu_L, mu_H, Sigma_L, Sigma_H, 
#                            LLmodels, HLmodels, Ill, Ihl, 
#                            projection_method, initialization, 
#                            max_iter, tol, pbar=None):

#     start_time = time.time()
#     h, l = mu_H.shape[0], mu_L.shape[0]

#     # Initialize progress tracking
#     if pbar:
#         pbar.set_postfix({'stage': 'Initializing'})
#         pbar.update(0)

#     # Initialize the structural matrices    
#     L_matrices = compute_struc_matrices(LLmodels, Ill)
#     H_matrices = compute_struc_matrices(HLmodels, Ihl)

#     if pbar:
#         pbar.set_postfix({'stage': 'Computing barycentric means'})
#         pbar.update(10)  # Update progress by 10%

#     # Initialize the barycentric means and covariances
#     print("Computing barycentric mu_L")
#     mu_bary_L = compute_mu_bary(L_matrices, mu_L)
#     print("mu_bary_L:", mu_bary_L)  

#     if pbar:
#         pbar.update(10)

#     print("\nComputing barycentric mu_H")
#     mu_bary_H = compute_mu_bary(H_matrices, mu_H)
#     print("mu_bary_H:", mu_bary_H)  

#     if pbar:
#         pbar.set_postfix({'stage': 'Computing barycentric covariances'})
#         pbar.update(20)

#     print("\nComputing barycentric Sigma_L")
#     Sigma_bary_L = compute_Sigma_bary(L_matrices, Sigma_L, initialization, max_iter, tol)
#     print(Sigma_bary_L)

#     if pbar:
#         pbar.update(20)

#     print("\nComputing barycentric Sigma_H")
#     Sigma_bary_H = compute_Sigma_bary(H_matrices, Sigma_H, initialization, max_iter, tol)
#     print(Sigma_bary_H)
    
#     if pbar:
#         pbar.set_postfix({'stage': 'Projecting covariance'})
#         pbar.update(20)

#     proj_Sigma_bary_L, Tp = project_covariance(Sigma_bary_L, h, projection_method)
#     proj_mu_bary_L = torch.matmul(Tp.T, mu_bary_L)

#     if pbar:
#         pbar.set_postfix({'stage': 'Computing Monge map'})
#         pbar.update(10)

#     tau, A = monge(proj_mu_bary_L, proj_Sigma_bary_L, mu_bary_H, Sigma_bary_H)

#     if pbar:
#         pbar.set_postfix({
#             'stage': 'Finalizing',
#             'tau': f'{tau:.4f}' if isinstance(tau, (int, float)) else 'N/A'
#         })
#         pbar.update(10)

#     T = torch.matmul(A, Tp.T)

#     return tau, T, Tp, mu_bary_L, Sigma_bary_L, mu_bary_H, Sigma_bary_H

# # Main execution wrapper
# def run_barycentric_optimization(mu_L, mu_H, Sigma_L, Sigma_H,
#                                LLmodels, HLmodels, Ill, Ihl,
#                                projection_method='svd', 
#                                initialization='avg', 
#                                max_iter=100, 
#                                tol=1e-5):
    
#     start_time = time.time()
#     start_datetime = datetime.now()

#     print(f"\nStarting Barycentric Optimization at {start_datetime}")
#     print(f"Maximum iterations: {max_iter}")
#     print("Parameters:")
#     print(f"- Projection method: {projection_method}")
#     print(f"- Initialization: {initialization}")
#     print(f"- Tolerance: {tol}")

#     # Create progress bar (100 total steps for all stages)
#     with tqdm(total=100, desc="Barycentric Optimization") as pbar:
#         try:
#             # Run optimization with progress tracking
#             tau, T, Tp, mu_bary_L, Sigma_bary_L, mu_bary_H, Sigma_bary_H = barycentric_optimization(
#                 mu_L=mu_L, 
#                 mu_H=mu_H, 
#                 Sigma_L=Sigma_L, 
#                 Sigma_H=Sigma_H,
#                 LLmodels=LLmodels, 
#                 HLmodels=HLmodels, 
#                 Ill=Ill, 
#                 Ihl=Ihl,
#                 projection_method=projection_method, 
#                 initialization=initialization, 
#                 max_iter=max_iter, 
#                 tol=tol,
#                 pbar=pbar
#             )
            
#             # Calculate execution time
#             end_time = time.time()
#             execution_time = end_time - start_time
            
#             # Print timing information
#             print("\nOptimization Complete!")
#             print(f"Started at: {start_datetime}")
#             print(f"Finished at: {datetime.now()}")
#             print(f"Total execution time: {timedelta(seconds=int(execution_time))}")
            
#             # Print final results
#             print("\nFinal Results:")
#             print(f"tau: {tau}")
#             print(f"T shape: {T.shape}")
#             print(f"Tp shape: {Tp.shape}")
            
#             return tau, T, Tp, mu_bary_L, Sigma_bary_L, mu_bary_H, Sigma_bary_H

#         except Exception as e:
#             print(f"\nError during optimization: {e}")
#             end_time = time.time()
#             print(f"Time until error: {timedelta(seconds=int(end_time - start_time))}")
#             raise

# # Usage
# tau, T, Tp, mu_bary_L, Sigma_bary_L, mu_bary_H, Sigma_bary_H = run_barycentric_optimization(
#     mu_L, mu_H, Sigma_L, Sigma_H,
#     LLmodels, HLmodels, Ill, Ihl,
#     projection_method='svd',
#     initialization='psd',
#     max_iter=100,
#     tol=1e-5
# )


Starting Barycentric Optimization at 2024-12-02 16:20:10.238005
Maximum iterations: 100
Parameters:
- Projection method: svd
- Initialization: psd
- Tolerance: 1e-05


Barycentric Optimization:  20%|██        | 20/100 [00:00<00:00, 694.66it/s, stage=Computing barycentric covariances]

Computing barycentric mu_L
mu_bary_L: tensor([0., 0., 0.])

Computing barycentric mu_H
mu_bary_H: tensor([0., 0.])

Computing barycentric Sigma_L


Barycentric Optimization: 100%|██████████| 100/100 [00:01<00:00, 78.40it/s, stage=Finalizing, tau=N/A]              

tensor([[-0.8460, -0.0126, -1.2985],
        [-1.9967, -1.2195, -1.6254],
        [-1.0787, -0.1480,  0.2046]])

Computing barycentric Sigma_H
Converged after 2 iterations
tensor([[1.3602, 0.6002],
        [0.6002, 1.0000]])

Optimization Complete!
Started at: 2024-12-02 16:20:10.238005
Finished at: 2024-12-02 16:20:11.525951
Total execution time: 0:00:01

Final Results:
tau: <function monge.<locals>.tau at 0x19a759580>
T shape: torch.Size([2, 3])
Tp shape: torch.Size([3, 2])





Barycentric gradient

In [36]:
def optimize_min(mu_L, Sigma_L, mu_H, Sigma_H, num_steps, seed, tol=1e-2):

    # Set seeds for reproducibility
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    T = torch.randn(mu_H.shape[0], mu_L.shape[0], requires_grad=True)

    optimizer_T        = torch.optim.Adam([T], lr=0.01)
    previous_objective = float('inf')
    objective_T        = 0  # Reset objective at the start of each step
    # Optimization loop
    for step in range(num_steps):
        objective_T = 0  # Reset objective at the start of each step

        # Calculate each term of the Wasserstein distance
        term1 = torch.norm(T @ mu_L - mu_H) ** 2  # Squared Euclidean distance between transformed means
        term2 = torch.trace(T @ Sigma_L @ T.T)   # Trace term for low-level covariance
        term3 = torch.trace(Sigma_H)             # Trace term for high-level covariance
        
        # Compute the intermediate covariance matrices
        T_Sigma_L_T      = torch.matmul(T, torch.matmul(Sigma_L, T.T))
        T_Sigma_L_T_sqrt = oput.sqrtm_svd(T_Sigma_L_T)
        Sigma_H_sqrt     = oput.sqrtm_svd(Sigma_H)
        
        # Coupling term using nuclear norm
        term4 = -2 * torch.norm(T_Sigma_L_T_sqrt @ Sigma_H_sqrt, p='nuc')

        # Total objective is the sum of terms
        objective_T += term1 + term2 + term3 + term4

        if abs(previous_objective - objective_T.item()) < tol:
            print(f"Converged at step {step + 1}/{num_steps_min} with objective: {objective_T.item()}")
            break

        # Update previous objective
        previous_objective = objective_T.item()

        # Perform optimization step
        optimizer_T.zero_grad()  # Clear gradients
        objective_T.backward(retain_graph=True)  # Backpropagate
        optimizer_T.step()  # Update T

        print(f"Min step {step+1}/{num_steps}, Objective: {objective_T.item()}")

    return objective_T.item(), T  # Return final objective and optimized T


In [37]:
# Run optimization
num_steps = 1000
seed      = 42

final_objective, optimized_T = optimize_min(mu_bary_L, Sigma_bary_L, mu_bary_H, Sigma_bary_H, num_steps, seed)

print(f"Final Objective: {final_objective}")
print(f"Optimized T: {optimized_T}")


Min step 1/1000, Objective: -0.6442282199859619
Min step 2/1000, Objective: -0.7537491321563721
Min step 3/1000, Objective: -0.8608345985412598
Min step 4/1000, Objective: -0.965545654296875
Min step 5/1000, Objective: -1.0683064460754395
Min step 6/1000, Objective: -1.1698195934295654
Min step 7/1000, Objective: -1.2707304954528809
Min step 8/1000, Objective: -1.3716082572937012
Min step 9/1000, Objective: -1.4730181694030762
Min step 10/1000, Objective: -1.5754728317260742
Min step 11/1000, Objective: -1.679353952407837
Min step 12/1000, Objective: -1.7849090099334717
Min step 13/1000, Objective: -1.8922863006591797
Min step 14/1000, Objective: -2.0015642642974854
Min step 15/1000, Objective: -2.1127963066101074
Min step 16/1000, Objective: -2.226024627685547
Min step 17/1000, Objective: -2.3412821292877197
Min step 18/1000, Objective: -2.4586050510406494
Min step 19/1000, Objective: -2.5780253410339355
Min step 20/1000, Objective: -2.6995835304260254
Min step 21/1000, Objective: -2.

In [38]:
mu_U_hl_hat

tensor([0., 0.])

In [39]:
mu_bary_L

tensor([0., 0., 0.])