In [1]:
import numpy as np
import pandas as pd
from scipy.stats import normaltest, mannwhitneyu, ttest_ind
import plotly.graph_objects as go
import plotly.express as px
from causalbench.data.datasets import read_group_arc_num_as_dataframe
from IPython.display import display

In [2]:
pd.options.display.max_columns = None
df_pcstable = pd.read_csv('data_for_evaluation/pcstable_feedback_19Mar.csv', sep=';')
df = pd.read_csv('data_for_evaluation/feeback_castle4__04Mar.csv',  sep=';')
df = df[df.algo_name != 'pc']
df = pd.concat([df, df_pcstable], ignore_index=True)
df

Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time
0,sim-02.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.5000,0.8333,0.7500,4.0,6.0,0.4000,0.6667,0.5000,0.0,0.16,Sat Mar 4 11:48:58 2023
1,sim-08.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.1667,1.1667,0.2500,1.0,6.0,0.6667,1.0000,0.8000,0.5,0.14,Sat Mar 4 11:48:58 2023
2,sim-03.Network1_amp.continuous,0.541667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.4286,1.0000,0.7500,3.0,7.0,0.4444,0.6667,0.5333,0.0,0.15,Sat Mar 4 11:48:58 2023
3,sim-06.Network1_amp.continuous,0.416667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.2857,1.1667,0.5000,2.0,7.0,0.5000,0.8333,0.6250,0.0,0.15,Sat Mar 4 11:48:58 2023
4,sim-05.Network1_amp.continuous,0.458333,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.6250,0.8333,1.2500,5.0,8.0,0.3077,0.6667,0.4211,0.0,0.17,Sat Mar 4 11:48:58 2023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4311,sim-45.Network9_cont_amp.continuous,0.500000,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,7.0,13.0,0.4286,0.6000,0.5000,0.0,0.81,Sun Mar 19 19:44:10 2023
4312,sim-58.Network9_cont_amp.continuous,0.542373,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,8.0,13.0,0.4286,0.6000,0.5000,0.0,0.54,Sun Mar 19 19:44:10 2023
4313,sim-50.Network9_cont_amp.continuous,0.449153,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,7.0,13.0,0.4286,0.6000,0.5000,0.0,0.67,Sun Mar 19 19:44:10 2023
4314,sim-51.Network9_cont_amp.continuous,0.466102,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.4167,0.7000,0.1923,6.0,12.0,0.4375,0.7000,0.5385,0.0,0.58,Sun Mar 19 19:44:10 2023


In [3]:
#add notears and notearslowrank data
df1 = pd.read_csv('data_for_evaluation/notears_feedback_12Mar.csv',  sep=';')
df2 = pd.read_csv('data_for_evaluation/notearslowrank_feedback_12Mar.csv', sep=';')
df = pd.concat([df, df1, df2], ignore_index=True)

In [4]:
def extract_group(name):
    
    if '-' in name:#case of feedback
        return name.split('.')[1] 
    elif 'real' in name:
        return name
    elif 'numdata' in name:
        return name
    elif 'dream' in name:
        return name
    elif '_' in name:
        return name.split('_')[0]
    else:
        return name

In [5]:
group_arc_num_df = read_group_arc_num_as_dataframe()
group_arc_num_df

Unnamed: 0,group,arc_num
0,Network1_amp,6
1,Network2_amp,7
2,Network3_amp,7
3,Network4_amp,19
4,Network5_amp,5
...,...,...
90,dream4_1,176
91,dream4_2,249
92,dream4_3,195
93,dream4_4,211


In [6]:
def get_arc_num_by_name(name:str):
    group = extract_group(name)
    row = group_arc_num_df.loc[group_arc_num_df['group'] == group]
    return row['arc_num'].values[0]

In [7]:
def normalize_shd(shd, arc_num):
    return 1-(shd/(arc_num*2))

In [8]:
df['group'] = df['dataset_name'].apply(extract_group)

In [9]:
df['arc_num'] = df['dataset_name'].apply(get_arc_num_by_name)

In [10]:
def exclude_dataset_which_has_nan_in_metric(dataframe, metric:str):
    df_nan_rows = dataframe[dataframe[metric].isna()]
    to_exclude_dataset_names = df_nan_rows['dataset_name'].tolist()
    print(f"There're {len(to_exclude_dataset_names)} datasets excluded for {metric}.")
    df_without_nan = dataframe[~dataframe['dataset_name'].isin(to_exclude_dataset_names)]
    return df_without_nan
    

In [11]:
df['normalized_shd'] = normalize_shd(df['shd'], df['arc_num'])

In [12]:
df_without_nan_norm_shd = exclude_dataset_which_has_nan_in_metric(df, 'normalized_shd')

There're 8 datasets excluded for normalized_shd.


In [13]:
df_without_nan_norm_shd
# df_for_scatter_plot_norm_shd = pd.DataFrame()
df_for_scatter_plot_norm_shd = df_without_nan_norm_shd.groupby(['algo_name', 'group'], as_index=False, sort=False)[['varsortability','normalized_shd']].median()
df_for_scatter_plot_f1 = df_without_nan_norm_shd.groupby(['algo_name', 'group'], as_index=False, sort=False)[['varsortability','F1']].median()
df_for_scatter_plot_gscore = df_without_nan_norm_shd.groupby(['algo_name', 'group'], as_index=False, sort=False)[['varsortability','gscore']].median()
#df_for_scatter_plot_norm_shd.sort_values(by=['varsortability', 'normalized_shd'], ascending=False)

In [14]:
df_without_nan_norm_shd

Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time,group,arc_num,normalized_shd
0,sim-02.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.5000,0.8333,0.7500,4.0,6.0,0.4000,0.6667,0.5000,0.0,0.16,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.666667
1,sim-08.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.1667,1.1667,0.2500,1.0,6.0,0.6667,1.0000,0.8000,0.5,0.14,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.916667
2,sim-03.Network1_amp.continuous,0.541667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.4286,1.0000,0.7500,3.0,7.0,0.4444,0.6667,0.5333,0.0,0.15,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.750000
3,sim-06.Network1_amp.continuous,0.416667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.2857,1.1667,0.5000,2.0,7.0,0.5000,0.8333,0.6250,0.0,0.15,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.833333
4,sim-05.Network1_amp.continuous,0.458333,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.6250,0.8333,1.2500,5.0,8.0,0.3077,0.6667,0.4211,0.0,0.17,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.583333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6469,sim-30.Network9_cont_amp.continuous,0.576271,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7500,0.2000,0.2308,8.0,8.0,0.2500,0.2000,0.2222,0.0,158.49,Sun Mar 12 12:40:50 2023,Network9_cont_amp,10,0.600000
6470,sim-43.Network9_cont_amp.continuous,0.415254,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7778,0.2000,0.2692,9.0,9.0,0.2222,0.2000,0.2105,0.0,130.79,Sun Mar 12 12:40:50 2023,Network9_cont_amp,10,0.550000
6471,sim-46.Network9_amp_amp.continuous,0.449153,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.6000,0.4000,0.2308,8.0,10.0,0.4000,0.4000,0.4000,0.0,395.03,Sun Mar 12 12:40:52 2023,Network9_amp_amp,10,0.600000
6472,sim-11.Network9_cont_amp.continuous,0.491525,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7500,0.3000,0.3462,11.0,12.0,0.2500,0.3000,0.2727,0.0,206.72,Sun Mar 12 12:40:53 2023,Network9_cont_amp,10,0.450000


In [15]:
def test():
    df_for_scatter_plot_norm_shd = df_without_nan_norm_shd.groupby(['algo_name', 'group'], as_index=False, sort=False)[['varsortability','normalized_shd']].median()
    df_for_scatter_plot_f1 = df_without_nan_norm_shd.groupby(['algo_name', 'group'], as_index=False, sort=False)[['varsortability','F1']].median()
    df_for_scatter_plot_gscore = df_without_nan_norm_shd.groupby(['algo_name', 'group'], as_index=False, sort=False)[['varsortability','gscore']].median()

In [27]:
def show_scatter_plot(dataframe, x_lable, y_lable, color, symbol):
    fig = px.scatter(dataframe, x=x_lable, y=y_lable,color=color, symbol=symbol)
    fig.show()
show_scatter_plot(df_for_scatter_plot_norm_shd, 'varsortability', 'normalized_shd', color="algo_name", symbol="algo_name")  
show_scatter_plot(df_for_scatter_plot_f1, 'varsortability', 'F1', color="algo_name", symbol="algo_name")
show_scatter_plot(df_for_scatter_plot_gscore, 'varsortability', 'gscore', color="algo_name", symbol="algo_name")

In [28]:
show_scatter_plot(df_without_nan_norm_shd,'varsortability', 'normalized_shd', color="algo_name", symbol="algo_name")

In [29]:
temp = df_without_nan_norm_shd.groupby(['varsortability', 'algo_name'], as_index=False).agg({'normalized_shd': lambda x: np.median(x)})
temp
show_scatter_plot(temp, 'varsortability', 'normalized_shd', color="algo_name", symbol="algo_name")

In [19]:
algos = ['pc', 'ges', 'icalingam', 'directlingam','notears', 'notearslowrank']
def show_violin_plots(dataframe, metric:str):
    df = dataframe
    for i in range(0, len(algos), 4):
        stop =  i + 4 if i + 4 < len(algos) else len(algos)
        fig = go.Figure()
        for algo in algos[i:stop]:
                fig.add_trace(go.Violin(x=df['algo_name'][df['algo_name'] == algo],
                                    y=df[metric][df['algo_name'] == algo],
                                    name=algo,
                                    box_visible=True,
                                    meanline_visible=True))
        fig.update_layout(font_size=26)
        fig.update_layout(showlegend=False) 
        fig.update_yaxes(title_text=metric)        
        fig.show()
show_violin_plots(df_without_nan_norm_shd, 'normalized_shd')     

In [20]:
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_norm_shd['normalized_shd'][df_without_nan_norm_shd['algo_name'] == algo].describe()
print(summary.round(2)) 

            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.69     0.67       0.73          0.78     0.70            0.70
std       0.14     0.15       0.11          0.10     0.10            0.11
min       0.19     0.25       0.39          0.42     0.40            0.35
25%       0.60     0.56       0.67          0.70     0.62            0.61
50%       0.70     0.67       0.72          0.79     0.70            0.70
75%       0.79     0.79       0.80          0.85     0.75            0.76
max       1.00     1.00       1.00          1.00     1.00            1.00


In [21]:
df_without_nan_gscore = exclude_dataset_which_has_nan_in_metric(df, 'gscore')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_gscore['gscore'][df_without_nan_gscore['algo_name'] == algo].describe()
print(summary.round(2)) 
#print(df[df['gscore'].isna()])

There're 8 datasets excluded for gscore.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.14     0.09       0.23          0.40     0.12            0.12
std       0.22     0.17       0.22          0.22     0.18            0.18
min       0.00     0.00       0.00          0.00     0.00            0.00
25%       0.00     0.00       0.00          0.20     0.00            0.00
50%       0.00     0.00       0.20          0.40     0.00            0.00
75%       0.20     0.11       0.40          0.57     0.20            0.20
max       1.00     0.86       0.88          0.88     0.83            0.88


In [22]:
df_without_nan_recall = exclude_dataset_which_has_nan_in_metric(df, 'recall')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_recall['recall'][df_without_nan_recall['algo_name'] == algo].describe()
print(summary.round(2))

There're 8 datasets excluded for recall.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.62     0.68       0.41          0.49     0.34            0.35
std       0.20     0.17       0.17          0.17     0.16            0.16
min       0.00     0.17       0.00          0.00     0.00            0.00
25%       0.50     0.58       0.30          0.40     0.21            0.22
50%       0.60     0.68       0.40          0.50     0.33            0.33
75%       0.75     0.80       0.56          0.60     0.43            0.44
max       1.00     1.00       0.88          0.89     0.83            0.88


In [23]:
df_without_nan_precision = exclude_dataset_which_has_nan_in_metric(df, 'precision')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_precision['precision'][df_without_nan_precision['algo_name'] == algo].describe()
print(summary.round(2))
#print(df[df['precision'].isna()])

There're 21 datasets excluded for precision.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1065.00  1065.00    1065.00       1065.00  1065.00         1065.00
mean      0.50     0.45       0.66          0.84     0.53            0.51
std       0.17     0.14       0.24          0.19     0.24            0.23
min       0.00     0.10       0.00          0.00     0.00            0.00
25%       0.40     0.35       0.50          0.71     0.38            0.33
50%       0.50     0.43       0.67          0.88     0.50            0.50
75%       0.60     0.54       0.80          1.00     0.67            0.67
max       1.00     0.88       1.00          1.00     1.00            1.00


In [24]:

df_without_nan_fpr = exclude_dataset_which_has_nan_in_metric(df, 'fpr')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_fpr['fpr'][df_without_nan_fpr['algo_name'] == algo].describe()
print(summary.round(2))
#print(df[df['fpr'].isna()])

There're 8 datasets excluded for fpr.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.39     0.46       0.17          0.08     0.25            0.29
std       0.23     0.29       0.19          0.16     0.23            0.25
min       0.00     0.00       0.00          0.00     0.00            0.00
25%       0.23     0.30       0.05          0.00     0.11            0.13
50%       0.33     0.40       0.12          0.04     0.20            0.21
75%       0.47     0.56       0.21          0.11     0.33            0.40
max       1.33     2.33       1.33          1.67     2.00            1.67


In [25]:
df_without_nan_runtime = exclude_dataset_which_has_nan_in_metric(df, 'runtime_second')
summary = pd.DataFrame()
for algo in algos:
    summary[algo] = df_without_nan_runtime['runtime_second'][df_without_nan_runtime['algo_name'] == algo].describe()
print(summary.round(2))
#print(df[df['fpr'].isna()])

There're 8 datasets excluded for runtime_second.
            pc      ges  icalingam  directlingam  notears  notearslowrank
count  1071.00  1071.00    1071.00       1071.00  1071.00         1071.00
mean      0.51     1.29       0.50          0.19    36.81          112.62
std       0.72     1.31       0.44          0.13    45.12          117.91
min       0.01     0.09       0.02          0.04     5.76           10.29
25%       0.05     0.23       0.07          0.07    16.73           40.54
50%       0.22     0.68       0.51          0.16    26.10           75.62
75%       0.60     1.96       0.77          0.25    39.58          138.24
max       3.76     5.90       1.84          0.68   460.65         1194.83


In [26]:
algo_df = df.groupby('algo_name')['normalized_shd'].apply(list)

In [27]:
def normality_test(data_dict, data_key):
    data = data_dict[data_key]
    data = data.dropna()
    print(len(data.to_numpy()))
    value, p = normaltest(data.to_numpy())
    print(f"Normality tested on {data_key}. value: {value}, p:{p}")
    if p >= 0.05:
        print(f'It is likely that the distribution of {data_key} datasets is normal.\n')
    else:
         print(f'It is unlikely that the distribution of {data_key} datasets is normal.\n') 

In [28]:
def gen_dict(algos:list, value_to_analyse:str):
    result = {}
    for algo in algos:
        result[algo] = df[value_to_analyse][df['algo_name'] == algo]
    return result 

In [29]:
algos_dict = gen_dict(algos, 'normalized_shd')

for algo in algos_dict.keys():
    normality_test(algos_dict, algo)

1079
Normality tested on pc. value: 3.1006691098900268, p:0.21217697709434644
It is likely that the distribution of pc datasets is normal.

1071
Normality tested on ges. value: 18.02748895504528, p:0.00012172520430010903
It is unlikely that the distribution of ges datasets is normal.

1079
Normality tested on icalingam. value: 3.9744912665944465, p:0.1370724537725664
It is likely that the distribution of icalingam datasets is normal.

1079
Normality tested on directlingam. value: 4.920462679885263, p:0.0854151887366505
It is likely that the distribution of directlingam datasets is normal.

1079
Normality tested on notears. value: 3.1875885008968545, p:0.20315333292654988
It is likely that the distribution of notears datasets is normal.

1079
Normality tested on notearslowrank. value: 0.4302878678601874, p:0.8064253598522079
It is likely that the distribution of notearslowrank datasets is normal.



## stability test

In [30]:
stability_df = pd.DataFrame(df, columns=['group', 'algo_name', 'normalized_shd', 'F1'])
stability_df

Unnamed: 0,group,algo_name,normalized_shd,F1
0,Network1_amp,ges,0.666667,0.5000
1,Network1_amp,ges,0.916667,0.8000
2,Network1_amp,ges,0.750000,0.5333
3,Network1_amp,ges,0.833333,0.6250
4,Network1_amp,ges,0.583333,0.4211
...,...,...,...,...
6469,Network9_cont_amp,notearslowrank,0.600000,0.2222
6470,Network9_cont_amp,notearslowrank,0.550000,0.2105
6471,Network9_amp_amp,notearslowrank,0.600000,0.4000
6472,Network9_cont_amp,notearslowrank,0.450000,0.2727


In [31]:
groups = stability_df.group.unique()
groups

array(['Network1_amp', 'Network2_amp', 'Network3_amp', 'Network5_amp',
       'Network5_cont', 'Network4_amp', 'Network5_cont_p3n7',
       'Network5_cont_p7n3', 'Network6_amp', 'Network6_cont',
       'Network7_amp', 'Network7_cont', 'Network8_amp_amp',
       'Network8_amp_cont', 'Network8_cont_amp', 'Network9_amp_amp',
       'Network9_amp_cont', 'Network9_cont_amp'], dtype=object)

In [34]:
algos = ['pc', 'ges', 'icalingam', 'directlingam']
for group in groups:
    grouped_stability_df = stability_df[(stability_df == group).any(axis=1)]
    fig = go.Figure()
    for algo in algos:
            fig.add_trace(go.Violin(x=grouped_stability_df['algo_name'][grouped_stability_df['algo_name'] == algo],
                                y=grouped_stability_df['normalized_shd'][grouped_stability_df['algo_name'] == algo],
                                name=algo,
                                box_visible=True,
                                meanline_visible=True))
    
    fig.update_layout(font_size=26)
#     fig.update_layout(title_text=group)
    fig.update_layout(showlegend=False) 
    fig.update_yaxes(title_text="normalized_shd")
    fig.show()

In [33]:
for group in groups:
    summary = pd.DataFrame()
    data_to_analyse = stability_df[(stability_df == group).any(axis=1)]
    for algo in algos:
        summary[algo] = data_to_analyse['normalized_shd'][data_to_analyse['algo_name'] == algo].describe()
    summary['avg'] = summary.mean(axis=1)     
    print(f'Summary of {group}:')    
    print(summary.round(2)) 
    print('\n')

Summary of Network1_amp:
          pc    ges  icalingam  directlingam    avg
count  60.00  60.00      60.00         60.00  60.00
mean    0.82   0.79       0.78          0.86   0.81
std     0.11   0.11       0.11          0.09   0.10
min     0.50   0.50       0.50          0.67   0.54
25%     0.75   0.75       0.67          0.83   0.75
50%     0.83   0.83       0.75          0.83   0.81
75%     0.92   0.83       0.83          0.92   0.88
max     1.00   1.00       1.00          1.00   1.00


Summary of Network2_amp:
          pc    ges  icalingam  directlingam    avg
count  60.00  60.00      60.00         60.00  60.00
mean    0.84   0.79       0.84          0.83   0.82
std     0.10   0.11       0.09          0.10   0.10
min     0.64   0.50       0.64          0.50   0.57
25%     0.79   0.71       0.79          0.79   0.77
50%     0.86   0.79       0.86          0.86   0.84
75%     0.93   0.86       0.93          0.93   0.91
max     1.00   1.00       1.00          1.00   1.00


Summary of

In [20]:
df

Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time,group,arc_num,normalized_shd
0,sim-13.Network1_amp.continuous,0.666667,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.4000,0.8333,0.5000,3.0,5.0,0.6667,0.6667,0.6667,0.3333,0.14,Sat Mar 4 11:48:53 2023,Network1_amp,6,0.750000
1,sim-41.Network1_amp.continuous,0.208333,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.1667,1.1667,0.2500,1.0,6.0,0.6667,1.0000,0.8000,0.5000,0.17,Sat Mar 4 11:48:53 2023,Network1_amp,6,0.916667
2,sim-59.Network1_amp.continuous,0.291667,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.2000,1.0000,0.2500,1.0,5.0,0.8333,0.8333,0.8333,0.6667,0.12,Sat Mar 4 11:48:53 2023,Network1_amp,6,0.916667
3,sim-10.Network1_amp.continuous,0.166667,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.5000,0.8333,0.7500,3.0,6.0,0.5000,0.5000,0.5000,0.0000,0.17,Sat Mar 4 11:48:53 2023,Network1_amp,6,0.750000
4,sim-04.Network1_amp.continuous,0.250000,5,500,pc,"{'variant': 'original', 'alpha': 0.05, 'ci_tes...",gCastle,No Error,0.2857,1.1667,0.5000,2.0,7.0,0.6000,1.0000,0.7500,0.3333,0.16,Sat Mar 4 11:48:53 2023,Network1_amp,6,0.833333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6469,sim-30.Network9_cont_amp.continuous,0.576271,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7500,0.2000,0.2308,8.0,8.0,0.2500,0.2000,0.2222,0.0000,158.49,Sun Mar 12 12:40:50 2023,Network9_cont_amp,10,0.600000
6470,sim-43.Network9_cont_amp.continuous,0.415254,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7778,0.2000,0.2692,9.0,9.0,0.2222,0.2000,0.2105,0.0000,130.79,Sun Mar 12 12:40:50 2023,Network9_cont_amp,10,0.550000
6471,sim-46.Network9_amp_amp.continuous,0.449153,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.6000,0.4000,0.2308,8.0,10.0,0.4000,0.4000,0.4000,0.0000,395.03,Sun Mar 12 12:40:52 2023,Network9_amp_amp,10,0.600000
6472,sim-11.Network9_cont_amp.continuous,0.491525,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7500,0.3000,0.3462,11.0,12.0,0.2500,0.3000,0.2727,0.0000,206.72,Sun Mar 12 12:40:53 2023,Network9_cont_amp,10,0.450000


In [35]:
df_pcstable

Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time
0,sim-59.Network1_amp.continuous,0.291667,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.2000,1.0000,0.2500,1,5,0.8333,0.8333,0.8333,0.6667,0.36,Sun Mar 19 19:43:54 2023
1,sim-41.Network1_amp.continuous,0.208333,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.1667,1.1667,0.2500,1,6,0.6667,1.0000,0.8000,0.5000,0.48,Sun Mar 19 19:43:54 2023
2,sim-58.Network1_amp.continuous,0.416667,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.3333,1.0000,0.5000,2,6,0.7143,0.8333,0.7692,0.5000,0.37,Sun Mar 19 19:43:54 2023
3,sim-30.Network1_amp.continuous,0.375000,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.1667,1.1667,0.2500,1,6,0.6667,1.0000,0.8000,0.5000,0.37,Sun Mar 19 19:43:54 2023
4,sim-18.Network1_amp.continuous,0.250000,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.1667,1.1667,0.2500,1,6,0.6667,1.0000,0.8000,0.5000,0.43,Sun Mar 19 19:43:54 2023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1074,sim-45.Network9_cont_amp.continuous,0.500000,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,7,13,0.4286,0.6000,0.5000,0.0000,0.81,Sun Mar 19 19:44:10 2023
1075,sim-58.Network9_cont_amp.continuous,0.542373,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,8,13,0.4286,0.6000,0.5000,0.0000,0.54,Sun Mar 19 19:44:10 2023
1076,sim-50.Network9_cont_amp.continuous,0.449153,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5385,0.6000,0.2692,7,13,0.4286,0.6000,0.5000,0.0000,0.67,Sun Mar 19 19:44:10 2023
1077,sim-51.Network9_cont_amp.continuous,0.466102,9,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.4167,0.7000,0.1923,6,12,0.4375,0.7000,0.5385,0.0000,0.58,Sun Mar 19 19:44:10 2023


In [36]:
df

Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time,group,arc_num,normalized_shd
0,sim-02.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.5000,0.8333,0.7500,4.0,6.0,0.4000,0.6667,0.5000,0.0,0.16,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.666667
1,sim-08.Network1_amp.continuous,0.250000,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.1667,1.1667,0.2500,1.0,6.0,0.6667,1.0000,0.8000,0.5,0.14,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.916667
2,sim-03.Network1_amp.continuous,0.541667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.4286,1.0000,0.7500,3.0,7.0,0.4444,0.6667,0.5333,0.0,0.15,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.750000
3,sim-06.Network1_amp.continuous,0.416667,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.2857,1.1667,0.5000,2.0,7.0,0.5000,0.8333,0.6250,0.0,0.15,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.833333
4,sim-05.Network1_amp.continuous,0.458333,5,500,ges,"{'criterion': 'bic', 'method': 'scatter', 'k':...",gCastle,No Error,0.6250,0.8333,1.2500,5.0,8.0,0.3077,0.6667,0.4211,0.0,0.17,Sat Mar 4 11:48:58 2023,Network1_amp,6,0.583333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6469,sim-30.Network9_cont_amp.continuous,0.576271,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7500,0.2000,0.2308,8.0,8.0,0.2500,0.2000,0.2222,0.0,158.49,Sun Mar 12 12:40:50 2023,Network9_cont_amp,10,0.600000
6470,sim-43.Network9_cont_amp.continuous,0.415254,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7778,0.2000,0.2692,9.0,9.0,0.2222,0.2000,0.2105,0.0,130.79,Sun Mar 12 12:40:50 2023,Network9_cont_amp,10,0.550000
6471,sim-46.Network9_amp_amp.continuous,0.449153,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.6000,0.4000,0.2308,8.0,10.0,0.4000,0.4000,0.4000,0.0,395.03,Sun Mar 12 12:40:52 2023,Network9_amp_amp,10,0.600000
6472,sim-11.Network9_cont_amp.continuous,0.491525,9,500,notearslowrank,"{'w_init': None, 'max_iter': 15, 'h_tol': 1e-0...",gCastle,No Error,0.7500,0.3000,0.3462,11.0,12.0,0.2500,0.3000,0.2727,0.0,206.72,Sun Mar 12 12:40:53 2023,Network9_cont_amp,10,0.450000


In [47]:
pc_network2 = df[df.group == "Network2_amp"]
pc_network2 = pc_network2[pc_network2.algo_name == 'pc']
pc_network2


Unnamed: 0,dataset_name,varsortability,N_variables,N_obs,algo_name,algo_param,library_name,Error,fdr,tpr,fpr,shd,nnz,precision,recall,F1,gscore,runtime_second,experiment_time,group,arc_num,normalized_shd
3257,sim-11.Network2_amp.continuous,0.636364,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.0,1.2857,0.0,0.0,5.0,1.0,1.0,1.0,1.0,0.47,Sun Mar 19 19:43:54 2023,Network2_amp,7,1.0
3266,sim-20.Network2_amp.continuous,0.30303,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5,1.1429,1.3333,4.0,8.0,0.6,0.8571,0.7059,0.2857,0.5,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.714286
3269,sim-09.Network2_amp.continuous,0.272727,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.2857,1.2857,0.6667,2.0,7.0,0.7,1.0,0.8235,0.5714,0.42,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.857143
3273,sim-08.Network2_amp.continuous,0.30303,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.2,1.1429,0.3333,2.0,5.0,0.6667,0.5714,0.6154,0.2857,0.52,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.857143
3275,sim-14.Network2_amp.continuous,0.424242,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.2,1.1429,0.3333,2.0,5.0,0.6667,0.5714,0.6154,0.2857,0.51,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.857143
3283,sim-06.Network2_amp.continuous,0.393939,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.1667,1.2857,0.3333,1.0,6.0,0.8333,0.7143,0.7692,0.5714,0.5,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.928571
3284,sim-05.Network2_amp.continuous,0.30303,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.4286,1.1429,1.0,3.0,7.0,0.5556,0.7143,0.625,0.1429,0.53,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.785714
3285,sim-04.Network2_amp.continuous,0.454545,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5,1.1429,1.3333,4.0,8.0,0.6,0.8571,0.7059,0.2857,0.56,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.714286
3286,sim-12.Network2_amp.continuous,0.30303,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.2857,1.2857,0.6667,2.0,7.0,0.7,1.0,0.8235,0.5714,0.51,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.857143
3287,sim-07.Network2_amp.continuous,0.272727,5,500,pc,"{'variant': 'stable', 'alpha': 0.05, 'ci_test'...",gCastle,No Error,0.5,1.1429,1.3333,4.0,8.0,0.4444,0.5714,0.5,0.0,0.57,Sun Mar 19 19:43:54 2023,Network2_amp,7,0.714286


In [63]:
grouped = pc_network2.groupby(['normalized_shd'])['dataset_name'].apply(list).reset_index(name='new')
grouped

Unnamed: 0,normalized_shd,new
0,0.642857,[sim-25.Network2_amp.continuous]
1,0.714286,"[sim-20.Network2_amp.continuous, sim-04.Networ..."
2,0.785714,"[sim-05.Network2_amp.continuous, sim-24.Networ..."
3,0.857143,"[sim-09.Network2_amp.continuous, sim-08.Networ..."
4,0.928571,"[sim-06.Network2_amp.continuous, sim-10.Networ..."
5,1.0,"[sim-11.Network2_amp.continuous, sim-21.Networ..."


In [73]:
grouped.iloc[[3]].new.values[0]

['sim-09.Network2_amp.continuous',
 'sim-08.Network2_amp.continuous',
 'sim-14.Network2_amp.continuous',
 'sim-12.Network2_amp.continuous',
 'sim-03.Network2_amp.continuous',
 'sim-17.Network2_amp.continuous',
 'sim-19.Network2_amp.continuous',
 'sim-01.Network2_amp.continuous',
 'sim-18.Network2_amp.continuous',
 'sim-22.Network2_amp.continuous',
 'sim-23.Network2_amp.continuous',
 'sim-32.Network2_amp.continuous',
 'sim-34.Network2_amp.continuous',
 'sim-36.Network2_amp.continuous',
 'sim-42.Network2_amp.continuous',
 'sim-43.Network2_amp.continuous',
 'sim-44.Network2_amp.continuous',
 'sim-48.Network2_amp.continuous',
 'sim-55.Network2_amp.continuous',
 'sim-58.Network2_amp.continuous']