In [129]:
import numpy as np
import pandas as pd
from scipy.stats import normaltest, mannwhitneyu, ttest_ind
import plotly.graph_objects as go
import plotly.express as px
from causalbench.data.datasets import read_group_arc_num_as_dataframe
from IPython.display import display

In [130]:
pd.options.display.max_columns = None
df = pd.read_csv('data_for_evaluation/feeback_castle4__04Mar.csv',  sep=';')
df

Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time
0,sim-13.Network1_amp.continuous,0.666667,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.4000,0.8333,0.5000,3.0,5.0,0.6667,0.6667,0.6667,0.3333,0.14,Sat Mar 4 11:48:53 2023
1,sim-41.Network1_amp.continuous,0.208333,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.1667,1.1667,0.2500,1.0,6.0,0.6667,1.0000,0.8000,0.5000,0.17,Sat Mar 4 11:48:53 2023
2,sim-59.Network1_amp.continuous,0.291667,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.2000,1.0000,0.2500,1.0,5.0,0.8333,0.8333,0.8333,0.6667,0.12,Sat Mar 4 11:48:53 2023
3,sim-10.Network1_amp.continuous,0.166667,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.5000,0.8333,0.7500,3.0,6.0,0.5000,0.5000,0.5000,0.0000,0.17,Sat Mar 4 11:48:53 2023
4,sim-04.Network1_amp.continuous,0.250000,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.2857,1.1667,0.5000,2.0,7.0,0.6000,1.0000,0.7500,0.3333,0.16,Sat Mar 4 11:48:53 2023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4311,sim-56.Network9_cont_amp.continuous,0.415254,9,500,directlingam,"{'prior_knowledge': None, 'measure': 'pwling',...",gCastle,No Error,0.1667,0.5000,0.0385,5.0,6.0,0.8333,0.5000,0.6250,0.4000,0.28,Sat Mar 4 11:49:37 2023
4312,sim-45.Network9_cont_amp.continuous,0.500000,9,500,directlingam,"{'prior_knowledge': None, 'measure': 'pwling',...",gCastle,No Error,0.0000,0.5000,0.0000,5.0,5.0,1.0000,0.5000,0.6667,0.5000,0.32,Sat Mar 4 11:49:37 2023
4313,sim-43.Network9_cont_amp.continuous,0.415254,9,500,directlingam,"{'prior_knowledge': None, 'measure': 'pwling',...",gCastle,No Error,0.3333,0.4000,0.0769,6.0,6.0,0.6667,0.4000,0.5000,0.2000,0.31,Sat Mar 4 11:49:37 2023
4314,sim-59.Network9_cont_amp.continuous,0.474576,9,500,directlingam,"{'prior_knowledge': None, 'measure': 'pwling',...",gCastle,No Error,0.4444,0.5000,0.1538,6.0,9.0,0.5556,0.5000,0.5263,0.1000,0.27,Sat Mar 4 11:49:37 2023


In [131]:
#add notears and notearslowrank data
df1 = pd.read_csv('data_for_evaluation/notears_feedback_12Mar.csv',  sep=';')
df2 = pd.read_csv('data_for_evaluation/notearslowrank_feedback_12Mar.csv', sep=';')
df = pd.concat([df, df1, df2], ignore_index=True)

In [132]:
def extract_group(name):
    
    if '-' in name:#case of feedback
        return name.split('.')[1] 
    elif 'real' in name:
        return name
    elif 'numdata' in name:
        return name
    elif 'dream' in name:
        return name
    elif '_' in name:
        return name.split('_')[0]
    else:
        return name

In [133]:
group_arc_num_df = read_group_arc_num_as_dataframe()
group_arc_num_df

Unnamed: 0,group,arc_num
0,Network1_amp,6
1,Network2_amp,7
2,Network3_amp,7
3,Network4_amp,19
4,Network5_amp,5
...,...,...
90,dream4_1,176
91,dream4_2,249
92,dream4_3,195
93,dream4_4,211


In [134]:
def get_arc_num_by_name(name:str):
    group = extract_group(name)
    row = group_arc_num_df.loc[group_arc_num_df['group'] == group]
    return row['arc_num'].values[0]

In [135]:
def normalize_shd(shd, arc_num):
    return 1-(shd/(arc_num*2))

In [136]:
df['group'] = df['dataset_name'].apply(extract_group)

In [137]:
df['arc_num'] = df['dataset_name'].apply(get_arc_num_by_name)

In [153]:
def exclude_dataset_which_has_nan_in_metric(dataframe, metric:str):
    df_nan_rows = dataframe[dataframe[metric].isna()]
    to_exclude_dataset_names = df_nan_rows['dataset_name'].tolist()
    print(f"There're {len(to_exclude_dataset_names)} datasets excluded for {metric}.")
    df_without_nan = dataframe[~dataframe['dataset_name'].isin(to_exclude_dataset_names)]
    return df_without_nan
    

In [154]:
df['normalized_shd'] = normalize_shd(df['shd'], df['arc_num'])

In [155]:
df_without_nan_norm_shd = exclude_dataset_which_has_nan_in_metric(df, 'normalized_shd')

There're 8 datasets excluded for normalized_shd.


In [156]:
algos = ['pc', 'ges', 'icalingam', 'directlingam','notears', 'notearslowrank']
def show_violin_plots(dataframe, metric:str):
    df = dataframe
    for i in range(0, len(algos), 4):
        stop =  i + 4 if i + 4 < len(algos) else len(algos)
        fig = go.Figure()
        for algo in algos[i:stop]:
                fig.add_trace(go.Violin(x=df['algo_name'][df['algo_name'] == algo],
                                    y=df[metric][df['algo_name'] == algo],
                                    name=algo,
                                    box_visible=True,
                                    meanline_visible=True))
        fig.update_layout(font_size=26)
        fig.update_layout(showlegend=False) 
        fig.update_yaxes(title_text=metric)        
        fig.show()
show_violin_plots(df_without_nan_norm_shd, 'normalized_shd')     

In [161]:
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_norm_shd['normalized_shd'][df_without_nan_norm_shd['algo_name'] == algo].describe()
print(summary.round(2)) 

            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.69     0.67       0.73          0.78     0.70            0.70
std       0.15     0.15       0.11          0.10     0.10            0.11
min       0.19     0.25       0.39          0.42     0.40            0.35
25%       0.60     0.56       0.67          0.70     0.62            0.61
50%       0.70     0.67       0.72          0.79     0.70            0.70
75%       0.79     0.79       0.80          0.85     0.75            0.76
max       1.00     1.00       1.00          1.00     1.00            1.00


In [172]:
df_without_nan_gscore = exclude_dataset_which_has_nan_in_metric(df, 'gscore')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_gscore['gscore'][df_without_nan_gscore['algo_name'] == algo].describe()
print(summary.round(2)) 
#print(df[df['gscore'].isna()])

There're 8 datasets excluded for gscore.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.13     0.09       0.23          0.40     0.12            0.12
std       0.22     0.17       0.22          0.22     0.18            0.18
min       0.00     0.00       0.00          0.00     0.00            0.00
25%       0.00     0.00       0.00          0.20     0.00            0.00
50%       0.00     0.00       0.20          0.40     0.00            0.00
75%       0.20     0.11       0.40          0.57     0.20            0.20
max       1.00     0.86       0.88          0.88     0.83            0.88


In [165]:
df_without_nan_recall = exclude_dataset_which_has_nan_in_metric(df, 'recall')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_recall['recall'][df_without_nan_recall['algo_name'] == algo].describe()
print(summary.round(2))

There're 8 datasets excluded for recall.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.62     0.68       0.41          0.49     0.34            0.35
std       0.20     0.17       0.17          0.17     0.16            0.16
min       0.00     0.17       0.00          0.00     0.00            0.00
25%       0.50     0.58       0.30          0.40     0.21            0.22
50%       0.60     0.68       0.40          0.50     0.33            0.33
75%       0.78     0.80       0.56          0.60     0.43            0.44
max       1.00     1.00       0.88          0.89     0.83            0.88


In [175]:
df_without_nan_precision = exclude_dataset_which_has_nan_in_metric(df, 'precision')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_precision['precision'][df_without_nan_precision['algo_name'] == algo].describe()
print(summary.round(2))
#print(df[df['precision'].isna()])

There're 21 datasets excluded for precision.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1065.00  1065.00    1065.00       1065.00  1065.00         1065.00
mean      0.49     0.45       0.66          0.84     0.53            0.51
std       0.16     0.14       0.24          0.19     0.24            0.23
min       0.00     0.10       0.00          0.00     0.00            0.00
25%       0.38     0.35       0.50          0.71     0.38            0.33
50%       0.47     0.43       0.67          0.88     0.50            0.50
75%       0.58     0.54       0.80          1.00     0.67            0.67
max       1.00     0.88       1.00          1.00     1.00            1.00


In [178]:

df_without_nan_fpr = exclude_dataset_which_has_nan_in_metric(df, 'fpr')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_fpr['fpr'][df_without_nan_fpr['algo_name'] == algo].describe()
print(summary.round(2))
#print(df[df['fpr'].isna()])

There're 8 datasets excluded for fpr.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.40     0.46       0.17          0.08     0.25            0.29
std       0.23     0.29       0.19          0.16     0.23            0.25
min       0.00     0.00       0.00          0.00     0.00            0.00
25%       0.25     0.30       0.05          0.00     0.11            0.13
50%       0.35     0.40       0.12          0.04     0.20            0.21
75%       0.50     0.56       0.21          0.11     0.33            0.40
max       1.33     2.33       1.33          1.67     2.00            1.67


In [77]:
algo_df = df.groupby('algo_name')['normalized_shd'].apply(list)

In [78]:
def normality_test(data_dict, data_key):
    data = data_dict[data_key]
    data = data.dropna()
    print(len(data.to_numpy()))
    value, p = normaltest(data.to_numpy())
    print(f"Normality tested on {data_key}. value: {value}, p:{p}")
    if p >= 0.05:
        print(f'It is likely that the distribution of {data_key} datasets is normal.\n')
    else:
         print(f'It is unlikely that the distribution of {data_key} datasets is normal.\n') 

In [18]:
def gen_dict(algos:list, value_to_analyse:str):
    result = {}
    for algo in algos:
        result[algo] = df[value_to_analyse][df['algo_name'] == algo]
    return result 

In [19]:
algos_dict = gen_dict(algos, 'normalized_shd')

for algo in algos_dict.keys():
    normality_test(algos_dict, algo)

1071
Normality tested on pc. value: 3.0961205323450094, p:0.21266007795834024
It is likely that the distribution of pc datasets is normal.

1071
Normality tested on ges. value: 18.02748895504528, p:0.00012172520430010903
It is unlikely that the distribution of ges datasets is normal.

1071
Normality tested on icalingam. value: 3.9410955833026238, p:0.13938048383348445
It is likely that the distribution of icalingam datasets is normal.

1071
Normality tested on directlingam. value: 4.91598240585669, p:0.08560674493848495
It is likely that the distribution of directlingam datasets is normal.



## stability test

In [20]:
stability_df = pd.DataFrame(df, columns=['group', 'algo_name', 'normalized_shd', 'F1'])
stability_df

Unnamed: 0,group,algo_name,normalized_shd,F1
0,Network1_amp,pc,0.750000,0.6667
1,Network1_amp,pc,0.916667,0.8000
2,Network1_amp,pc,0.916667,0.8333
3,Network1_amp,pc,0.750000,0.5000
4,Network1_amp,pc,0.833333,0.7500
...,...,...,...,...
4311,Network9_cont_amp,directlingam,0.750000,0.6250
4312,Network9_cont_amp,directlingam,0.750000,0.6667
4313,Network9_cont_amp,directlingam,0.700000,0.5000
4314,Network9_cont_amp,directlingam,0.700000,0.5263


In [21]:
groups = stability_df.group.unique()
groups

array(['Network1_amp', 'Network2_amp', 'Network3_amp', 'Network4_amp',
       'Network5_amp', 'Network5_cont', 'Network5_cont_p3n7',
       'Network5_cont_p7n3', 'Network6_amp', 'Network6_cont',
       'Network7_amp', 'Network7_cont', 'Network8_amp_amp',
       'Network8_amp_cont', 'Network8_cont_amp', 'Network9_amp_amp',
       'Network9_amp_cont', 'Network9_cont_amp'], dtype=object)

In [22]:
algos = ['pc', 'ges', 'icalingam', 'directlingam']
for group in groups:
    grouped_stability_df = stability_df[(stability_df == group).any(axis=1)]
    fig = go.Figure()
    for algo in algos:
            fig.add_trace(go.Violin(x=grouped_stability_df['algo_name'][grouped_stability_df['algo_name'] == algo],
                                y=grouped_stability_df['normalized_shd'][grouped_stability_df['algo_name'] == algo],
                                name=algo,
                                box_visible=True,
                                meanline_visible=True))
    
    fig.update_layout(font_size=26)
    fig.update_layout(title_text=group)
    fig.update_layout(showlegend=False) 
    fig.update_yaxes(title_text="normalized_shd")
    fig.show()

In [23]:
for group in groups:
    summary = pd.DataFrame()
    data_to_analyse = stability_df[(stability_df == group).any(axis=1)]
    for algo in algos:
        summary[algo] = data_to_analyse['normalized_shd'][data_to_analyse['algo_name'] == algo].describe()
    summary['avg'] = summary.mean(axis=1)     
    print(f'Summary of {group}:')    
    print(summary.round(2)) 
    print('\n')

Summary of Network1_amp:
          pc    ges  icalingam  directlingam    avg
count  60.00  60.00      60.00         60.00  60.00
mean    0.81   0.79       0.78          0.86   0.81
std     0.11   0.11       0.11          0.09   0.11
min     0.50   0.50       0.50          0.67   0.54
25%     0.75   0.75       0.67          0.83   0.75
50%     0.83   0.83       0.75          0.83   0.81
75%     0.92   0.83       0.83          0.92   0.88
max     1.00   1.00       1.00          1.00   1.00


Summary of Network2_amp:
          pc    ges  icalingam  directlingam    avg
count  60.00  60.00      60.00         60.00  60.00
mean    0.84   0.79       0.84          0.83   0.83
std     0.10   0.11       0.09          0.10   0.10
min     0.64   0.50       0.64          0.50   0.57
25%     0.79   0.71       0.79          0.79   0.77
50%     0.86   0.79       0.86          0.86   0.84
75%     0.93   0.86       0.93          0.93   0.91
max     1.00   1.00       1.00          1.00   1.00


Summary of