In [None]:
import itertools

import random
import numpy as np
import pandas as pd
import cvxpy as cp
from tqdm import tqdm

from scipy.spatial import distance
from cvxpy.error import SolverError
from cvxpy import ECOS, SCS
import seaborn as sns
import ot
from pgmpy import inference
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance
import plotly.graph_objects as go
import plotly.graph_objs as go
import plotly.express as px

from plotly.subplots import make_subplots
from mpltern.datasets import get_triangular_grid
import scipy.optimize as optimize
from scipy.spatial.distance import cdist
from src.examples import smokingmodels as sm
from scipy.spatial.distance import squareform,pdist
from scipy.optimize import linprog
from scipy import stats
from scipy.stats import wasserstein_distance

from IPython.utils import io
import warnings

import joblib
import modularized_utils as ut
import abstraction_metrics as ams
import matplotlib.pylab as pl

import get_results
import params

np.random.seed(0)

warnings.filterwarnings(action='ignore')
np.set_printoptions(precision=4,suppress=True)

In [None]:
def get_square_grid(n):
    a, b, c, d = [], [], [], []

    for i in range(n):
        order = random.sample(range(4), 4)

        a_val = round(random.uniform(0, 1), 4)
        b_val = round(random.uniform(0, 1 - a_val), 4)
        c_val = round(random.uniform(0, 1 - a_val - b_val), 4)
        d_val = round(1 - (a_val + b_val + c_val), 4)

        a.append([a_val, b_val, c_val, d_val][order[0]])
        b.append([a_val, b_val, c_val, d_val][order[1]])
        c.append([a_val, b_val, c_val, d_val][order[2]])
        d.append([a_val, b_val, c_val, d_val][order[3]])

    return a, b, c, d

In [None]:
def create_comb(a, b, c):
    return str(a)+'-'+str(b)+'-'+str(c)

def parse_comb(input_string):
    a, b, c = input_string.split('-')
    return [float(a), float(b), float(c)]

In [None]:
def leave_one_out_grid(pairs, dropped_pair, experiment, combination, df, cf, cota_version):
    
    omega = {}
    for pair in pairs:
        omega[pair.iota_base] = pair.iota_abst
    
    hold_pairs = pairs.copy()
   
    if dropped_pair != None:
        hold_pairs.remove(dropped_pair)
        hold_omega  = ut.drop1omega(omega, dropped_pair.iota_base)
        
    else:
        hold_omega = omega
        
    I_relevant  = list(hold_omega.keys())

    struc, tree = ut.build_poset(I_relevant)
    chains      = ut.to_chains(hold_pairs, struc)
    
    combin = parse_comb(combination)
    kk, ll, mm = combin[0], combin[1], combin[2]
    
    args   = [hold_pairs, [chains], kk, ll, mm, df, cf]
    
    if cota_version == 'avg_plan':
        get_results.results_grid_looo(args, experiment, dropped_pair)
        looo_grid_results = ut.load_grid_results_looo(experiment, combination, dropped_pair)
        
    elif cota_version == 'avg_map':
        get_results.results_grid_looo_aggregated(args, experiment, dropped_pair)
        looo_grid_results = ut.load_grid_results_looo_aggregated(experiment, combination, dropped_pair)
    
    return looo_grid_results

In [None]:
def compute_ae_grid_looo(exp, pairs, combo, metric, cost, cota_version):
    
    looo_results_grid = []
    for n in range(len(pairs)):

        results_grid = {}
        for i, pair in enumerate(pairs[n]):

            if pair.iota_base.intervention != {None: None}:
                dropped_pair = pair
            else:
                dropped_pair = None

            results_grid[dropped_pair] = leave_one_out_grid(pairs[n], dropped_pair, exp, combo, metric, cost, cota_version)

        looo_results_grid.append(results_grid)
    
    return looo_results_grid

In [None]:
def get_looo_results(maps, pairs, looo, error, cota_version):
    
    no_looo_map = maps[0][list(maps[0].keys())[0]]
    
    if cota_version == 'avg_plan':
        no_looo_map = no_looo_map[0]
    
    n_sims = len(maps)
    avg_list = []
    for n in range(n_sims):

        ae = 0
        for pair in pairs[n]:
            #print('base: ', pair.iota_base.intervention, 'abst: ', pair.iota_abst.intervention)
            if pair.iota_base.intervention == {None: None}:
                p = None
            else:
                p = pair
            
            tau_dict = maps[n][p]
            
            if cota_version == 'avg_plan':
                tau_dict = tau_dict[0]
                
            if looo == True:
                tau = ams.to_tuples(tau_dict, 'stochastic')
            else:
                tau = ams.to_tuples(no_looo_map, 'stochastic')

            pushforward = ams.stochastic_pushforward(tau, pair.base_dict, list(pair.abst_dict.keys()))
            
            if error == 'jsd':
                d = distance.jensenshannon(pushforward, pair.abst_distribution)
                ae += d
            elif error == 'wass':
                d = wasserstein_distance(pushforward, pair.abst_distribution)
                ae += d

        avg_list.append(ae/len(pairs[n]))

    return np.mean(avg_list), np.std(avg_list)

In [None]:
def get_triangular_grid_random(n):
    a, b, c = [], [], []

    for i in range(n):
        order = random.sample(range(3), 3)

        a_val = round(random.uniform(0, 1), 4)
        b_val = round(random.uniform(0, 1 - a_val), 4)
        c_val = round(1 - (a_val + b_val), 4)

        a.append([a_val, b_val, c_val][order[0]])
        b.append([a_val, b_val, c_val][order[1]])
        c.append([a_val, b_val, c_val][order[2]])

    return a, b, c

In [None]:
def run_cota(exp, pairs, metric, cost, n_grid, looo, error, cota_version):    

    track_MI = False
    if n_grid != 1:
        
        k_list, l_list, m_list = get_triangular_grid(n_grid)

        conv_combinations = []
        for elements in zip(k_list, l_list, m_list):
            kappa, lmbda, mu = elements[0], elements[1], elements[2]
            conv_combinations.append(create_comb(kappa, lmbda, mu))
        conv_combinations.remove('1.0-0.0-0.0')
    else:
        conv_combinations = ['1.0-0.0-0.0']

    grid_dict, stds = {}, {}
    excluded_combos     = []
    for combo in tqdm(conv_combinations):
        try:
            ms               = compute_ae_grid_looo(exp, pairs, combo, metric, cost, cota_version)
            grid_dict[combo], stds[combo] = get_looo_results(ms, pairs, looo, error, cota_version)
        except SolverError:
            print(f"SolverError occurred for combo: {combo}. Skipping this combo.")
            excluded_combos.append(combo)

    conv_combinations = [combo for combo in conv_combinations if combo not in excluded_combos]
    k_list, l_list, m_list = [], [], []
    for combo in conv_combinations:
        kappa, lmbda, mu = map(float, combo.split('-'))
        k_list.append(kappa)
        l_list.append(lmbda)
        m_list.append(mu)

    f_values    = list(grid_dict.values())
    std_values  = list(stds.values())

    data_dict = {
        'κ': k_list,
        'λ': l_list,
        'μ': m_list,
        'e(α)': f_values,
        'std': std_values
    }

    df_simplex = pd.DataFrame(data_dict)

    sorted_simplex = df_simplex.sort_values(by='e(α)', ascending=True)
    
    return sorted_simplex

In [None]:
def get_solution(sorted_simplex, m, c, cota_version):

    mean = np.mean(list(sorted_simplex['e(α)']))
    std  = np.std(list(sorted_simplex['e(α)']))

    min_index = sorted_simplex['e(α)'].idxmin()

    min_a = sorted_simplex.loc[min_index, 'κ']
    min_b = sorted_simplex.loc[min_index, 'λ']
    min_c = sorted_simplex.loc[min_index, 'μ']

    min_value = sorted_simplex.loc[min_index, 'e(α)']
    
    lookup_values = {'κ': 1.0, 'λ': 0.0, 'μ': 0.0}

    return min_value, min_a, min_b, min_c

In [None]:
relevant_experiments = [params.experiments[1], #synth1
                        params.experiments[9], #synth1T
                        params.experiments[7], #battery
                        params.experiments[6]] #lucas

n_grid_experiments   = {'synth1': 14,
                        'synth1T': 14,
                        'battery_discrete': 14,
                        'little_lucas': 12,
                        'synth1Tinv': 14
                       }

## COTA

In [None]:
errors    = ['jsd', 'wass']
versions  = ['avg_plan', 'avg_map']

for version in versions:
    print('COTA: ', version)
    print('----------------')
    for looo in [True]: #, False]:
        for error in errors:
            print(error)
            for exp in relevant_experiments:
                print('Experiment: ', exp)
                pairs  = ut.load_pairs(exp)
                n_grid = n_grid_experiments[exp]

                for metric in ['fro', 'jsd']:
                    for cost in ['Omega', 'Hamming']:
                        print(f"{metric}--{cost}")
                        df_simplex       = run_cota(exp, pairs, metric, cost, n_grid, looo, error, version)
                        min_val, k, l, m = get_solution(df_simplex, metric, cost, version)
                        
                        min_combo = create_comb(k, l, m)
                        ms   = compute_ae_grid_looo(exp, pairs, min_combo, metric, cost, version)
                        v, s = get_looo_results(ms, pairs, looo, error, version)
                        
                        print(f"e(τ*) = {min_val} ± {s*1.96} for κ: {k}, λ: {l}, μ: {m}")
                        
                        print( )
                print( )
            print( )
    print( )

## Baselines

In [None]:
errors  = ['jsd', 'wass']
versions = ['avg_plan', 'avg_map']

for version in versions:
    if version == 'avg_plan':
        equiv = 'Pairwise OT'
    elif version == 'avg_map':
        equiv = 'Map OT'
    print(f"Baseline: {equiv} (COTA {version} for κ = 1.0, λ = 0.0, μ = 0.0)")
    print('-------------------------------------------------------------------')
    for looo in [True]:#, False]:
        #print('LOOO: ', looo)
        for error in errors:
            print('e(τ): ', error)
            for exp in relevant_experiments:
                print('Experiment: ', exp)
                pairs  = ut.load_pairs(exp)
                n_grid = n_grid_experiments[exp]

                for metric in ['fro']:
                    for cost in ['Omega', 'Hamming']:
                        print(f"{cost}")
                        df_simplex       = run_cota(exp, pairs, metric, cost, 1, looo, error, version)
                        min_val, k, l, m = get_solution(df_simplex, metric, cost, version)
                        
                        min_combo = create_comb(k, l, m)
                        ms   = compute_ae_grid_looo(exp, pairs, min_combo, metric, cost, version)
                        v, s = get_looo_results(ms, pairs, looo, error, version)
                        
                        print(f"e(τ) = {min_val} ± {s*1.96}")
                        print( )
            print( )
    print( )