In [1]:
import joblib
import numpy as np
import torch

import modularised_utils as mut
import matplotlib.pyplot as plt

import opt_utils as oput

import Linear_Additive_Noise_Models as lanm
import operations as ops
from scipy.linalg import sqrtm

import params

np.random.seed(0)

In [2]:
experiment = 'synth1'

In [3]:
# Define the radius of the Wasserstein balls (epsilon, delta) and the size for both models.
epsilon         = params.radius[experiment][0]
ll_num_envs     = params.n_envs[experiment][0]

delta           = params.radius[experiment][1]
hl_num_envs     = params.n_envs[experiment][1]

# Define the number of samples per environment. Currently every environment has the same number of samples
num_llsamples   = params.n_samples[experiment][0]
num_hlsamples   = params.n_samples[experiment][1]

In [4]:
Dll = mut.load_samples(experiment)[None][0] 
Gll = mut.load_ll_model(experiment)[0]
Ill = mut.load_ll_model(experiment)[1]


Dhl = mut.load_samples(experiment)[None][1] 
Ghl = mut.load_hl_model(experiment)[0]
Ihl = mut.load_hl_model(experiment)[1]

omega = mut.load_omega_map(experiment)

In [5]:
ll_coeffs = mut.get_coefficients(Dll, Gll)
hl_coeffs = mut.get_coefficients(Dhl, Ghl) 

In [6]:
# # [Not suggested] In case we want to explore also the interventional --> worse estimation!
# Dlls, Dhls = [], []
# for dpair in list(mut.load_samples(experiment).values()):
#     Dlls.append(dpair[0])
#     Dhls.append(dpair[1])
    
# ll_coeffs = mut.get_coefficients(Dlls, Gll)
# hl_coeffs = mut.get_coefficients(Dhls, Ghl) 

In [7]:
U_ll_hat, mu_U_ll_hat, Sigma_U_ll_hat = mut.lan_abduction(Dll, Gll, ll_coeffs)
U_hl_hat, mu_U_hl_hat, Sigma_U_hl_hat = mut.lan_abduction(Dhl, Ghl, hl_coeffs)

In [8]:
LLmodels = {}
for iota in Ill:
    LLmodels[iota] = lanm.LinearAddSCM(Gll, ll_coeffs, iota)
    
HLmodels, Dhl_samples = {}, {}
for eta in Ihl:
    HLmodels[eta] = lanm.LinearAddSCM(Ghl, hl_coeffs, eta)

In [18]:
mu_L    = torch.from_numpy(mu_U_ll_hat)
Sigma_L = torch.from_numpy(Sigma_U_ll_hat)

mu_H    = torch.from_numpy(mu_U_hl_hat)
Sigma_H = torch.from_numpy(Sigma_U_hl_hat)

l = mu_L.shape[0]
h = mu_H.shape[0]

# Given estimates (mu_L, Sigma_L, mu_H, Sigma_H)
hat_mu_L    =  torch.from_numpy(mu_U_ll_hat) 
hat_Sigma_L =  torch.from_numpy(Sigma_U_ll_hat)

hat_mu_H    =  torch.from_numpy(mu_U_hl_hat)
hat_Sigma_H =  torch.from_numpy(Sigma_U_hl_hat)

lambda_L =.2
lambda_H =.3
eta      = .01
max_iter = 10

Sigma_L = hat_Sigma_L
Sigma_H = hat_Sigma_H


T = torch.exp(torch.from_numpy(np.random.randn(2, 3)).float())

In [52]:
# Define the necessary functions using PyTorch for automatic differentiation
def F_func(mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, lambda_L, lambda_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta):
    term1 = 0
    term2 = 0
    term3 = 0

    # Loop to compute the sum of terms
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  # Convert to float32
        V_i = T @ L_i  # Matrix multiplication, ensure V_i is float32
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  # Convert to float32

        term1 += torch.norm(torch.matmul(V_i, mu_L) - torch.matmul(H_i, mu_H))**2 + torch.trace(torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))) + torch.trace(torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T)))

    term2 = lambda_L * (epsilon**2 - torch.norm(mu_L - hat_mu_L)**2 - torch.norm(sqrtm_svd(Sigma_L) - sqrtm_svd(hat_Sigma_L))**2)
    term3 = lambda_H * (delta**2 - torch.norm(mu_H - hat_mu_H)**2 - torch.norm(sqrtm_svd(Sigma_H) - sqrtm_svd(hat_Sigma_H))**2)

    return term1 / n + term2 + term3

# Proximal operator for Sigma_L (using soft-thresholding)
def prox_Sigma_L(Sigma_L, lambda_L, LLmodels, HLmodels, Sigma_H):
    # Using the Frobenius norm as a soft-thresholding operator for Sigma_L
    prox = torch.zeros_like(Sigma_L)
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
        V_i = T @ L_i  
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  

        V_Sigma_V       = torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))
        sqrtm_V_Sigma_V = sqrtm_svd(V_Sigma_V)
        prox_i          = prox_operator(sqrtm_V_Sigma_V, lambda_L)
        ll_term         = torch.linalg.pinv(V_i) @ torch.matmul(prox_i, prox_i.T) @ torch.linalg.pinv(V_i).T

        H_Sigma_H       = torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T))
        sqrtm_H_Sigma_H = sqrtm_svd(H_Sigma_H)
        hl_term         = torch.norm(sqrtm_H_Sigma_H, p='fro') 
       
        prox += ll_term * hl_term

    prox *= (2 / n)
    prox = diagonalize(prox)
    return prox

# Proximal operator for Sigma_H (using soft-thresholding)
def prox_Sigma_H(Sigma_H, lambda_H, LLmodels, HLmodels, Sigma_L):
    prox = torch.zeros_like(Sigma_H)
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
        V_i = T @ L_i  
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  
       
        H_Sigma_H       = torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T))
        sqrtm_H_Sigma_H = sqrtm_svd(H_Sigma_H)
        prox_i          = prox_operator(sqrtm_H_Sigma_H, lambda_H)
        hl_term         = torch.linalg.inv(H_i) @ torch.matmul(prox_i, prox_i.T) @ torch.linalg.inv(H_i).T
        #hl_term        = torch.inverse(H_i) @ torch.matmul(prox_i, prox_i.T) @ torch.inverse(H_i).T

        V_Sigma_V       = torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))
        sqrtm_V_Sigma_V = sqrtm_svd(V_Sigma_V)
        ll_term         = torch.norm(sqrtm_V_Sigma_V, p='fro') 
        
        prox     += ll_term * hl_term

    prox *= (2 / n)

    prox = diagonalize(prox)
    return prox

# Proximal operator of a matrix frobenious norm
def prox_operator(A, lambda_param):
    frobenius_norm = torch.norm(A, p='fro')
    scaling_factor = torch.max(1 - lambda_param / frobenius_norm, torch.zeros_like(frobenius_norm))
    return scaling_factor * A

def diagonalize(A):
    # Get eigenvalues and eigenvectors
    eigvals, eigvecs = torch.linalg.eig(A)  
    eigvals_real     = eigvals.real  
    eigvals_real     = torch.sqrt(eigvals_real)  # Take the square root of the eigenvalues

    return torch.diag(eigvals_real)

def sqrtm_svd(A):
    # Compute the SVD of A
    U, S, V = torch.svd(A)
    
    # Take the square root of the singular values
    S_sqrt = torch.sqrt(torch.clamp(S, min=0.0))  # Ensure non-negative singular values
    
    # Reconstruct the square root matrix
    sqrt_A = U @ torch.diag(S_sqrt) @ V.T
    
    return sqrt_A

def sqrtm_eig(A):
    eigvals, eigvecs = torch.linalg.eig(A)
    eigvals_real = eigvals.real
    
    # Ensure eigenvalues are non-negative for the square root to be valid
    eigvals_sqrt = torch.sqrt(torch.clamp(eigvals_real, min=0.0))  # Square root of non-negative eigenvalues

    # Reconstruct the square root of the matrix using the eigenvectors
    # Make sure the eigenvectors are also real
    eigvecs_real = eigvecs.real
    
    # Reconstruct the matrix square root
    sqrt_A = eigvecs_real @ torch.diag(eigvals_sqrt) @ eigvecs_real.T
    
    return sqrt_A


# Optimization loop using autograd and PyProximal (maximize using gradient ascent)
def optimize(LLmodels, HLmodels, mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta, lambda_L, lambda_H, eta, max_iter):
    mu_L.requires_grad_(True)  # Enable autograd for mu_L
    Sigma_L_half.requires_grad_(True)  # Enable autograd for Sigma_L
    mu_H.requires_grad_(True)  # Enable autograd for mu_H
    Sigma_H_half.requires_grad_(True)  # Enable autograd for Sigma_H

    for t in range(max_iter):
        print(f"Iteration {t}")
        
        objective = F_func(mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, lambda_L, lambda_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta)
        objective.backward()

        with torch.no_grad():
            mu_L += eta * mu_L.grad  # Ascent for mu_L
            mu_H += eta * mu_H.grad  # Ascent for mu_H
            
            print(f"Sigma_L: {Sigma_L.grad}")
            Sigma_L_half += eta * Sigma_L.grad  # Ascent for Sigma_L
            Sigma_H += eta * Sigma_H.grad  # Ascent for Sigma_H
            Sigma_L = prox_Sigma_L(Sigma_L_half, lambda_L, LLmodels, HLmodels, Sigma_H)
            print(Sigma_L)  
            Sigma_H = prox_Sigma_H(Sigma_H_half, lambda_H, LLmodels, HLmodels, Sigma_L)
            
            #Zero the gradients after the update
            mu_L.grad.zero_()
            mu_H.grad.zero_()
            Sigma_L.grad.zero_()
            Sigma_H.grad.zero_()

            # if mu_L.grad is not None:
            #     mu_L.grad.zero_()
            # if mu_H.grad is not None:
            #     mu_H.grad.zero_()
            # if Sigma_L.grad is not None:
            #     Sigma_L.grad.zero_()
            # if Sigma_H.grad is not None:
            #     Sigma_H.grad.zero_()

        # Print progress
        if t % 10 == 0:
            print(f"Iteration {t}, Objective Value: {objective.item()}")

    return mu_L, Sigma_L, mu_H, Sigma_H


In [53]:
hat_mu_L    = torch.from_numpy(mu_U_ll_hat).float()
hat_Sigma_L = torch.from_numpy(Sigma_U_ll_hat).float()

hat_mu_H    = torch.from_numpy(mu_U_hl_hat).float()
hat_Sigma_H = torch.from_numpy(Sigma_U_hl_hat).float()

l = hat_mu_L.shape[0]
h = hat_mu_H.shape[0]


# Gelbrich initialization
ll_moments      = mut.sample_moments_U(mu_hat = mu_U_ll_hat, Sigma_hat = Sigma_U_ll_hat, bound = epsilon, num_envs = 1)
mu_L0, Sigma_L0 = ll_moments[0]
#mu_L0, Sigma_L0 = torch.from_numpy(mu_L0), torch.from_numpy(Sigma_L0)

hl_moments      = mut.sample_moments_U(mu_hat = mu_U_hl_hat, Sigma_hat = Sigma_U_hl_hat, bound = delta, num_envs = 1)
mu_H0, Sigma_H0 = hl_moments[0]
#mu_H0, Sigma_H0 = torch.from_numpy(mu_H0), torch.from_numpy(Sigma_H0)


T = torch.exp(torch.from_numpy(np.random.randn(2, 3)).float())

In [56]:
mu_L    = torch.from_numpy(mu_L0).float()
Sigma_L = torch.from_numpy(Sigma_L0).float()
mu_H    = torch.from_numpy(mu_H0).float()
Sigma_H = torch.from_numpy(Sigma_H0).float()

mu_L, Sigma_L, mu_H, Sigma_H = optimize(LLmodels, HLmodels, mu_L.float(), Sigma_L.float(), mu_H.float(), Sigma_H.float(),
                                         hat_mu_L.float(), hat_Sigma_L.float(), hat_mu_H.float(), hat_Sigma_H.float(),
                                           epsilon, delta, lambda_L=.7, lambda_H=.8, eta=.01, max_iter=10)


Iteration 0
Sigma_L: tensor([[0.4711, 1.2888, 1.0865],
        [1.2888, 5.5278, 3.1481],
        [1.0865, 3.1481, 2.5457]])
Updated Sigma_L: tensor([[0.0066, 0.0000, 0.0000],
        [0.0000, 1.6441, 0.0000],
        [0.0000, 0.0000, 2.2323]])


AttributeError: 'NoneType' object has no attribute 'zero_'