In [1]:
import joblib
import numpy as np
import torch

import modularised_utils as mut
import matplotlib.pyplot as plt

import opt_utils as oput

import Linear_Additive_Noise_Models as lanm
import operations as ops
from scipy.linalg import sqrtm

import params

np.random.seed(0)

In [2]:
experiment = 'synth1'

In [3]:
# Define the radius of the Wasserstein balls (epsilon, delta) and the size for both models.
epsilon         = params.radius[experiment][0]
ll_num_envs     = params.n_envs[experiment][0]

delta           = params.radius[experiment][1]
hl_num_envs     = params.n_envs[experiment][1]

# Define the number of samples per environment. Currently every environment has the same number of samples
num_llsamples   = params.n_samples[experiment][0]
num_hlsamples   = params.n_samples[experiment][1]

In [4]:
Dll = mut.load_samples(experiment)[None][0] 
Gll = mut.load_ll_model(experiment)[0]
Ill = mut.load_ll_model(experiment)[1]


Dhl = mut.load_samples(experiment)[None][1] 
Ghl = mut.load_hl_model(experiment)[0]
Ihl = mut.load_hl_model(experiment)[1]

omega = mut.load_omega_map(experiment)

In [5]:
ll_coeffs = mut.get_coefficients(Dll, Gll)
hl_coeffs = mut.get_coefficients(Dhl, Ghl) 

In [6]:
# # [Not suggested] In case we want to explore also the interventional --> worse estimation!
# Dlls, Dhls = [], []
# for dpair in list(mut.load_samples(experiment).values()):
#     Dlls.append(dpair[0])
#     Dhls.append(dpair[1])
    
# ll_coeffs = mut.get_coefficients(Dlls, Gll)
# hl_coeffs = mut.get_coefficients(Dhls, Ghl) 

In [7]:
U_ll_hat, mu_U_ll_hat, Sigma_U_ll_hat = mut.lan_abduction(Dll, Gll, ll_coeffs)
U_hl_hat, mu_U_hl_hat, Sigma_U_hl_hat = mut.lan_abduction(Dhl, Ghl, hl_coeffs)

In [8]:
LLmodels = {}
for iota in Ill:
    LLmodels[iota] = lanm.LinearAddSCM(Gll, ll_coeffs, iota)
    
HLmodels, Dhl_samples = {}, {}
for eta in Ihl:
    HLmodels[eta] = lanm.LinearAddSCM(Ghl, hl_coeffs, eta)

In [9]:
# # Ambiguity set construction: Based on epsilon and delta include distribution (as many as the num_envs) that
# # pass the "gelbrich" test.
# ll_moments = mut.sample_moments_U(mu_hat    = mu_U_ll_hat,
#                                   Sigma_hat = Sigma_U_ll_hat,
#                                   bound     = epsilon,
#                                   num_envs  = ll_num_envs)

# A_ll       = mut.sample_distros_Gelbrich(ll_moments) #Low-level: A_epsilon


# hl_moments = mut.sample_moments_U(mu_hat    = mu_U_hl_hat,
#                                   Sigma_hat = Sigma_U_hl_hat,
#                                   bound     = delta,
#                                   num_envs  = hl_num_envs)

# A_hl       = mut.sample_distros_Gelbrich(hl_moments) #High-level A_delta

In [10]:
# abstraction_errors             = {}
# abstraction_env_errors         = {}
# max_env_avg_interv_error_value = -np.inf
# max_env_avg_interv_error_key   = None
# distance_err                   = 'wass'

# for lenv in A_ll:

#     Dll_noise      = lenv.sample(num_llsamples)[0]
#     ll_environment = mut.get_exogenous_distribution(Dll_noise)

#     for henv in A_hl:
#         Dhl_noise      = henv.sample(num_hlsamples)[0]
#         hl_environment = mut.get_exogenous_distribution(Dhl_noise)

#         total_ui_error = 0
#         num_distros    = len(Ill)

#         n, m  = len(LLmodels[None].endogenous_vars), len(HLmodels[None].endogenous_vars)

#         T     = mut.sample_stoch_matrix(n, m)

#         for iota in Ill:
#             llcm   = LLmodels[iota]
#             hlcm   = HLmodels[omega[iota]]
#             llmech = llcm.compute_mechanism()
#             hlmech = hlcm.compute_mechanism()
#             error  = mut.ui_error_dist(distance_err, lenv, henv, llmech, hlmech, T)

#             total_ui_error += error

#         avg_interv_error = total_ui_error/num_distros

#         if avg_interv_error > max_env_avg_interv_error_value:
#             max_env_avg_interv_error_value = avg_interv_error
#             max_env_avg_interv_error_key   = (lenv, henv)

#         abstraction_errors[str(T)] = avg_interv_error
#         abstraction_env_errors['ll: '+str(ll_environment.means_)+' hl: '+str(hl_environment.means_)] = avg_interv_error


# max_tau   = max(abstraction_errors, key=abstraction_errors.get)
# max_error = abstraction_errors[max_tau]

# print(f"Abstraction: {max_tau}, Error: {max_error}")
# print('==============================================================================' )
# max_lenv = max_env_avg_interv_error_key[0]
# max_henv = max_env_avg_interv_error_key[1]

# print(f"max LL mean vector = {max_lenv.means_}")
# print(f"max LL covariance = {max_lenv.covariances_}")
# print( )

# print(f"max HL mean vector = {max_henv.means_}")
# print(f"max HL covariance = {max_henv.covariances_}")
# print('==============================================================================' )
# print(f"max environment, average interventional abstraction error = {max_env_avg_interv_error_value}")

In [11]:
def update_mu_L(T, mu_L, mu_H, LLmodels, HLmodels, lambda_L, hat_mu_L, eta):
    grad_mu_L = torch.zeros_like(mu_L, dtype=torch.float32) 
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float() 
        V_i = T @ L_i  
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float() 

        grad_mu_L += torch.matmul(V_i.T, torch.matmul(V_i, mu_L.float()) - torch.matmul(H_i, mu_H.float())) 
    
    grad_mu_L = (2 / n) * grad_mu_L - 2 * lambda_L * (mu_L - hat_mu_L)
    mu_L = mu_L + (eta * grad_mu_L)
    return mu_L

def update_mu_H(T, mu_L, mu_H, LLmodels, HLmodels, lambda_H, hat_mu_H, eta):
    grad_mu_H = torch.zeros_like(mu_H, dtype=torch.float32)  
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
        V_i = T @ L_i  
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  

        grad_mu_H -= torch.matmul(H_i.T, torch.matmul(V_i, mu_L.float()) - torch.matmul(H_i, mu_H.float()))
    
    grad_mu_H = (2 / n) * grad_mu_H - 2 * lambda_H * (mu_H - hat_mu_H)
    
    mu_H = mu_H + (eta * grad_mu_H)
    return mu_H


def update_Sigma_L_half(T, Sigma_L, LLmodels, lambda_L, hat_Sigma_L, eta):
    grad_Sigma_L = torch.zeros_like(Sigma_L)
    
    # Term 1: (2/n) * sum_i(V_i^T * V_i)
    term1 = torch.zeros_like(Sigma_L)
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism())
        V_i = T @ L_i.float()
        term1 = term1 + torch.matmul(V_i.T, V_i)

    # Term 2: -2 * lambda_L * (Sigma_L^(1/2) - hat_Sigma_L^(1/2)) * Sigma_L^(-1/2)
    Sigma_L_sqrt = oput.sqrtm_svd(Sigma_L)  # Compute the square root of Sigma_L
    #Sigma_L_sqrt = torch.linalg.matrix_power(Sigma_L, 0.5)

    hat_Sigma_L_sqrt = oput.sqrtm_svd(hat_Sigma_L)  # Compute the square root of hat_Sigma_L

    term2 = -2 * lambda_L * (Sigma_L_sqrt - hat_Sigma_L_sqrt) @ torch.inverse(Sigma_L_sqrt)

    # Combine terms
    grad_Sigma_L = (2 / n) * term1 + term2

    # Update Sigma_L
    Sigma_L_half = Sigma_L + eta * grad_Sigma_L
    #Sigma_L_half  = diagonalize(Sigma_L_half)
    return Sigma_L_half


def update_Sigma_L(T, Sigma_L_half, LLmodels, Sigma_H, HLmodels, lambda_param):
    Sigma_L_final = torch.zeros_like(Sigma_L_half, dtype=torch.float32)  
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()  
        V_i = T @ L_i  
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()  
        
        Sigma_L_half      = Sigma_L_half.float()
        V_Sigma_V         = torch.matmul(V_i, torch.matmul(Sigma_L_half, V_i.T))
        sqrtm_V_Sigma_V   = oput.sqrtm_svd(V_Sigma_V)
        prox_Sigma_L_half = torch.matmul(oput.prox_operator(sqrtm_V_Sigma_V, lambda_param), oput.prox_operator(sqrtm_V_Sigma_V, lambda_param).T)
        ll_term           = torch.matmul(torch.matmul(torch.linalg.pinv(V_i), prox_Sigma_L_half), torch.linalg.pinv(V_i).T)

        Sigma_H   = Sigma_H.float()  
        H_Sigma_H = torch.matmul(H_i, torch.matmul(Sigma_H, H_i.T)).float()
        hl_term   = torch.norm(oput.sqrtm_svd(H_Sigma_H), p='fro')

        Sigma_L_final = Sigma_L_final + (ll_term * hl_term)

    Sigma_L_final =  Sigma_L_final * (2 / n)
    Sigma_L_final = oput.diagonalize(Sigma_L_final)

    return Sigma_L_final


def update_Sigma_H_half(T, Sigma_H, HLmodels, lambda_H, hat_Sigma_H, eta):
    grad_Sigma_H = torch.zeros_like(Sigma_H)
    term1 = torch.zeros_like(Sigma_H)
    for n, iota in enumerate(Ill):
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()

        term1 = term1 + torch.matmul(H_i.T, H_i)

    Sigma_H_sqrt     = oput.sqrtm_svd(Sigma_H)  
    hat_Sigma_H_sqrt = oput.sqrtm_svd(hat_Sigma_H) 

    term2 = -2 * lambda_H * (Sigma_H_sqrt - hat_Sigma_H_sqrt) @ torch.inverse(Sigma_H_sqrt)

    grad_Sigma_H = (2 / n) * term1 + term2

    Sigma_H_half = Sigma_H + eta * grad_Sigma_H
    return Sigma_H_half

def check_for_invalid_values(matrix):
    if torch.isnan(matrix).any() or torch.isinf(matrix).any():
        #print("Matrix contains NaN or Inf values!")
        return True
    return False

def handle_nans(matrix, replacement_value=0.0):
    # Replace NaNs with a given value (default is 0)
    if torch.isnan(matrix).any():
        print("Warning: NaN values found! Replacing with zero.")
        matrix = torch.nan_to_num(matrix, nan=replacement_value)
    return matrix


def update_Sigma_H(T, Sigma_H_half, LLmodels, Sigma_L, HLmodels, lambda_param):
    if check_for_invalid_values(Sigma_L):
        print("Sigma_L contains NaN or Inf values!")
    Sigma_H_final = torch.zeros_like(Sigma_H_half)
    for n, iota in enumerate(Ill):
        L_i = torch.from_numpy(LLmodels[iota].compute_mechanism())
        V_i = T @ L_i.float()
        H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()

        H_Sigma_H         = torch.matmul(H_i, torch.matmul(Sigma_H_half, H_i.T))
        sqrtm_H_Sigma_H   = oput.sqrtm_svd(H_Sigma_H)
        prox_Sigma_H_half = torch.matmul(oput.prox_operator(sqrtm_H_Sigma_H, lambda_param), oput.prox_operator(sqrtm_H_Sigma_H, lambda_param).T)
        hl_term           = torch.matmul(torch.matmul(torch.inverse(H_i), prox_Sigma_H_half), torch.inverse(H_i).T)  
        
        V_Sigma_V = torch.matmul(V_i, torch.matmul(Sigma_L, V_i.T))
        ll_term   = torch.norm(oput.sqrtm_svd(V_Sigma_V))

        Sigma_H_final = Sigma_H_final + (ll_term * hl_term)
    
    Sigma_H_final = Sigma_H_final * (2 / n)
    Sigma_H_final = oput.diagonalize(Sigma_H_final)
    
    return Sigma_H_final

In [12]:
def check_constraints(mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta):
    # Constraint 1: epsilon^2 - ||mu_L - hat_mu_L||_2^2 - ||Sigma_L^{1/2} - hat_Sigma_L^{1/2}||_2^2 >= 0
    constraint_L = epsilon**2 - (torch.norm(mu_L - hat_mu_L)**2) - (torch.norm(oput.sqrtm_svd(Sigma_L) - oput.sqrtm_svd(hat_Sigma_L))**2)
    
    # Constraint 2: delta^2 - ||mu_H - hat_mu_H||_2^2 - ||Sigma_H^{1/2} - hat_Sigma_H^{1/2}||_2^2 >= 0
    constraint_H = delta**2 - (torch.norm(mu_H - hat_mu_H)**2) - (torch.norm(oput.sqrtm_svd(Sigma_H) - oput.sqrtm_svd(hat_Sigma_H))**2)
    
    # Return whether constraints are satisfied (i.e., >= 0) and the constraint violations
    return constraint_L, constraint_H


def enforce_constraints(mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta):
    constraint_L, constraint_H = check_constraints(mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta)
    
    # Clip values if constraints are violated
    if constraint_L < 0:
        #print(f"Constraint for mu_L and Sigma_L violated. Fixing...")
        mu_L = hat_mu_L + torch.clamp(mu_L - hat_mu_L, min=-epsilon, max=epsilon)
        Sigma_L = hat_Sigma_L + torch.clamp(Sigma_L - hat_Sigma_L, min=-epsilon, max=epsilon)
    
    if constraint_H < 0:
        #print(f"Constraint for mu_H and Sigma_H violated. Fixing...")
        mu_H = hat_mu_H + torch.clamp(mu_H - hat_mu_H, min=-delta, max=delta)
        Sigma_H = hat_Sigma_H + torch.clamp(Sigma_H - hat_Sigma_H, min=-delta, max=delta)
    
    return mu_L, Sigma_L, mu_H, Sigma_H

In [13]:
def optimize_max(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, lambda_L, lambda_H, lambda_param, eta, num_steps_max, epsilon, delta, seed):
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    for t in range(num_steps_max): 
        mu_L         = update_mu_L(T, mu_L, mu_H, LLmodels, HLmodels, lambda_L, hat_mu_L, eta)
        mu_H         = update_mu_H(T, mu_L, mu_H, LLmodels, HLmodels, lambda_H, hat_mu_H, eta)
        Sigma_L_half = update_Sigma_L_half(T, Sigma_L, LLmodels, lambda_L, hat_Sigma_L, eta)
        Sigma_L      = update_Sigma_L(T, Sigma_L_half, LLmodels, Sigma_H, HLmodels, lambda_param)
        Sigma_H_half = update_Sigma_H_half(T, Sigma_H, HLmodels, lambda_H, hat_Sigma_H, eta)
        Sigma_H      = update_Sigma_H(T, Sigma_H_half, LLmodels, Sigma_L, HLmodels, lambda_param)
        
        mu_L, Sigma_L, mu_H, Sigma_H = enforce_constraints(mu_L, Sigma_L, mu_H, Sigma_H, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, epsilon, delta)
        
        obj = 0
        
        for i, iota in enumerate(Ill):
            L_i = torch.from_numpy(LLmodels[iota].compute_mechanism())
            V_i = T @ L_i.float()
            H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()
                        
            L_i_mu_L = V_i @ mu_L
            H_i_mu_H = H_i @ mu_H
            term1 = torch.norm(L_i_mu_L.float() - H_i_mu_H.float())**2
            
            V_Sigma_V = V_i.float() @ Sigma_L.float() @ V_i.T.float()
            H_Sigma_H = H_i.float() @ Sigma_H.float() @ H_i.T.float()

            term2 = torch.trace(V_Sigma_V)
            term3 = torch.trace(H_Sigma_H)
            
            sqrtVSV = oput.sqrtm_svd(V_Sigma_V)
            sqrtHSH = oput.sqrtm_svd(H_Sigma_H)

            term4 = -2*torch.norm(oput.sqrtm_svd(sqrtVSV) @ oput.sqrtm_svd(sqrtHSH), 'nuc')
            
            obj = obj + (term1 + term2 + term3 + term4)
        
        obj = obj/i
        
        #print(f"Max step {t+1}/{num_steps_max}, Objective: {obj.item()}")

    return obj, mu_L, Sigma_L, mu_H, Sigma_H

def optimize_min(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, num_steps_min, optimizer_T, seed):

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    objective_T = 0 # Reset objective at the start of each step
    for step in range(num_steps_min):
        objective_T = 0  # Reset objective at the start of each step
        for n, iota in enumerate(Ill):
            L_i = torch.from_numpy(LLmodels[iota].compute_mechanism()).float()
            H_i = torch.from_numpy(HLmodels[omega[iota]].compute_mechanism()).float()

            L_i_mu_L = L_i @ mu_L  
            H_i_mu_H = H_i @ mu_H 

            term1 = torch.norm(T @ L_i_mu_L - H_i_mu_H) ** 2
            term2 = torch.trace(T @ L_i @ Sigma_L @ L_i.T @ T.T)
            term3 = torch.trace(H_i @ Sigma_H @ H_i.T)
            
            L_i_Sigma_L = T @ L_i @ Sigma_L @ L_i.T @ T.T
            H_i_Sigma_H = H_i @ Sigma_H @ H_i.T

            term4 = -2 * torch.norm(oput.sqrtm_svd(L_i_Sigma_L) @ oput.sqrtm_svd(H_i_Sigma_H), 'nuc')

            objective_T += term1 + term2 + term3 + term4

        objective_T = objective_T/n

        optimizer_T.zero_grad() # Clear previous gradients
        objective_T.backward(retain_graph=True)  # Backpropagate to compute gradients
        optimizer_T.step()      # Update T using the optimizer

        #print(f"Min step {step+1}/{num_steps_min}, Objective: {objective_T.item()}")

    return objective_T, T  # Return both the objective and T

In [14]:
def optimize_min_max(mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, 
                     hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, 
                     epsilon, delta, lambda_L, lambda_H, lambda_param, 
                     eta, max_iter, num_steps_min, num_steps_max, tol_max, tol, seed):
    
    j = 0
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed)

    T           = torch.randn(mu_H.shape[0], mu_L.shape[0], requires_grad=True)
    optimizer_T = torch.optim.Adam([T], lr=0.01)

    previous_objective = float('inf')  # Initialize with a large number
    previous_objective_theta = float('inf')  # Initialize with a large number
    objective_theta = torch.tensor(float('inf'))
    for epoch in range(max_iter):
        #print('##########################################')
        print(f"Epoch {epoch+1}/{max_iter}\n")
        #print("MINIMIZING T")

        # ---- Minimize T ----
        objective_T, T = optimize_min(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, num_steps_min, optimizer_T, seed)
        
        print()
        #print("MAX mu_L, Sigma_L, mu_H, Sigma_H")

        if not abs(previous_objective_theta - objective_theta.item()) < tol_max:
            objective_theta, mu_L, Sigma_L, mu_H, Sigma_H = optimize_max(T, mu_L, Sigma_L, mu_H, Sigma_H, LLmodels, HLmodels, hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H,
                                                                         lambda_L, lambda_H, lambda_param, eta, num_steps_max, epsilon, delta, seed)
        else:
            if j==0:
                print('MAX step skipped')
                print(mu_L, Sigma_L, mu_H, Sigma_H)
                j+=1
            
        previous_objective_theta = objective_theta.item()
        # Check for convergence by comparing the difference in objective values
        criterion = abs(previous_objective - objective_T.item())
        print(f"Objective difference: {criterion}")
        
        if criterion < tol:
            print(f"Convergence reached at epoch {epoch+1} with objective {objective_T.item()}")
            break

        # Update previous objective for the next check
        previous_objective = objective_T.item()
        #print('##########################################')

    print("Final T:", T)
    print("Final mu_L:", mu_L)
    print("Final Sigma_L:", Sigma_L)
    print("Final mu_H:", mu_H)
    print("Final Sigma_H:", Sigma_H)

    return mu_L, Sigma_L, mu_H, Sigma_H, T


In [15]:
hat_mu_L    = torch.from_numpy(mu_U_ll_hat).float()
hat_Sigma_L = torch.from_numpy(Sigma_U_ll_hat).float()

hat_mu_H    = torch.from_numpy(mu_U_hl_hat).float()
hat_Sigma_H = torch.from_numpy(Sigma_U_hl_hat).float()

l = hat_mu_L.shape[0]
h = hat_mu_H.shape[0]


# Gelbrich initialization
ll_moments      = mut.sample_moments_U(mu_hat = mu_U_ll_hat, Sigma_hat = Sigma_U_ll_hat, bound = epsilon, num_envs = 1)
mu_L0, Sigma_L0 = ll_moments[0]
#mu_L0, Sigma_L0 = torch.from_numpy(mu_L0), torch.from_numpy(Sigma_L0)

hl_moments      = mut.sample_moments_U(mu_hat = mu_U_hl_hat, Sigma_hat = Sigma_U_hl_hat, bound = delta, num_envs = 1)
mu_H0, Sigma_H0 = hl_moments[0]
#mu_H0, Sigma_H0 = torch.from_numpy(mu_H0), torch.from_numpy(Sigma_H0)

In [16]:
mu_L    = torch.from_numpy(mu_L0).float()
Sigma_L = torch.from_numpy(Sigma_L0).float()
mu_H    = torch.from_numpy(mu_H0).float()
Sigma_H = torch.from_numpy(Sigma_H0).float()

mu_L, Sigma_L, mu_H, Sigma_H, T = optimize_min_max(mu_L, Sigma_L, mu_H, Sigma_H, 
                                                    LLmodels, HLmodels, 
                                                    hat_mu_L, hat_Sigma_L, hat_mu_H, hat_Sigma_H, 
                                                    epsilon=0.5, delta=0.5, lambda_L=0.8, lambda_H=0.7, lambda_param=0.9, 
                                                    eta=0.01, max_iter=500,  num_steps_min=5, num_steps_max=5, tol_max=1e-5, tol=1e-5, seed=42)



Epoch 1/500


Objective difference: inf
Epoch 2/500


MAX step skipped
tensor([-0.1889,  0.3310,  0.0501], grad_fn=<AddBackward0>) tensor([[0.4787, 0.0000, 0.0000],
        [0.0000, 1.4741, 0.0000],
        [0.0000, 0.0000, 1.4901]], grad_fn=<AddBackward0>) tensor([0.4611, 0.4914], grad_fn=<AddBackward0>) tensor([[1.3855, 0.0000],
        [0.0000, 1.4701]], grad_fn=<AddBackward0>)
Objective difference: 1.6845197677612305
Epoch 3/500


Objective difference: 0.19249367713928223
Epoch 4/500


Objective difference: 0.17441654205322266
Epoch 5/500


Objective difference: 0.16641926765441895
Epoch 6/500


Objective difference: 0.16215252876281738
Epoch 7/500


Objective difference: 0.158613920211792
Epoch 8/500


Objective difference: 0.15427398681640625
Epoch 9/500


Objective difference: 0.14841902256011963
Epoch 10/500


Objective difference: 0.14085781574249268
Epoch 11/500


Objective difference: 0.1317894458770752
Epoch 12/500


Objective difference: 0.12165462970733643
Epoch 13/500




KeyboardInterrupt: 