In [1]:
import itertools
import random
import numpy as np
import pandas as pd
import seaborn as sns
import cvxpy as cp
from tqdm import tqdm

from scipy.spatial import distance
from cvxpy.error import SolverError
from cvxpy import ECOS, SCS
import seaborn as sns
import ot
from pgmpy import inference
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance

import scipy.optimize as optimize
from scipy.spatial.distance import cdist
from src.examples import smokingmodels as sm
from scipy.spatial.distance import squareform,pdist
from scipy.stats import wasserstein_distance

from scipy.optimize import linprog
from scipy import stats
from IPython.utils import io
import warnings

import joblib
import modularized_utils as ut
import abstraction_metrics as ams
import matplotlib.pylab as pl

import get_results
import params

np.random.seed(0)

warnings.filterwarnings(action='ignore')
np.set_printoptions(precision=4,suppress=True)

In [2]:
def leave_one_out(pairs, dropped_pair, experiment, mode):
    
    omega = {}
    for pair in pairs:
        omega[pair.iota_base] = pair.iota_abst
    
    hold_pairs = pairs.copy()
   
    if dropped_pair != None:
        hold_pairs.remove(dropped_pair)
        hold_omega  = ut.drop1omega(omega, dropped_pair.iota_base)
        
    else:
        hold_omega = omega
        
    I_relevant  = list(hold_omega.keys())

    struc, tree = ut.build_poset(I_relevant)
    chains      = ut.to_chains(hold_pairs, struc)
    
    args   = [hold_pairs, [chains], params.lmbdas[experiment]]
    
    get_results.results(mode, args, experiment, dropped_pair)
    
    looo_results = ut.load_results(mode, experiment, dropped_pair)
    
    return looo_results

In [3]:
def compute_ae_baselines(mode, pairs, maps, error, looo, n_sims, mappings, costs):

    n_sims = len(maps)
    k = 0 
    all_results = []
    all_strings = []
    taus = []
    tau_dict = {}
    for n in range(n_sims):
        noloopair = list(maps[n].keys())[0]
        results = []
        strings = []
        for mapping in mappings:
            for cost in costs:
                ae = 0
                for pair in pairs[n]:
                    
                    if looo == True:
                        tau = ams.to_tuples(maps[n][pair][mode][0][mapping][cost], mapping)
                        
                    elif looo == False:
                        tau = ams.to_tuples(maps[n][noloopair][mode][0][mapping][cost], mapping)
                        
                    taus.append(tau)   
                    if mapping != 'stochastic':
                        pushforward = ams.discrete_pushforward(tau, pair.base_dict, list(pair.abst_dict.keys()))
                    else:
                        pushforward = ams.stochastic_pushforward(tau, pair.base_dict, list(pair.abst_dict.keys()))
                        
                    if error == 'jsd':
                        d   = distance.jensenshannon(pushforward, pair.abst_distribution)
                        ae += d
                    elif error == 'wass':
                        d   = wasserstein_distance(pushforward, pair.abst_distribution)
                        ae += d
                tau_dict[cost] = taus        
                strings.append('{0}'.format(cost))
                results.append(ae/len(pairs[n]))

        all_results.append(results)
        all_strings.append(strings)

    averages = np.mean(all_results, axis=0)
    std_devs = np.std(all_results, axis=0)

    means_pw = {}
    for i, string in enumerate(all_strings[0]):
        means_pw[string] = averages[i]

    std_pw = {}
    for j, string in enumerate(all_strings[0]):
        std_pw[string] = std_devs[j]
    
    for k in list(means_pw.keys()):
        print ('{0}: {1} ± {2}'.format(k, means_pw[k], std_pw[k]*1.96))
        
    return #tau_dict

In [4]:
def run_baselines(exp, pairs):

    looo_results = {}
    for mode in ['bary', 'agg']:

        looo_list = []

        for n in tqdm(range(len(pairs))):

            results = {}
            for i, pair in enumerate(pairs[n]):

                if pair.iota_base.intervention != {None: None}:
                    dropped_pair = pair
                else:
                    dropped_pair = None

                results[dropped_pair] = leave_one_out(pairs[n], dropped_pair, exp, mode)

            looo_list.append(results)

        looo_results[mode] = looo_list

        n_sims   = len(looo_results[mode])
        mappings = list(looo_results[mode][0][None][mode][0].keys())
        costs    =  list(looo_results[mode][0][None][mode][0][mappings[0]].keys())

        # Make the results keys same with the pairs-to-evaluate                
        for n in range(n_sims):    
            looo_results[mode][n] = dict(zip(pairs[n], looo_results[mode][n].values()))
    
    return looo_results, n_sims, mappings, costs

In [6]:
relevant_experiments = [params.experiments[1], #synth1
                        params.experiments[9], #synth1T
                        params.experiments[7], #battery
                        params.experiments[6]] #lucas

In [7]:
modes  = ['bary', 'agg']
errors = ['wass', 'jsd']
for experiment in relevant_experiments:
    print('Experiment: ', experiment)
    print('===============================================')
    pairs = ut.load_pairs(experiment)
    results, n_sims, mappings, costs = run_baselines(experiment, pairs)
    for mode in modes:
        for looo in [True]:#, False]:
            for error in errors:
                print('ERROR: ', error)
                print('Baseline: ', mode) #, '<> LOOO = ', looo, '<> e(α): ',error)
                compute_ae_baselines(mode, pairs, results[mode], error, looo, n_sims, mappings, costs)
                print( )
        print( )
    print( )

Experiment:  synth1Tinv


100%|███████████████████████████████████████████| 50/50 [00:03<00:00, 15.37it/s]
100%|███████████████████████████████████████████| 50/50 [00:10<00:00,  4.59it/s]

ERROR:  wass
Baseline:  bary
Omega: 0.08422377990573754 ± 0.056303104456332326
Hamming: 0.069261984673563 ± 0.05232342103150767

ERROR:  jsd
Baseline:  bary
Omega: 0.24693241239399819 ± 0.16059666022079172
Hamming: 0.26100087473358263 ± 0.196877595797778


ERROR:  wass
Baseline:  agg
Omega: 0.1283099611794554 ± 0.00029635049034922354
Hamming: 0.12204327803751075 ± 0.00021057687825385058

ERROR:  jsd
Baseline:  agg
Omega: 0.15189889643174964 ± 0.0023940697850549815
Hamming: 0.16739796201619803 ± 0.0018975755072508423






