In [15]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import os

import json

In [16]:
# configs_to_evaluate= ['vi_pca_C_g_EOT', 'vi_pca_C_g_REOT_g+p', 'vi_pca_C_g_REOT_CC1_g+p+c', 'vi_pca_C_g_REOT_CC2_g+p+c']
configs_to_evaluate= ['vi_pca_C_g_EOT', 'vi_pca_C_g+p_EOT']

GSE_id = 'GSE232025'
use_all_data = False
run_version = 'v_everything_normalized'
# run_version = 'v_everything_normalized_no_spatial_cost_otregl_0.8'
# run_version = 'v_everything_normalized_no_spatial_cost_otregl_0.5'
# run_version = 'v_everything_normalized_no_spatial_cost_otregl_0.9'

base_folder = f'/Users/rssantanu/Desktop/codebase/constrained_FM/experiment_figures/use_all_data_{use_all_data}_{GSE_id}/{run_version}/'        

In [17]:
all_configs = os.listdir(base_folder)
IVP= {}
next_step_prediction= {}

for config in configs_to_evaluate:
    saved_configs= [c for c in all_configs if c.startswith(config)]
    for saved_config in saved_configs:
        folder_address = base_folder + saved_config + '/'
        IVP[saved_config]= {}
        next_step_prediction[saved_config]= {}

        try:
            IVP_dict= json.load(open(folder_address + 'IVP_error.json'))
            next_step_dict= json.load(open(folder_address + 'next_step_error.json'))

            IVP[saved_config]['wasserstein'] = IVP_dict['wasserstein']
            IVP[saved_config]['weighted_wasserstein'] = IVP_dict['weighted_wasserstein']
            IVP[saved_config]['mmd'] = IVP_dict['mmd']
            IVP[saved_config]['energy'] = IVP_dict['energy']

            next_step_prediction[saved_config]['wasserstein'] = next_step_dict['wasserstein']
            next_step_prediction[saved_config]['weighted_wasserstein'] = next_step_dict['weighted_wasserstein']
            next_step_prediction[saved_config]['mmd'] = next_step_dict['mmd']
            next_step_prediction[saved_config]['energy'] = next_step_dict['energy']
        except:
            print(f"Error loading {folder_address}")
            continue

In [18]:
IVP

{'vi_pca_C_g_EOT_0_1_2_3_4': {'wasserstein': [0.7615553276877063,
   0.3424755391456084,
   0.43356214268868143,
   0.2987493836207039,
   0.7412663108143633],
  'weighted_wasserstein': [1.7032972078112407,
   0.3448436290757992,
   0.4945201265840544,
   0.5288870353777141,
   1.2580904806021531],
  'mmd': [0.2725270092487335,
   0.029337937012314796,
   0.023281490430235863,
   0.06023034453392029,
   0.20863378047943115],
  'energy': [0.14737123114425615,
   0.16940047714738016,
   0.18778277324433518,
   0.07908379025218239,
   0.13384380062787907]},
 'vi_pca_C_g_EOT_0_1_2': {'wasserstein': [0.4798168971444777,
   0.35982104201503107,
   0.4821663934873119,
   0.3057441881120133,
   0.5431329534814919],
  'weighted_wasserstein': [0.663595973764743,
   0.3533180082501731,
   0.479080660891507,
   0.2749067274708152,
   0.6212008341803128],
  'mmd': [0.23514020442962646,
   0.03464913368225098,
   0.027812371030449867,
   0.07239928096532822,
   0.17294888198375702],
  'energy': [0.1

In [19]:
evaluate_on= -1

error_dict_IVP= {config: {'wasserstein': None, 'weighted_wasserstein': None, 'mmd': None} for config in configs_to_evaluate}
error_dict_next_step= {config: {'wasserstein': None, 'weighted_wasserstein': None, 'mmd': None} for config in configs_to_evaluate}


for config in configs_to_evaluate:
    all_keys_ivp= IVP.keys()
    all_keys_next_step= next_step_prediction.keys()

    wasserstein_error_data_ivp= []
    weighted_wasserstein_error_data_ivp= []
    mmd_error_data_ivp= []
    energy_error_data_ivp= []

    wasserstein_error_data_next_step= []
    weighted_wasserstein_error_data_next_step= []
    mmd_error_data_next_step= []
    energy_error_data_next_step= []
    
    for k in [c for c in all_keys_ivp if c.startswith(config)]:
        wasserstein_error_data_ivp.append(IVP[k]['wasserstein'][evaluate_on])
        weighted_wasserstein_error_data_ivp.append(IVP[k]['weighted_wasserstein'][evaluate_on])
        mmd_error_data_ivp.append(IVP[k]['mmd'][evaluate_on])
        energy_error_data_ivp.append(IVP[k]['energy'][evaluate_on])

        print(k, IVP[k]['wasserstein'][evaluate_on])
        

    for k in [c for c in all_keys_next_step if c.startswith(config)]:
        wasserstein_error_data_next_step.append(next_step_prediction[k]['wasserstein'][evaluate_on])
        weighted_wasserstein_error_data_next_step.append(next_step_prediction[k]['weighted_wasserstein'][evaluate_on])
        mmd_error_data_next_step.append(next_step_prediction[k]['mmd'][evaluate_on])
        energy_error_data_next_step.append(next_step_prediction[k]['energy'][evaluate_on])
        
    error_dict_IVP[config]['wasserstein']= f'{np.mean(wasserstein_error_data_ivp):.5f} ± {np.std(wasserstein_error_data_ivp):.5f}'
    error_dict_IVP[config]['weighted_wasserstein']= f'{np.mean(weighted_wasserstein_error_data_ivp):.5f} ± {np.std(weighted_wasserstein_error_data_ivp):.5f}'
    error_dict_IVP[config]['mmd']= f'{np.mean(mmd_error_data_ivp):.5f} ± {np.std(mmd_error_data_ivp):.5f}'
    error_dict_IVP[config]['energy']= f'{np.mean(energy_error_data_ivp):.5f} ± {np.std(energy_error_data_ivp):.5f}'

    error_dict_next_step[config]['wasserstein']= f'{np.mean(wasserstein_error_data_next_step):.5f} ± {np.std(wasserstein_error_data_next_step):.5f}'
    error_dict_next_step[config]['weighted_wasserstein']= f'{np.mean(weighted_wasserstein_error_data_next_step):.5f} ± {np.std(weighted_wasserstein_error_data_next_step):.5f}'
    error_dict_next_step[config]['mmd']= f'{np.mean(mmd_error_data_next_step):.5f} ± {np.std(mmd_error_data_next_step):.5f}'
    error_dict_next_step[config]['energy']= f'{np.mean(energy_error_data_next_step):.5f} ± {np.std(energy_error_data_next_step):.5f}'







    

vi_pca_C_g_EOT_0_1_2_3_4 0.7412663108143633
vi_pca_C_g_EOT_0_1_2 0.5431329534814919
vi_pca_C_g_EOT 0.6252541721036308
vi_pca_C_g_EOT_0_1_2_3 0.6433012962499282
vi_pca_C_g_EOT_0 0.685789630913079
vi_pca_C_g_EOT_0_1 0.5103492012224158
vi_pca_C_g+p_EOT_0_1_2_3_4 0.7104048345968373
vi_pca_C_g+p_EOT_0 0.7439655922007457
vi_pca_C_g+p_EOT_0_1_2_3 0.6133266601858455
vi_pca_C_g+p_EOT_0_1_2 0.6166932639658831
vi_pca_C_g+p_EOT_0_1 0.6455199606881105
vi_pca_C_g+p_EOT 0.6431117168585158


In [20]:
error_dict_IVP

{'vi_pca_C_g_EOT': {'wasserstein': '0.62485 ± 0.07895',
  'weighted_wasserstein': '0.70445 ± 0.25399',
  'mmd': '0.23793 ± 0.04747',
  'energy': '0.15307 ± 0.02393'},
 'vi_pca_C_g+p_EOT': {'wasserstein': '0.66217 ± 0.04850',
  'weighted_wasserstein': '0.62844 ± 0.07967',
  'mmd': '0.25679 ± 0.04097',
  'energy': '0.17228 ± 0.02017'}}

In [21]:
error_dict_next_step

{'vi_pca_C_g_EOT': {'wasserstein': '0.70828 ± 0.04294',
  'weighted_wasserstein': '0.73400 ± 0.03721',
  'mmd': '0.32601 ± 0.00275',
  'energy': '0.18637 ± 0.01882'},
 'vi_pca_C_g+p_EOT': {'wasserstein': '0.73551 ± 0.08559',
  'weighted_wasserstein': '0.76408 ± 0.08079',
  'mmd': '0.32370 ± 0.00571',
  'energy': '0.19378 ± 0.02303'}}