In [1]:
import numpy as np
import pandas as pd
from scipy.stats import normaltest, mannwhitneyu, ttest_ind
import plotly.graph_objects as go
import plotly.express as px
from causalbench.data.datasets import read_group_arc_num_as_dataframe
from IPython.display import display

In [2]:
pd.options.display.max_columns = None

In [3]:
df_pcstable = pd.read_csv('data_for_evaluation/pcstable_feedback_19Mar.csv', sep=';')

In [4]:
# replace pc_origin with pc_stable
df = pd.read_csv('data_for_evaluation/feeback_castle4__04Mar.csv',  sep=';')
df = df[df.algo_name != 'pc']
df = pd.concat([df, df_pcstable], ignore_index=True)
df

Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time
0,sim-02.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.5000,0.8333,0.7500,4.0,6.0,0.4000,0.6667,0.5000,0.0,0.16,Sat Mar 4 11:48:58 2023
1,sim-08.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.1667,1.1667,0.2500,1.0,6.0,0.6667,1.0000,0.8000,0.5,0.14,Sat Mar 4 11:48:58 2023
2,sim-03.Network1_amp.continuous,0.541667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.4286,1.0000,0.7500,3.0,7.0,0.4444,0.6667,0.5333,0.0,0.15,Sat Mar 4 11:48:58 2023
3,sim-06.Network1_amp.continuous,0.416667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.2857,1.1667,0.5000,2.0,7.0,0.5000,0.8333,0.6250,0.0,0.15,Sat Mar 4 11:48:58 2023
4,sim-05.Network1_amp.continuous,0.458333,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.6250,0.8333,1.2500,5.0,8.0,0.3077,0.6667,0.4211,0.0,0.17,Sat Mar 4 11:48:58 2023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4311,sim-45.Network9_cont_amp.continuous,0.500000,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,7.0,13.0,0.4286,0.6000,0.5000,0.0,0.81,Sun Mar 19 19:44:10 2023
4312,sim-58.Network9_cont_amp.continuous,0.542373,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,8.0,13.0,0.4286,0.6000,0.5000,0.0,0.54,Sun Mar 19 19:44:10 2023
4313,sim-50.Network9_cont_amp.continuous,0.449153,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,7.0,13.0,0.4286,0.6000,0.5000,0.0,0.67,Sun Mar 19 19:44:10 2023
4314,sim-51.Network9_cont_amp.continuous,0.466102,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.4167,0.7000,0.1923,6.0,12.0,0.4375,0.7000,0.5385,0.0,0.58,Sun Mar 19 19:44:10 2023


In [5]:
#add notears and notearslowrank data
df_notears = pd.read_csv('data_for_evaluation/notears_feedback_12Mar.csv',  sep=';')
df_notearslowrank = pd.read_csv('data_for_evaluation/notearslowrank_feedback_12Mar.csv', sep=';')
df = pd.concat([df, df_notears, df_notearslowrank], ignore_index=True)

In [6]:
def extract_group(name):
    
    if '-' in name:#case of feedback
        return name.split('.')[1] 
    elif 'real' in name:
        return name
    elif 'numdata' in name:
        return name
    elif 'dream' in name:
        return name
    elif '_' in name:
        return name.split('_')[0]
    else:
        return name

In [7]:
group_arc_num_df = read_group_arc_num_as_dataframe()

In [8]:
def get_arc_num_by_name(name:str):
    group = extract_group(name)
    row = group_arc_num_df.loc[group_arc_num_df['group'] == group]
    return row['arc_num'].values[0]

In [9]:
def normalize_shd(shd, arc_num):
    return 1-(shd/(arc_num*2))

In [10]:
df['group'] = df['dataset_name'].apply(extract_group)
df['arc_num'] = df['dataset_name'].apply(get_arc_num_by_name)
df['normalized_shd'] = normalize_shd(df['shd'], df['arc_num'])

In [11]:
def exclude_dataset_which_has_nan_in_metric(dataframe, metric:str):
    df_nan_rows = dataframe[dataframe[metric].isna()]
    to_exclude_dataset_names = df_nan_rows['dataset_name'].tolist()
    print(f"There're {len(to_exclude_dataset_names)} datasets excluded for {metric}.")
    df_without_nan = dataframe[~dataframe['dataset_name'].isin(to_exclude_dataset_names)]
    return df_without_nan

In [12]:
def show_scatter_plot(dataframe, x_lable, y_lable, color, symbol):
    fig = px.scatter(dataframe, x=x_lable, y=y_lable,color=color, symbol=symbol)
    fig.show()

In [13]:
df_without_nan_norm_shd = exclude_dataset_which_has_nan_in_metric(df, 'normalized_shd')
df_without_nan_f1 = exclude_dataset_which_has_nan_in_metric(df, 'F1')
df_without_nan_gscore = exclude_dataset_which_has_nan_in_metric(df, 'gscore')

# df_for_scatter_plot_f1 = df_without_nan_f1.groupby(['varsortability', 'algo_name'], as_index=False).agg({'F1': lambda x: np.median(x)})
# df_for_scatter_plot_gscore = df_without_nan_gscore.groupby(['varsortability', 'algo_name'], as_index=False).agg({'gscore': lambda x: np.median(x)})

There're 8 datasets excluded for normalized_shd.
There're 148 datasets excluded for F1.
There're 8 datasets excluded for gscore.


## Normalized SHD

In [21]:
df_for_scatter_plot_norm_shd = df_without_nan_norm_shd.groupby(['varsortability', 'algo_name'], as_index=False).agg({'normalized_shd': lambda x: np.median(x)})
show_scatter_plot(df_for_scatter_plot_norm_shd, 'varsortability', 'normalized_shd', color="algo_name", symbol="algo_name")

In [15]:
# slice dataframe into 3 pieces. low varsortability, mid varsortability, and high varsortability
df_for_scatter_plot_norm_shd_low_varsor = df_without_nan_norm_shd[(df_without_nan_norm_shd.varsortability >= 0.2) & (df_without_nan_norm_shd.varsortability < 0.4)]
df_for_scatter_plot_norm_shd_mid_varsor = df_without_nan_norm_shd[(df_without_nan_norm_shd.varsortability >= 0.4) & (df_without_nan_norm_shd.varsortability < 0.6)]
df_for_scatter_plot_norm_shd_high_varsor = df_without_nan_norm_shd[df_without_nan_norm_shd.varsortability >= 0.6]

### normalized shd low varsortability

In [16]:
norm_shd_low_varsor_runtime = df_for_scatter_plot_norm_shd_low_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'normalized_shd': lambda x: np.max(x)})
show_scatter_plot(norm_shd_low_varsor_runtime, 'runtime_second', 'normalized_shd', color="algo_name", symbol="algo_name")

In [17]:
fig_norm_shd_low_varsor_runtime = px.line(norm_shd_low_varsor_runtime, x = 'runtime_second', y = 'normalized_shd', color="algo_name")
fig_norm_shd_low_varsor_runtime.show()

### normalized shd mid varsortability

In [18]:
norm_shd_mid_varsor_runtime = df_for_scatter_plot_norm_shd_mid_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'normalized_shd': lambda x: np.max(x)})
show_scatter_plot(norm_shd_mid_varsor_runtime, 'runtime_second', 'normalized_shd', color="algo_name", symbol="algo_name")

In [90]:
fig_norm_shd_mid_varsor_runtime = px.line(norm_shd_mid_varsor_runtime, x = 'runtime_second', y = 'normalized_shd', color="algo_name")
fig_norm_shd_mid_varsor_runtime.show()

### normalized shd nigh varsortability

In [91]:
norm_shd_high_varsor_runtime = df_for_scatter_plot_norm_shd_high_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'normalized_shd': lambda x: np.max(x)})
show_scatter_plot(norm_shd_high_varsor_runtime, 'runtime_second', 'normalized_shd', color="algo_name", symbol="algo_name")

In [92]:
fig_norm_shd_high_varsor_runtime = px.line(norm_shd_high_varsor_runtime, x = 'runtime_second', y = 'normalized_shd', color="algo_name")
fig_norm_shd_high_varsor_runtime.show()

## G-Score

In [19]:
df_for_scatter_plot_gscore = df_without_nan_gscore.groupby(['varsortability', 'algo_name'], as_index=False).agg({'gscore': lambda x: np.median(x)})
show_scatter_plot(df_for_scatter_plot_gscore, 'varsortability', 'gscore', color="algo_name", symbol="algo_name")

In [94]:
# slice dataframe into 3 pieces. low varsortability, mid varsortability, and high varsortability
df_for_scatter_plot_gscore_low_varsor = df_without_nan_gscore[(df_without_nan_gscore.varsortability >= 0.2) & (df_without_nan_gscore.varsortability < 0.4)]
df_for_scatter_plot_gscore_mid_varsor = df_without_nan_gscore[(df_without_nan_gscore.varsortability >= 0.4) & (df_without_nan_gscore.varsortability < 0.6)]
df_for_scatter_plot_gscore_high_varsor = df_without_nan_gscore[df_without_nan_gscore.varsortability >= 0.6]

### gscore low varsortability

In [95]:
gscore_low_varsor_runtime = df_for_scatter_plot_gscore_low_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'gscore': lambda x: np.max(x)})
show_scatter_plot(gscore_low_varsor_runtime, 'runtime_second', 'gscore', color="algo_name", symbol="algo_name")

### gscore mid varsortability

In [96]:
gscore_mid_varsor_runtime = df_for_scatter_plot_gscore_mid_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'gscore': lambda x: np.max(x)})
show_scatter_plot(gscore_mid_varsor_runtime, 'runtime_second', 'gscore', color="algo_name", symbol="algo_name")

### gscore high varsortability

In [97]:
gscore_high_varsor_runtime = df_for_scatter_plot_gscore_high_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'gscore': lambda x: np.max(x)})
show_scatter_plot(gscore_high_varsor_runtime, 'runtime_second', 'gscore', color="algo_name", symbol="algo_name")

## F1-Score

In [20]:
df_for_scatter_plot_f1 = df_without_nan_gscore.groupby(['varsortability', 'algo_name'], as_index=False).agg({'F1': lambda x: np.median(x)})
show_scatter_plot(df_for_scatter_plot_f1, 'varsortability', 'F1', color="algo_name", symbol="algo_name")

In [67]:
# slice dataframe into 3 pieces. low varsortability, mid varsortability, and high varsortability
df_for_scatter_plot_f1_low_varsor = df_without_nan_f1[(df_without_nan_f1.varsortability >= 0.2) & (df_without_nan_f1.varsortability < 0.4)]
df_for_scatter_plot_f1_mid_varsor = df_without_nan_f1[(df_without_nan_f1.varsortability >= 0.4) & (df_without_nan_f1.varsortability < 0.6)]
df_for_scatter_plot_f1_high_varsor = df_without_nan_f1[df_without_nan_f1.varsortability >= 0.6]

### F1 low varsorbatility

In [69]:
f1_low_varsor_runtime = df_for_scatter_plot_f1_low_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'F1': lambda x: np.max(x)})
show_scatter_plot(f1_low_varsor_runtime, 'runtime_second', 'F1', color="algo_name", symbol="algo_name")

### F1 mid varsortability

In [70]:
f1_mid_varsor_runtime = df_for_scatter_plot_f1_mid_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'F1': lambda x: np.max(x)})
show_scatter_plot(f1_mid_varsor_runtime, 'runtime_second', 'F1', color="algo_name", symbol="algo_name")

### F1 high varsortability

In [71]:
f1_high_varsor_runtime = df_for_scatter_plot_f1_high_varsor.groupby(['runtime_second', 'algo_name'], as_index=False).agg({'F1': lambda x: np.max(x)})
show_scatter_plot(f1_high_varsor_runtime, 'runtime_second', 'F1', color="algo_name", symbol="algo_name")