In [1]:
import itertools

import random
import numpy as np
import pandas as pd
import cvxpy as cp
from tqdm import tqdm

from scipy.spatial import distance
from cvxpy.error import SolverError
from cvxpy import ECOS, SCS
import seaborn as sns
import ot
from pgmpy import inference
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance
import plotly.graph_objects as go
import plotly.graph_objs as go
import plotly.express as px

from plotly.subplots import make_subplots
from mpltern.datasets import get_triangular_grid
import scipy.optimize as optimize
from scipy.spatial.distance import cdist
from src.examples import smokingmodels as sm
from scipy.spatial.distance import squareform,pdist
from scipy.optimize import linprog
from scipy import stats
from scipy.stats import wasserstein_distance

from IPython.utils import io
import warnings

import joblib
import modularized_utils as ut
import abstraction_metrics as ams
import matplotlib.pylab as pl

import get_results
import params

np.random.seed(0)

warnings.filterwarnings(action='ignore')
np.set_printoptions(precision=4,suppress=True)

In [2]:
def create_comb(a, b, c):
    return str(a)+'-'+str(b)+'-'+str(c)

def parse_comb(input_string):
    a, b, c = input_string.split('-')
    return [float(a), float(b), float(c)]

In [3]:
def leave_one_out_grid(pairs, dropped_pair, experiment, combination, df, cf, cota_version):
    
    omega = {}
    for pair in pairs:
        omega[pair.iota_base] = pair.iota_abst
    
    hold_pairs = pairs.copy()
   
    if dropped_pair != None:
        hold_pairs.remove(dropped_pair)
        hold_omega  = ut.drop1omega(omega, dropped_pair.iota_base)
        
    else:
        hold_omega = omega
        
    I_relevant  = list(hold_omega.keys())

    struc, tree = ut.build_poset(I_relevant)
    chains      = ut.to_chains(hold_pairs, struc)
    
    combin = parse_comb(combination)
    kk, ll, mm = combin[0], combin[1], combin[2]
    
    args   = [hold_pairs, [chains], kk, ll, mm, df, cf]
    
    if cota_version == 'avg_plan':
        get_results.results_grid_looo(args, experiment, dropped_pair)
        looo_grid_results = ut.load_grid_results_looo(experiment, combination, dropped_pair)
        
    elif cota_version == 'avg_map':
        get_results.results_grid_looo_aggregated(args, experiment, dropped_pair)
        looo_grid_results = ut.load_grid_results_looo_aggregated(experiment, combination, dropped_pair)
    
    return looo_grid_results

In [4]:
def compute_ae_grid_looo(exp, pairs, combo, metric, cost, cota_version):
    
    looo_results_grid = []
    for n in range(len(pairs)):

        results_grid = {}
        for i, pair in enumerate(pairs[n]):

            if pair.iota_base.intervention != {None: None}:
                dropped_pair = pair
            else:
                dropped_pair = None

            results_grid[dropped_pair] = leave_one_out_grid(pairs[n], dropped_pair, exp, combo, metric, cost, cota_version)

        looo_results_grid.append(results_grid)
    
    return looo_results_grid

In [5]:
def get_looo_results(maps, pairs, looo, error, cota_version):
    
    no_looo_map = maps[0][list(maps[0].keys())[0]]
    
    if cota_version == 'avg_plan':
        no_looo_map = no_looo_map[0]
    
    n_sims = len(maps)
    avg_list = []
    for n in range(n_sims):

        ae = 0
        for pair in pairs[n]:
            #print('base: ', pair.iota_base.intervention, 'abst: ', pair.iota_abst.intervention)
            if pair.iota_base.intervention == {None: None}:
                p = None
            else:
                p = pair
            
            tau_dict = maps[n][p]
            
            if cota_version == 'avg_plan':
                tau_dict = tau_dict[0]
                
            if looo == True:
                tau = ams.to_tuples(tau_dict, 'stochastic')
            else:
                tau = ams.to_tuples(no_looo_map, 'stochastic')

            pushforward = ams.stochastic_pushforward(tau, pair.base_dict, list(pair.abst_dict.keys()))
            
            if error == 'jsd':
                d = distance.jensenshannon(pushforward, pair.abst_distribution)
                ae += d
            elif error == 'wass':
                d = wasserstein_distance(pushforward, pair.abst_distribution)
                ae += d

        avg_list.append(ae/len(pairs[n]))

    return np.mean(avg_list), np.std(avg_list)

In [6]:
def run_cota(exp, pairs, metric, cost, n_grid, looo, error, cota_version):    

    track_MI = False
    if n_grid != 1:
        
        k_list, l_list, m_list = get_triangular_grid(n_grid)
        conv_combinations = []
        for elements in zip(k_list, l_list, m_list):
            kappa, lmbda, mu = elements[0], elements[1], elements[2]
            conv_combinations.append(create_comb(kappa, lmbda, mu))
            
    else:
        conv_combinations = ['1.0-0.0-0.0']

    grid_dict, stds = {}, {}
    excluded_combos     = []
    for combo in tqdm(conv_combinations):
        try:
            ms               = compute_ae_grid_looo(exp, pairs, combo, metric, cost, cota_version)
            grid_dict[combo], stds[combo] = get_looo_results(ms, pairs, looo, error, cota_version)
        except SolverError:
            print(f"SolverError occurred for combo: {combo}. Skipping this combo.")
            excluded_combos.append(combo)

    conv_combinations = [combo for combo in conv_combinations if combo not in excluded_combos]
    k_list, l_list, m_list = [], [], []
    for combo in conv_combinations:
        kappa, lmbda, mu = map(float, combo.split('-'))
        k_list.append(kappa)
        l_list.append(lmbda)
        m_list.append(mu)

    f_values    = list(grid_dict.values())
    std_values  = list(stds.values())

    data_dict = {
        'κ': k_list,
        'λ': l_list,
        'μ': m_list,
        'e(α)': f_values,
        'std': std_values
    }

    df_simplex = pd.DataFrame(data_dict)

    sorted_simplex = df_simplex.sort_values(by='e(α)', ascending=True)
    
    return sorted_simplex

In [7]:
def get_solution(sorted_simplex, m, c, cota_version):

    mean = np.mean(list(sorted_simplex['e(α)']))
    std  = np.std(list(sorted_simplex['e(α)']))

    min_index = sorted_simplex['e(α)'].idxmin()

    min_a = sorted_simplex.loc[min_index, 'κ']
    min_b = sorted_simplex.loc[min_index, 'λ']
    min_c = sorted_simplex.loc[min_index, 'μ']

    min_value = sorted_simplex.loc[min_index, 'e(α)']
    
    lookup_values = {'κ': 1.0, 'λ': 0.0, 'μ': 0.0}

    pwise = sorted_simplex.loc[(sorted_simplex['κ'] == lookup_values['κ']) & (sorted_simplex['λ'] == lookup_values['λ']) & (sorted_simplex['μ'] == lookup_values['μ'])]

    pw_mean = pwise['e(α)'][0]
    pw_std  = pwise['std'][0]

    return min_value, min_a, min_b, min_c

In [8]:
relevant_experiments = [params.experiments[1], #synth1
                        params.experiments[9], #synth1T
                        params.experiments[7], #battery
                        params.experiments[6]] #lucas

n_grid_experiments   = {'synth1': 14,
                        'synth1T': 14,
                        'battery_discrete': 14,
                        'little_lucas': 12,
                        'synth1Tinv': 14
                       }

In [9]:
relevant_experiments = [params.experiments[1]]

In [None]:
errors    = ['jsd']#, 'wass']
versions  = ['avg_plan']#, 'avg_map']

for version in versions:
    print('COTA: ', version)
    print('----------------')
    for looo in [True]: #, False]:
        for error in errors:
            print(error)
            for exp in relevant_experiments:
                print('Experiment: ', exp)
                pairs  = ut.load_pairs(exp)
                n_grid = n_grid_experiments[exp]

                for metric in ['fro', 'jsd']:
                    for cost in ['Omega', 'Hamming']:
                        print(f"{metric}--{cost}")
                        df_simplex       = run_cota(exp, pairs, metric, cost, n_grid, looo, error, version)
                        min_val, k, l, m = get_solution(df_simplex, metric, cost, version)
                        
                        min_combo = create_comb(k, l, m)
                        ms   = compute_ae_grid_looo(exp, pairs, min_combo, metric, cost, version)
                        v, s = get_looo_results(ms, pairs, looo, error, version)
                        
                        print(f"e(τ*) = {min_val} ± {s*1.96} for κ: {k}, λ: {l}, μ: {m}")
                        
                        print( )
                print( )
            print( )
    print( )

COTA:  avg_plan
----------------
jsd
Experiment:  synth1
fro--Omega


 82%|██████████████████████████████████▍       | 86/105 [01:15<00:16,  1.14it/s]

# COTA

In [9]:
errors    = ['jsd', 'wass']
versions  = ['avg_plan', 'avg_map']

for version in versions:
    print('COTA: ', version)
    print('----------------')
    for looo in [True]: #, False]:
        for error in errors:
            print(error)
            for exp in relevant_experiments:
                print('Experiment: ', exp)
                pairs  = ut.load_pairs(exp)
                n_grid = n_grid_experiments[exp]

                for metric in ['fro', 'jsd']:
                    for cost in ['Omega', 'Hamming']:
                        print(f"{metric}--{cost}")
                        df_simplex       = run_cota(exp, pairs, metric, cost, n_grid, looo, error, version)
                        min_val, k, l, m = get_solution(df_simplex, metric, cost, version)
                        
                        min_combo = create_comb(k, l, m)
                        ms   = compute_ae_grid_looo(exp, pairs, min_combo, metric, cost, version)
                        v, s = get_looo_results(ms, pairs, looo, error, version)
                        
                        print(f"e(τ*) = {min_val} ± {s*1.96} for κ: {k}, λ: {l}, μ: {m}")
                        
                        print( )
                print( )
            print( )
    print( )

COTA:  avg_plan
----------------
jsd
Experiment:  synth1
fro--Omega


100%|█████████████████████████████████████████| 105/105 [43:45<00:00, 25.00s/it]


e(τ*) = 0.011333459551271968 ± 0.005081186835427099 for κ: 0.3846153846153846, λ: 0.46153846153846145, μ: 0.15384615384615374

fro--Hamming


100%|█████████████████████████████████████████| 105/105 [40:26<00:00, 23.11s/it]


e(τ*) = 0.012753006864678723 ± 0.0038820245884312204 for κ: 0.46153846153846145, λ: 0.5384615384615384, μ: 0.0

jsd--Omega


100%|███████████████████████████████████████| 105/105 [1:43:22<00:00, 59.08s/it]


e(τ*) = 0.00897379600081668 ± 0.004468844159639562 for κ: 0.6153846153846154, λ: 0.3846153846153846, μ: 0.0

jsd--Hamming


100%|███████████████████████████████████████| 105/105 [1:24:26<00:00, 48.25s/it]


e(τ*) = 0.035049979567718326 ± 0.005805373148281864 for κ: 0.6153846153846154, λ: 0.23076923076923073, μ: 0.15384615384615374



wass
Experiment:  synth1
fro--Omega


100%|█████████████████████████████████████████| 105/105 [44:57<00:00, 25.69s/it]


e(τ*) = 0.010440875165478045 ± 0.0030794764802298488 for κ: 0.3076923076923077, λ: 0.5384615384615384, μ: 0.15384615384615374

fro--Hamming


100%|█████████████████████████████████████████| 105/105 [40:03<00:00, 22.89s/it]


e(τ*) = 0.011372602389263716 ± 0.002020223029299705 for κ: 0.46153846153846145, λ: 0.5384615384615384, μ: 0.0

jsd--Omega


100%|███████████████████████████████████████| 105/105 [1:21:20<00:00, 46.48s/it]


e(τ*) = 0.00927460823225766 ± 0.0018261130803632272 for κ: 0.6923076923076923, λ: 0.3076923076923077, μ: 0.0

jsd--Hamming


100%|███████████████████████████████████████| 105/105 [1:15:05<00:00, 42.91s/it]


e(τ*) = 0.016744373478933802 ± 0.0025521976932291117 for κ: 0.6153846153846154, λ: 0.23076923076923073, μ: 0.15384615384615374




COTA:  avg_map
----------------
jsd
Experiment:  synth1
fro--Omega


100%|█████████████████████████████████████████| 105/105 [38:13<00:00, 21.84s/it]


e(τ*) = 0.015123437164394604 ± 0.021926233521647978 for κ: 0.6153846153846154, λ: 0.3846153846153846, μ: 0.0

fro--Hamming


100%|█████████████████████████████████████████| 105/105 [34:14<00:00, 19.57s/it]


e(τ*) = 0.09705467818511497 ± 0.009544632696508813 for κ: 0.07692307692307687, λ: 0.9230769230769231, μ: 0.0

jsd--Omega


100%|███████████████████████████████████████| 105/105 [1:19:27<00:00, 45.41s/it]


e(τ*) = 0.014857535284800348 ± 0.022003139266672533 for κ: 1.0, λ: 0.0, μ: 0.0

jsd--Hamming


100%|███████████████████████████████████████| 105/105 [1:15:19<00:00, 43.04s/it]


e(τ*) = 0.13965350910531624 ± 0.011942158472430207 for κ: 0.5384615384615384, λ: 0.3846153846153846, μ: 0.07692307692307687



wass
Experiment:  synth1
fro--Omega


100%|█████████████████████████████████████████| 105/105 [38:15<00:00, 21.87s/it]


e(τ*) = 0.17141374122099118 ± 0.0019639033996429185 for κ: 0.9230769230769231, λ: 0.07692307692307687, μ: 0.0

fro--Hamming


100%|█████████████████████████████████████████| 105/105 [34:16<00:00, 19.59s/it]


e(τ*) = 0.17545775082769283 ± 0.001309767414390407 for κ: 0.7692307692307692, λ: 0.23076923076923073, μ: 0.0

jsd--Omega


100%|███████████████████████████████████████| 105/105 [1:19:29<00:00, 45.43s/it]


e(τ*) = 0.17142012518893618 ± 0.001966357801969946 for κ: 0.9230769230769231, λ: 0.07692307692307687, μ: 0.0

jsd--Hamming


100%|███████████████████████████████████████| 105/105 [1:15:27<00:00, 43.12s/it]


e(τ*) = 0.17523740951302746 ± 0.0011152228440724484 for κ: 0.9230769230769231, λ: 0.07692307692307687, μ: 0.0






# Baselines

In [13]:
errors  = ['jsd', 'wass']
versions = ['avg_plan', 'avg_map']

for version in versions:
    if version == 'avg_plan':
        equiv = 'Pairwise OT'
    elif version == 'avg_map':
        equiv = 'Map OT'
    print(f"Baseline: {equiv} (COTA {version} for κ = 1.0, λ = 0.0, μ = 0.0)")
    print('-------------------------------------------------------------------')
    for looo in [True]:#, False]:
        #print('LOOO: ', looo)
        for error in errors:
            print('e(τ): ', error)
            for exp in relevant_experiments:
                print('Experiment: ', exp)
                pairs  = ut.load_pairs(exp)
                n_grid = n_grid_experiments[exp]

                for metric in ['fro']:
                    for cost in ['Omega', 'Hamming']:
                        print(f"{cost}")
                        df_simplex       = run_cota(exp, pairs, metric, cost, 1, looo, error, version)
                        min_val, k, l, m = get_solution(df_simplex, metric, cost, version)
                        
                        min_combo = create_comb(k, l, m)
                        ms   = compute_ae_grid_looo(exp, pairs, min_combo, metric, cost, version)
                        v, s = get_looo_results(ms, pairs, looo, error, version)
                        
                        print(f"e(τ) = {min_val} ± {s*1.96}")
                        print( )
            print( )
    print( )

Baseline: Pairwise OT (COTA avg_plan for κ = 1.0, λ = 0.0, μ = 0.0)
-------------------------------------------------------------------
e(τ):  jsd
Experiment:  synth1
Omega


100%|█████████████████████████████████████████████| 1/1 [00:22<00:00, 22.93s/it]


e(τ) = 0.013365752169570484 ± 0.005703081927422738

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:20<00:00, 20.11s/it]


e(τ) = 0.0874006202661453 ± 0.01080361781155952

Experiment:  synth1T
Omega


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/it]


e(τ) = 0.27984442215420646 ± 0.01484203485043764

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.04s/it]


e(τ) = 0.2422169188030463 ± 0.0028675295887772697

Experiment:  battery_discrete
Omega


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.06s/it]


e(τ) = 0.43049939388373837 ± 0.0

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:02<00:00,  2.98s/it]


e(τ) = 0.263570442776218 ± 0.0

Experiment:  little_lucas
Omega


100%|████████████████████████████████████████████| 1/1 [05:04<00:00, 304.66s/it]


e(τ) = 0.3029088851885152 ± 0.01063226897072012

Hamming


100%|████████████████████████████████████████████| 1/1 [05:10<00:00, 310.25s/it]


e(τ) = 0.38574553250813265 ± 0.0038071570033329167


e(τ):  wass
Experiment:  synth1
Omega


100%|█████████████████████████████████████████████| 1/1 [00:20<00:00, 20.82s/it]


e(τ) = 0.012825246305755576 ± 0.0032061953066875146

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:18<00:00, 18.51s/it]


e(τ) = 0.03613869354989748 ± 0.0046215134217675105

Experiment:  synth1T
Omega


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.49s/it]


e(τ) = 0.0919390714915905 ± 0.005909152450128166

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:02<00:00,  2.96s/it]


e(τ) = 0.06770152500575735 ± 0.0013537575532277823

Experiment:  battery_discrete
Omega


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.05s/it]


e(τ) = 0.02761852584995019 ± 0.0

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.02s/it]


e(τ) = 0.027348418053273205 ± 0.0

Experiment:  little_lucas
Omega


100%|████████████████████████████████████████████| 1/1 [05:07<00:00, 307.29s/it]


e(τ) = 0.04495985734823442 ± 0.000942682387552874

Hamming


100%|████████████████████████████████████████████| 1/1 [05:24<00:00, 324.13s/it]


e(τ) = 0.04734271323217171 ± 0.0010117178115785542



Baseline: Map OT (COTA avg_map for κ = 1.0, λ = 0.0, μ = 0.0)
-------------------------------------------------------------------
e(τ):  jsd
Experiment:  synth1
Omega


100%|█████████████████████████████████████████████| 1/1 [00:23<00:00, 23.85s/it]


e(τ) = 0.018897398819199825 ± 0.022088567526157163

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:20<00:00, 20.79s/it]


e(τ) = 0.17129022552765405 ± 0.004998944519066082

Experiment:  synth1T
Omega


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.94s/it]


e(τ) = 0.2506531994365807 ± 0.005681344274833881

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.39s/it]


e(τ) = 0.22948784338721048 ± 0.004988407367197299

Experiment:  battery_discrete
Omega


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.79s/it]


e(τ) = 0.40861044788908335 ± 0.0

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.39s/it]


e(τ) = 0.2265346588215272 ± 0.0

Experiment:  little_lucas
Omega


100%|████████████████████████████████████████████| 1/1 [05:46<00:00, 346.93s/it]


e(τ) = 0.28026096486866725 ± 0.010740829188340089

Hamming


100%|████████████████████████████████████████████| 1/1 [05:51<00:00, 351.25s/it]


e(τ) = 0.341848891041967 ± 0.005577075255617967


e(τ):  wass
Experiment:  synth1
Omega


100%|█████████████████████████████████████████████| 1/1 [00:23<00:00, 23.84s/it]


e(τ) = 0.1714544646591796 ± 0.0020770761586346286

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:20<00:00, 20.93s/it]


e(τ) = 0.17869706822877063 ± 0.0010427843940799678

Experiment:  synth1T
Omega


100%|█████████████████████████████████████████████| 1/1 [00:04<00:00,  4.13s/it]


e(τ) = 0.14090439871764213 ± 0.0004478217304412343

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.41s/it]


e(τ) = 0.12950024421337603 ± 0.00017244029033372732

Experiment:  battery_discrete
Omega


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/it]


e(τ) = 0.06044807191610426 ± 0.0

Hamming


100%|█████████████████████████████████████████████| 1/1 [00:03<00:00,  3.34s/it]


e(τ) = 0.05372263651470203 ± 0.0

Experiment:  little_lucas
Omega


100%|████████████████████████████████████████████| 1/1 [05:50<00:00, 350.96s/it]


e(τ) = 0.06095940563605817 ± 0.0003017924996866767

Hamming


100%|████████████████████████████████████████████| 1/1 [05:09<00:00, 309.21s/it]


e(τ) = 0.060346193891111975 ± 0.00030435280111836384



