In [1]:
import joblib
import numpy as np
import torch

import modularised_utils as mut
import matplotlib.pyplot as plt

import opt_utils as oput

import Linear_Additive_Noise_Models as lanm
import operations as ops
from scipy.linalg import sqrtm

import params

np.random.seed(0)

In [2]:
experiment = 'synth1'

In [3]:
# Define the radius of the Wasserstein balls (epsilon, delta) and the size for both models.
epsilon         = params.radius[experiment][0]
ll_num_envs     = params.n_envs[experiment][0]

delta           = params.radius[experiment][1]
hl_num_envs     = params.n_envs[experiment][1]

# Define the number of samples per environment. Currently every environment has the same number of samples
num_llsamples   = params.n_samples[experiment][0]
num_hlsamples   = params.n_samples[experiment][1]

In [4]:
Dll = mut.load_samples(experiment)[None][0] 
Gll = mut.load_ll_model(experiment)[0]
Ill = mut.load_ll_model(experiment)[1]


Dhl = mut.load_samples(experiment)[None][1] 
Ghl = mut.load_hl_model(experiment)[0]
Ihl = mut.load_hl_model(experiment)[1]

omega = mut.load_omega_map(experiment)

In [5]:
ll_coeffs = mut.get_coefficients(Dll, Gll)
hl_coeffs = mut.get_coefficients(Dhl, Ghl) 

In [6]:
# # [Not suggested] In case we want to explore also the interventional --> worse estimation!
# Dlls, Dhls = [], []
# for dpair in list(mut.load_samples(experiment).values()):
#     Dlls.append(dpair[0])
#     Dhls.append(dpair[1])
    
# ll_coeffs = mut.get_coefficients(Dlls, Gll)
# hl_coeffs = mut.get_coefficients(Dhls, Ghl) 

In [7]:
U_ll_hat, mu_U_ll_hat, Sigma_U_ll_hat = mut.lan_abduction(Dll, Gll, ll_coeffs)
U_hl_hat, mu_U_hl_hat, Sigma_U_hl_hat = mut.lan_abduction(Dhl, Ghl, hl_coeffs)

In [8]:
LLmodels = {}
for iota in Ill:
    LLmodels[iota] = lanm.LinearAddSCM(Gll, ll_coeffs, iota)
    
HLmodels, Dhl_samples = {}, {}
for eta in Ihl:
    HLmodels[eta] = lanm.LinearAddSCM(Ghl, hl_coeffs, eta)

In [18]:
mu_L    = torch.from_numpy(mu_U_ll_hat)
Sigma_L = torch.from_numpy(Sigma_U_ll_hat)

mu_H    = torch.from_numpy(mu_U_hl_hat)
Sigma_H = torch.from_numpy(Sigma_U_hl_hat)

l = mu_L.shape[0]
h = mu_H.shape[0]

# Given estimates (mu_L, Sigma_L, mu_H, Sigma_H)
hat_mu_L    =  torch.from_numpy(mu_U_ll_hat) 
hat_Sigma_L =  torch.from_numpy(Sigma_U_ll_hat)

hat_mu_H    =  torch.from_numpy(mu_U_hl_hat)
hat_Sigma_H =  torch.from_numpy(Sigma_U_hl_hat)

lambda_L =.2
lambda_H =.3
eta      = .01
max_iter = 10

Sigma_L = hat_Sigma_L
Sigma_H = hat_Sigma_H


T = torch.exp(torch.from_numpy(np.random.randn(2, 3)).float())

In [52]:
# Define the necessary functions using PyTorch for automatic differentiation
def F_func(mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, lambda_L, lambda_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta):
    term1 = 0
    term2 = 0
    term3 = 0

    # Loop to compute the sum of terms
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  # Convert to float32
        V_i = T @ L_i  # Matrix multiplication, ensure V_i is float32
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  # Convert to float32

        term1 += torch.norm(torch.matmul(V_i, mu_L) - torch.matmul(H_i, mu_H))**2 + torch.trace(torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))) + torch.trace(torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T)))

    term2 = lambda_L * (epsilon**2 - torch.norm(mu_L - hat_mu_L)**2 - torch.norm(sqrtm_svd(Sigma_L) - sqrtm_svd(hat_Sigma_L))**2)
    term3 = lambda_H * (delta**2 - torch.norm(mu_H - hat_mu_H)**2 - torch.norm(sqrtm_svd(Sigma_H) - sqrtm_svd(hat_Sigma_H))**2)

    return term1 / n + term2 + term3

# Proximal operator for Sigma_L (using soft-thresholding)
def prox_Sigma_L(Sigma_L, lambda_L, LLmodels, HLmodels, Sigma_H):
    # Using the Frobenius norm as a soft-thresholding operator for Sigma_L
    prox = torch.zeros_like(Sigma_L)
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
        V_i = T @ L_i  
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  

        V_Sigma_V       = torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))
        sqrtm_V_Sigma_V = sqrtm_svd(V_Sigma_V)
        prox_i          = prox_operator(sqrtm_V_Sigma_V, lambda_L)
        ll_term         = torch.linalg.pinv(V_i) @ torch.matmul(prox_i, prox_i.T) @ torch.linalg.pinv(V_i).T

        H_Sigma_H       = torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T))
        sqrtm_H_Sigma_H = sqrtm_svd(H_Sigma_H)
        hl_term         = torch.norm(sqrtm_H_Sigma_H, p='fro') 
       
        prox += ll_term * hl_term

    prox *= (2 / n)
    prox = diagonalize(prox)
    return prox

# Proximal operator for Sigma_H (using soft-thresholding)
def prox_Sigma_H(Sigma_H, lambda_H, LLmodels, HLmodels, Sigma_L):
    prox = torch.zeros_like(Sigma_H)
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
        V_i = T @ L_i  
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  
       
        H_Sigma_H       = torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T))
        sqrtm_H_Sigma_H = sqrtm_svd(H_Sigma_H)
        prox_i          = prox_operator(sqrtm_H_Sigma_H, lambda_H)
        hl_term         = torch.linalg.inv(H_i) @ torch.matmul(prox_i, prox_i.T) @ torch.linalg.inv(H_i).T
        #hl_term        = torch.inverse(H_i) @ torch.matmul(prox_i, prox_i.T) @ torch.inverse(H_i).T

        V_Sigma_V       = torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))
        sqrtm_V_Sigma_V = sqrtm_svd(V_Sigma_V)
        ll_term         = torch.norm(sqrtm_V_Sigma_V, p='fro') 
        
        prox     += ll_term * hl_term

    prox *= (2 / n)

    prox = diagonalize(prox)
    return prox

# Proximal operator of a matrix frobenious norm
def prox_operator(A, lambda_param):
    frobenius_norm = torch.norm(A, p='fro')
    scaling_factor = torch.max(1 - lambda_param / frobenius_norm, torch.zeros_like(frobenius_norm))
    return scaling_factor * A

def diagonalize(A):
    # Get eigenvalues and eigenvectors
    eigvals, eigvecs = torch.linalg.eig(A)  
    eigvals_real     = eigvals.real  
    eigvals_real     = torch.sqrt(eigvals_real)  # Take the square root of the eigenvalues

    return torch.diag(eigvals_real)

def sqrtm_svd(A):
    # Compute the SVD of A
    U, S, V = torch.svd(A)
    
    # Take the square root of the singular values
    S_sqrt = torch.sqrt(torch.clamp(S, min=0.0))  # Ensure non-negative singular values
    
    # Reconstruct the square root matrix
    sqrt_A = U @ torch.diag(S_sqrt) @ V.T
    
    return sqrt_A

def sqrtm_eig(A):
    eigvals, eigvecs = torch.linalg.eig(A)
    eigvals_real = eigvals.real
    
    # Ensure eigenvalues are non-negative for the square root to be valid
    eigvals_sqrt = torch.sqrt(torch.clamp(eigvals_real, min=0.0))  # Square root of non-negative eigenvalues

    # Reconstruct the square root of the matrix using the eigenvectors
    # Make sure the eigenvectors are also real
    eigvecs_real = eigvecs.real
    
    # Reconstruct the matrix square root
    sqrt_A = eigvecs_real @ torch.diag(eigvals_sqrt) @ eigvecs_real.T
    
    return sqrt_A


# Optimization loop using autograd and PyProximal (maximize using gradient ascent)
def optimize(LLmodels, HLmodels, mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta, lambda_L, lambda_H, eta, max_iter):
    mu_L.requires_grad_(True)  # Enable autograd for mu_L
    Sigma_L_half.requires_grad_(True)  # Enable autograd for Sigma_L
    mu_H.requires_grad_(True)  # Enable autograd for mu_H
    Sigma_H_half.requires_grad_(True)  # Enable autograd for Sigma_H

    for t in range(max_iter):
        print(f"Iteration {t}")
        
        objective = F_func(mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, lambda_L, lambda_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta)
        objective.backward()

        with torch.no_grad():
            mu_L += eta * mu_L.grad  # Ascent for mu_L
            mu_H += eta * mu_H.grad  # Ascent for mu_H
            
            print(f"Sigma_L: {Sigma_L.grad}")
            Sigma_L_half += eta * Sigma_L.grad  # Ascent for Sigma_L
            Sigma_H += eta * Sigma_H.grad  # Ascent for Sigma_H
            Sigma_L = prox_Sigma_L(Sigma_L_half, lambda_L, LLmodels, HLmodels, Sigma_H)
            print(Sigma_L)  
            Sigma_H = prox_Sigma_H(Sigma_H_half, lambda_H, LLmodels, HLmodels, Sigma_L)
            
            #Zero the gradients after the update
            mu_L.grad.zero_()
            mu_H.grad.zero_()
            Sigma_L.grad.zero_()
            Sigma_H.grad.zero_()

            # if mu_L.grad is not None:
            #     mu_L.grad.zero_()
            # if mu_H.grad is not None:
            #     mu_H.grad.zero_()
            # if Sigma_L.grad is not None:
            #     Sigma_L.grad.zero_()
            # if Sigma_H.grad is not None:
            #     Sigma_H.grad.zero_()

        # Print progress
        if t % 10 == 0:
            print(f"Iteration {t}, Objective Value: {objective.item()}")

    return mu_L, Sigma_L, mu_H, Sigma_H


In [53]:
hat_mu_L    = torch.from_numpy(mu_U_ll_hat).float()
hat_Sigma_L = torch.from_numpy(Sigma_U_ll_hat).float()

hat_mu_H    = torch.from_numpy(mu_U_hl_hat).float()
hat_Sigma_H = torch.from_numpy(Sigma_U_hl_hat).float()

l = hat_mu_L.shape[0]
h = hat_mu_H.shape[0]


# Gelbrich initialization
ll_moments      = mut.sample_moments_U(mu_hat = mu_U_ll_hat, Sigma_hat = Sigma_U_ll_hat, bound = epsilon, num_envs = 1)
mu_L0, Sigma_L0 = ll_moments[0]
#mu_L0, Sigma_L0 = torch.from_numpy(mu_L0), torch.from_numpy(Sigma_L0)

hl_moments      = mut.sample_moments_U(mu_hat = mu_U_hl_hat, Sigma_hat = Sigma_U_hl_hat, bound = delta, num_envs = 1)
mu_H0, Sigma_H0 = hl_moments[0]
#mu_H0, Sigma_H0 = torch.from_numpy(mu_H0), torch.from_numpy(Sigma_H0)


T = torch.exp(torch.from_numpy(np.random.randn(2, 3)).float())

In [None]:
def optimize_max(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, lambda_L, lambda_H, lambda_param, eta, num_steps_max):

    for t in range(num_steps_max): 
        #print('mu_L before update:', mu_L)
        mu_L         = update_mu_L(T, mu_L, mu_H, LLmodels, HLmodels, lambda_L, hat_mu_L, eta)
        # print('mu_L after update:', mu_L)
        # print('mu_H before update:', mu_H)
        mu_H         = update_mu_H(T, mu_L, mu_H, LLmodels, HLmodels, lambda_H, hat_mu_H, eta)
        # print('mu_H after update:', mu_H)

        # print('Sigma_L before update:', Sigma_L)
        Sigma_L_half = update_Sigma_L_half(T, Sigma_L, LLmodels, lambda_L, hat_Sigma_L, eta)
        Sigma_L      = update_Sigma_L(T, Sigma_L_half, LLmodels, Sigma_H, HLmodels, lambda_param)
        # print('Sigma_L after update:', Sigma_L)
        
        # print('Sigma_H before update:', Sigma_H)
        Sigma_H_half = update_Sigma_H_half(T, Sigma_H, HLmodels, lambda_H, hat_Sigma_H, eta)
        Sigma_H      = update_Sigma_H(T, Sigma_H_half, LLmodels, Sigma_L, HLmodels, lambda_param)
        # print('Sigma_H after update:', Sigma_H)
        
        mu_L, Sigma_L, mu_H, Sigma_H = enforce_constraints(mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta)
        # print('mu_L after constraints:', mu_L)
        # print('Sigma_L after constraints:', Sigma_L)
        # print('mu_H after constraints:', mu_H)
        # print('Sigma_H after constraints:', Sigma_H)
        # print( )
        # Compute the objective function for the current iteration
        obj = 0
        
        for i, iota in enumerate(Ill):
            L_i = torch.from_numpy(LLmodels[iota].compute_mechanism())
            V_i = T @ L_i.float()
            H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()
                        
            L_i_mu_L = V_i @ mu_L
            H_i_mu_H = H_i @ mu_H
            term1 = torch.norm(L_i_mu_L.float() - H_i_mu_H.float())**2
            
            V_Sigma_V = V_i.float() @ Sigma_L.float() @ V_i.T.float()
            H_Sigma_H = H_i.float() @ Sigma_H.float() @ H_i.T.float()

            term2 = torch.trace(V_Sigma_V)
            term3 = torch.trace(H_Sigma_H)
            
            sqrtVSV = oput.sqrtm_svd(V_Sigma_V)
            sqrtHSH = oput.sqrtm_svd(H_Sigma_H)

            #term4 = -2*torch.trace(oput.sqrtm_svd(sqrtHSH @ V_Sigma_V @ sqrtHSH))
            term4 = -2*torch.norm(oput.sqrtm_svd(sqrtVSV) @ oput.sqrtm_svd(sqrtHSH), 'nuc')
            
            obj = obj + (term1 + term2 + term3 + term4)
        
        obj = obj/i
        
        print(f"Max step {t+1}/{num_steps_max}, Objective: {obj.item()}")

    return mu_L, Sigma_L, mu_H, Sigma_H

def optimize_min(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, num_steps_min, optimizer_T):

    objective_T = 0  # Initialize the objective for this step

    for step in range(num_steps_min):
        objective_T = 0  # Reset objective at the start of each step
        for n, iota in enumerate(Ill):
            L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()
            H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()

            L_i_mu_L = L_i @ mu_L  
            H_i_mu_H = H_i @ mu_H 

            term1 = torch.norm(T @ L_i_mu_L - H_i_mu_H) ** 2
            term2 = torch.trace(T @ L_i @ Sigma_L @ L_i.T @ T.T)
            term3 = torch.trace(H_i @ Sigma_H @ H_i.T)
            
            L_i_Sigma_L = T @ L_i @ Sigma_L @ L_i.T @ T.T
            H_i_Sigma_H = H_i @ Sigma_H @ H_i.T

            # Using the SVD square root term
            term4 = -2 * torch.norm(oput.sqrtm_svd(L_i_Sigma_L) @ oput.sqrtm_svd(H_i_Sigma_H), 'nuc')

            objective_T += term1 + term2 + term3 + term4

        objective_T = objective_T/n

        optimizer_T.zero_grad() # Clear previous gradients
        objective_T.backward(retain_graph=True)  # Backpropagate to compute gradients
        optimizer_T.step()      # Update T using the optimizer

        print(f"Min step {step+1}/{num_steps_min}, Objective: {objective_T.item()}")

    return objective_T, T  # Return both the objective and T

In [56]:
mu_L    = torch.from_numpy(mu_L0).float()
Sigma_L = torch.from_numpy(Sigma_L0).float()
mu_H    = torch.from_numpy(mu_H0).float()
Sigma_H = torch.from_numpy(Sigma_H0).float()

mu_L, Sigma_L, mu_H, Sigma_H = optimize(LLmodels, HLmodels, mu_L.float(), Sigma_L.float(), mu_H.float(), Sigma_H.float(),
                                         hat_mu_L.float(), hat_Sigma_L.float(), hat_mu_H.float(), hat_Sigma_H.float(),
                                           epsilon, delta, lambda_L=.7, lambda_H=.8, eta=.01, max_iter=10)


Iteration 0
Sigma_L: tensor([[0.4711, 1.2888, 1.0865],
        [1.2888, 5.5278, 3.1481],
        [1.0865, 3.1481, 2.5457]])
Updated Sigma_L: tensor([[0.0066, 0.0000, 0.0000],
        [0.0000, 1.6441, 0.0000],
        [0.0000, 0.0000, 2.2323]])


AttributeError: 'NoneType' object has no attribute 'zero_'

In [None]:
def wass_mean_barycenter(struc_matrices, mu):

    n = len(struc_matrices)

    mu_barycenter = np.sum([S @ mu for S in struc_matrices], axis=0) / n


    return mu_barycenter
    
    # Initialize the covariance matrix
    Sigma_barycenter = Sigma  # Start with the shared covariance matrix
    
    # Iterate to refine Sigma_barycenter
    for iteration in range(max_iter):
        Sigma_barycenter_half = sqrtm(Sigma_barycenter)  # Compute sqrt(Σ_barycenter)
        sum_term = np.zeros_like(Sigma)
        
        for i in range(n):
            # Calculate the transformed covariance L_i * Sigma_L * L_i^T
            struc_i_transformed = struc_matrices[i] @ Sigma @ struc_matrices[i].T
            # Compute the square root of the term
            term = sqrtm(Sigma_barycenter_half @ struc_i_transformed @ Sigma_barycenter_half)
            sum_term += term
        
        # Update Sigma_barycenter (take the average)
        new_Sigma_barycenter = sum_term / n
        
        for iteration in range(max_iter):
        # Check for convergence
        if np.linalg.norm(new_Sigma_barycenter - Sigma_barycenter) < tol:
            break
        
        # Update for the next iteration
        Sigma_barycenter = new_Sigma_barycenter
    
    return mu_barycenter, Sigma_barycenter

In [None]:
L_matrices = []  # List of L_i matrices
for iota in Ill:
    L_matrices.append(LLmodels[iota].compute_mechanism())

H_matrices = []  # List of H_i matrices
for eta in Ihl:
    H_matrices.append(HLmodels[eta].compute_mechanism())

mu_bary_L, Sigma_bary_L = oput.compute_gauss_barycenter(L_matrices, mu_U_ll_hat, Sigma_U_ll_hat)
mu_bary_H, Sigma_bary_H = oput.compute_gauss_barycenter(H_matrices, mu_U_hl_hat, Sigma_U_hl_hat)

print("Low-level barycenter Mean:", mu_bary_L)
print("Low-level barycenter Covariance:", Sigma_bary_L)
print( )
print("High-level barycenter Mean:", mu_bary_H)
print("High-level barycenter Covariance:", Sigma_bary_H)

V                 = oput.sample_projection(mu_U_ll_hat.shape[0], mu_U_hl_hat.shape[0], use_stiefel=False)
mu_bary_L_proj    = V @ mu_bary_L
Sigma_bary_L_proj = V @ Sigma_bary_L @ V.T

monge, A = oput.monge_map(mu_bary_L_proj, Sigma_bary_L_proj, mu_bary_H, Sigma_bary_H)
T        = V.T @ A

In [59]:
def wass_mean_barycenter(struc_matrices, mu):
    mu_barycenter = np.sum([S @ mu for S in struc_matrices], axis=0) / len(struc_matrices)
    
    return mu_barycenter

In [63]:
L_matrices = []  # List of L_i matrices
for iota in Ill:
    L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
    V_i = T @ L_i
    L_matrices.append(V_i)

H_matrices = []  # List of H_i matrices
for eta in Ihl:
    H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  
    H_matrices.append(H_i)


In [66]:
def update_L_ma

[tensor([[0.3555, 2.0766, 0.8979],
         [0.5018, 0.7747, 1.1855]]),
 tensor([[0.3555, 2.0766, 0.8979],
         [0.5018, 0.7747, 1.1855]]),
 tensor([[0.3555, 1.9770, 0.8763],
         [0.5018, 0.6341, 1.1551]]),
 tensor([[0.3555, 2.0766, 0.8979],
         [0.5018, 0.7747, 1.1855]]),
 tensor([[0.3555, 1.9770, 0.8763],
         [0.5018, 0.6341, 1.1551]]),
 tensor([[0.3555, 1.9770, 0.8763],
         [0.5018, 0.6341, 1.1551]])]

In [62]:
wass_mean_barycenter(H_matrices, mu_U_hl_hat)

array([ 0.0042843 , -0.00863504])

In [None]:
for n, iota in enumerate(Ill):
    L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
    V_i = T @ L_i  
    mu_bary += V_i @ mu_U_ll_hat

mu_bary = mu_bary / n

    

In [146]:
def monge(m_1, S_1, m_2, S_2):
    inner      = torch.matmul(oput.sqrtm_svd(S_1), torch.matmul(S_2, oput.sqrtm_svd(S_1)))
    sqrt_inner = oput.sqrtm_svd(inner)
    A          = torch.matmul(torch.inverse(oput.sqrtm_svd(S_1)), torch.matmul(sqrt_inner, torch.inverse(oput.sqrtm_svd(S_1))))
   
    # Define the Monge map as a function τ(x) = m_2 + A(x - m_1)
    def tau(x):
        return m_2 + A @ (x - m_1)

    return tau, A

def compute_monge_matrix(S_1, S_2):
    return torch.inverse(oput.sqrtm_svd(S_1)) @ oput.sqrtm_svd(oput.sqrtm_svd(S_1) @ S_2 @ oput.sqrtm_svd(S_1)) @ torch.inverse(oput.sqrtm_svd(S_1))

def update_S_bary(Sigma, models, S_bary):
    S_bary_new = torch.zeros_like(Sigma)
    for n, struc_mat in enumerate(models):
        S_bary_new += oput.sqrtm_svd(oput.sqrtm_svd(S_bary) @ struc_mat @ Sigma @ struc_mat.T @ oput.sqrtm_svd(S_bary))
    
    return S_bary_new /n

def compute_L_matrices(T, LLmodels, Ill):
    for iota in Ill:
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
        V_i = T @ L_i  
        L_matrices.append(V_i)

    return L_matrices

def compute_H_matrices(HLmodels, Ihl):
    for eta in Ihl:
        H_i = torch.from_numpy(HLmodels[eta].compute_mechanism()).float()  
        H_matrices.append(H_i)

    return H_matrices

def update_mu_bary(struc_matrices, mu):
    struc_matrices_tensor = torch.stack(struc_matrices)
    mu_barycenter         = torch.sum(struc_matrices_tensor @ mu, dim=0) / len(struc_matrices)

    return mu_barycenter

In [150]:
def barycentric_optimization(mu_L, mu_H, Sigma_L, Sigma_H, LLmodels, HLmodels, Ill, Ihl, max_iter=10, tol=1e-6):

    # Initialize the optimal transport map tau and the transformation matrix T
    tau = lambda x: x
    T   = torch.randn(Sigma_H.shape[0], Sigma_L.shape[0])
    print(T.shape)
    # Initialize the structural matrices    
    L_matrices = compute_L_matrices(T, LLmodels, Ill)
    H_matrices = compute_H_matrices(HLmodels, Ihl)

    # Initilize the barycenteric means and covariances
    mu_bary_L = update_mu_bary(L_matrices, mu_L)
    mu_bary_H = update_mu_bary(H_matrices, mu_H)

    S_bary_L = torch.sum(torch.stack([S @ Sigma_L @ S.T for S in L_matrices]), dim=0) / len(L_matrices)
    S_bary_H = torch.sum(torch.stack([S @ Sigma_H @ S.T for S in H_matrices]), dim=0) / len(H_matrices)

    for _ in range(max_iter):
        # Save previous values to check for convergence
        mu_bary_L_old = mu_bary_L.clone()
        mu_bary_H_old = mu_bary_H.clone()
        S_bary_L_old = S_bary_L.clone()
        S_bary_H_old = S_bary_H.clone()
        T_old = T.clone()

        tau, T     = monge(mu_bary_L, S_bary_L, mu_bary_H, S_bary_H)
        print(T.shape)
        L_matrices = compute_L_matrices(T, LLmodels, Ill)

        mu_bary_L  = update_mu_bary(L_matrices, mu_L)
        mu_bary_H  = update_mu_bary(H_matrices, mu_H)
        
        S_bary_L   = update_S_bary(Sigma_L, L_matrices, S_bary_L)
        S_bary_H   = update_S_bary(Sigma_H, H_matrices, S_bary_H)

        
        # Check for convergence
        if check_convergence(mu_bary_L_old, mu_bary_H_old, S_bary_L_old, S_bary_H_old, T_old,
                             mu_bary_L, mu_bary_H, S_bary_L, S_bary_H, T, tol):
            print(f"Converged after {_+1} iterations.")
            break

        return tau, T, mu_bary_L, S_bary_L, mu_bary_H, S_bary_H

In [151]:
# Define the convergence check function
def check_convergence(mu_bary_L_old, mu_bary_H_old, S_bary_L_old, S_bary_H_old, T_old, 
                      mu_bary_L, mu_bary_H, S_bary_L, S_bary_H, T, tol):
    # Frobenius norm of the difference in barycenter means
    mu_diff_L = torch.norm(mu_bary_L - mu_bary_L_old, p='fro')
    mu_diff_H = torch.norm(mu_bary_H - mu_bary_H_old, p='fro')
    
    # Frobenius norm of the difference in barycenter covariances
    S_diff_L = torch.norm(S_bary_L - S_bary_L_old, p='fro')
    S_diff_H = torch.norm(S_bary_H - S_bary_H_old, p='fro')
    
    # Frobenius norm of the difference in the transformation matrix
    T_diff = torch.norm(T - T_old, p='fro')
    
    # Check if all differences are below the tolerance
    if mu_diff_L < tol and mu_diff_H < tol and S_diff_L < tol and S_diff_H < tol and T_diff < tol:
        return True
    return False

In [152]:
tau, T, mu_bary_L, S_bary_L, mu_bary_H, S_bary_H = barycentric_optimization(mu_L, mu_H, Sigma_L, Sigma_H, LLmodels, HLmodels, Ill, Ihl, max_iter=10, tol=1e-6)

print(f"Final optimal map T: {T}")

torch.Size([2, 3])
torch.Size([2, 2])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2 and 3x3)