In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

from lib.constants import *
from lib.nodelink_viewer import *
import lib.VIS_L23_preprocessing.vis_L23_constants as VIS
from lib.pandas_compute import *
from lib.matrix_analyzer import *
from lib.multilevel_analysis import *
from lib.pandas_impl import *
from lib.pandas_stats_VIS import *

In [None]:
eval_folder = Path.cwd() / 'data' / 'eval' / 'VIS_24-09-13_5mu'
meta_folder = Path.cwd() / 'data' / 'VIS'/ 'meta'
plot_folder = eval_folder / 'plots' 
plot_folder_nodelink = plot_folder / 'nodelink'
plot_folder_nodelink.mkdir(exist_ok=True, parents=True)

vmin = -0.1
vmax = 0.2

#### Data preparation

In [None]:
filename = eval_folder / "multilevel_analysis.pkl"
with open(filename, 'rb') as file:
    multilevel_analysis = pickle.load(file)
df_summary = multilevel_analysis.df_summary.reset_index()

df_summary.set_index(['pre_celltype', 'post_celltype', 'pre_id_mapped', 'post_id_mapped',
       'post_compartment', 'overlap_volume', 'post_celltype_merged',
       'pre_celltype_merged'], inplace=True)

In [None]:
compute_delta_syncount(df_summary, EMPIRICAL, MODEL_NULL)
compute_delta_syncount(df_summary, EMPIRICAL, MODEL_P_disaggregated)
compute_delta_syncount(df_summary, EMPIRICAL, MODEL_PS_disaggregated)
compute_delta_syncount(df_summary, EMPIRICAL, MODEL_PSCb)
compute_delta_syncount(df_summary, MODEL_NULL, MODEL_PS_disaggregated)

df_filtered = df_summary[(df_summary.index.get_level_values("pre_id_mapped") >= 0) & (df_summary.index.get_level_values("post_id_mapped") >= 0)]
df_filtered.shape

In [None]:
selected_models = [MODEL_NULL, MODEL_P, MODEL_P_disaggregated, MODEL_Pa, MODEL_PS, MODEL_PS_disaggregated, MODEL_PSa, MODEL_PSCa, MODEL_PSCb]
df_cellular = get_df_cellular(df_filtered, selected_models, separate_compartment=True)

In [None]:
df_cellular.reset_index(inplace=True)   

mask_exc_pre = df_cellular.pre_celltype == VIS.EXC_INH[0]
mask_exc_post = df_cellular.post_celltype == VIS.EXC_INH[0]
mask_inh_pre = df_cellular.pre_celltype.isin(VIS.INH)
mask_inh_post = df_cellular.post_celltype.isin(VIS.INH)
mask_soma = df_cellular.post_compartment == VIS.SOMA[0]
mask_dend = df_cellular.post_compartment == VIS.DEND[0]
mask_ais = df_cellular.post_compartment == VIS.AIS[0] 

df_any_exc_soma = df_cellular[mask_exc_post & mask_soma]
df_any_exc_dend = df_cellular[mask_exc_post & mask_dend]
df_any_exc_ais = df_cellular[mask_exc_post & mask_ais]
df_any_inh_soma = df_cellular[mask_inh_post & mask_soma]
df_any_inh_dend = df_cellular[mask_inh_post & mask_dend]
df_any_inh_ais = df_cellular[mask_inh_post & mask_ais]

df_exc_exc_soma = df_cellular[mask_exc_pre & mask_exc_post & mask_soma]
df_exc_exc_dend = df_cellular[mask_exc_pre & mask_exc_post & mask_dend]
df_exc_exc_ais = df_cellular[mask_exc_pre & mask_exc_post & mask_ais]
df_inh_exc_soma = df_cellular[mask_inh_pre & mask_exc_post & mask_soma]
df_inh_exc_dend = df_cellular[mask_inh_pre & mask_exc_post & mask_dend]
df_inh_exc_ais = df_cellular[mask_inh_pre & mask_exc_post & mask_ais]

df_exc_inh_soma = df_cellular[mask_exc_pre & mask_inh_post & mask_soma]
df_exc_inh_dend = df_cellular[mask_exc_pre & mask_inh_post & mask_dend]
df_exc_inh_ais = df_cellular[mask_exc_pre & mask_inh_post & mask_ais]
df_inh_inh_soma = df_cellular[mask_inh_pre & mask_inh_post & mask_soma]
df_inh_inh_dend = df_cellular[mask_inh_pre & mask_inh_post & mask_dend]
df_inh_inh_ais = df_cellular[mask_inh_pre & mask_inh_post & mask_ais]

#### Save empirical connectivity stats by cell type to csv

In [None]:
pre_cells = df_cellular.groupby(["pre_id_mapped", "pre_celltype", "post_compartment"]).agg({"empirical" : "sum"})
min_num_synapses = 1
pre_cells_filtered = pre_cells[pre_cells.empirical >= min_num_synapses].reset_index()
pre_cells_filtered.groupby(["pre_celltype", "post_compartment"]).agg({"empirical" : "sum"}).to_csv(meta_folder/"prect_postcompartment.csv")

pre_cells_filtered[pre_cells_filtered.pre_celltype != 1].to_csv(meta_folder/"preid_prect_postcompartment.csv", index=False)

### Plot change in probability distributions

In [None]:
def get_plot_values(df, model_descriptor, only_connected=True):
    mask = np.ones(len(df), dtype=bool)
    if(only_connected):
        mask = df[EMPIRICAL] != 0
    return df.loc[mask, f"{model_descriptor}_observation_probability"].values

In [None]:
from lib.util_plot import *

initPlotSettings(spines_top_right=True)

fig, axes = plt.subplots(2, 3, figsize=figsize_mm_to_inch(160, 70))

model_1 = MODEL_NULL
model_2 = MODEL_PSCb #MODEL_PS_disaggregated
model_1_descriptor = STR_NULL
model_2_descriptor = STR_PS

color1 = "grey"
color2 = "orange"
alpha = 0.6
num_bins = 20

sns.histplot(get_plot_values(df_any_exc_soma, model_1), ax=axes[0, 0], color=color1, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1), label=model_1_descriptor) 
sns.histplot(get_plot_values(df_any_exc_soma, model_2), ax=axes[0, 0], color=color2, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1), label=model_2_descriptor)
axes[0, 0].set_xlim(0, 1)
axes[0, 0].set_ylabel("density")
axes[0, 0].set_title(r"$E,I\rightarrow E_S$")
axes[0, 0].legend()

sns.histplot(get_plot_values(df_any_exc_dend, model_1), ax=axes[0, 1], color=color1, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1))
sns.histplot(get_plot_values(df_any_exc_dend, model_2), ax=axes[0, 1], color=color2, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1))
axes[0, 1].set_xlim(0, 1)
axes[0, 1].set_ylabel(None)
axes[0, 1].set_title(r"$E,I\rightarrow E_D$")

sns.histplot(get_plot_values(df_any_exc_ais, model_1), ax=axes[0, 2], color=color1, kde=True, stat='density', alpha=alpha, bins= 2 * num_bins, binrange=(0, 1))
sns.histplot(get_plot_values(df_any_exc_ais, model_2), ax=axes[0, 2], color=color2, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1))
axes[0, 2].set_xlim(0, 1)
axes[0, 2].set_ylim(0, 5)
axes[0, 2].set_ylabel(None)
axes[0, 2].set_title(r"$E,I\rightarrow E_A$")

sns.histplot(get_plot_values(df_any_inh_soma, model_1), ax=axes[1, 0], color=color1, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1))
sns.histplot(get_plot_values(df_any_inh_soma, model_2), ax=axes[1, 0], color=color2, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1))
axes[1, 0].set_xlim(0, 1)
axes[1, 0].set_xlabel(r"$p_{connected\,in\,model\;|\;empirically\,observed}$")
axes[1, 0].set_ylabel("density")
axes[1, 0].set_title(r"$E,I\rightarrow I_S$")

sns.histplot(get_plot_values(df_any_inh_dend, model_1), ax=axes[1, 1], color=color1, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1))
sns.histplot(get_plot_values(df_any_inh_dend, model_2), ax=axes[1, 1], color=color2, kde=True, stat='density', alpha=alpha, bins=num_bins, binrange=(0, 1))
axes[1, 1].set_xlim(0, 1)
axes[1, 1].set_xlabel(r"$p_{connected\,in\,model\;|\;empirically\,observed}$")
axes[1, 1].set_ylabel(None)
axes[1, 1].set_title(r"$E,I\rightarrow I_D$")

# zoomed-in subplot for exc-ais
sns.histplot(get_plot_values(df_any_exc_ais, model_1), ax=axes[1, 2], color=color1, kde=True, stat='density', alpha=alpha, bins= 2 * num_bins, binrange=(0, 0.1))
axes[1, 2].set_xlim(0, 0.2)
axes[1, 2].set_xlabel(r"$p_{connected\,in\,model\;|\;empirically\,observed}$")
axes[1, 2].set_ylabel(None)
axes[1, 2].set_title(r"$E,I\rightarrow E_A$ (zoomed-in view)")


zoom_xrange = (0.0, 0.2)  # Example range corresponding to the bottom right zoomed-in subplot

# Draw lines connecting the top right and bottom right subplots
# Top right subplot coordinates
x_start_top = zoom_xrange[0]
x_end_top = zoom_xrange[1]

y_top = axes[0, 2].get_ylim()[1]  # Top of the plot area

# Bottom right subplot coordinates
y_bottom = axes[1, 2].get_ylim()[0]  # Bottom of the plot area

# Draw lines from top right to bottom right
#axes[0, 2].plot([x_start_top, x_start_top], [0, y_top], color='black', linestyle='--', lw=0.7, dashes=(2, 3))   
axes[0, 2].plot([x_end_top, x_end_top], [0, y_top], color='black', linestyle='--', lw=0.8, dashes=(3, 2))


plt.tight_layout()

img = savefig_png_svg(fig, plot_folder / f"histogram_observation_probability_{model_1}_{model_2}_all_pairs")

display(img)

### Node link diagrams (overlapping/connected, loss)

Manually select interesting cells

In [None]:
cells_of_interest = [2, 5, 25, 26, 30, 37, 38, 54, 82, 173, 210, 282, 298, 327, 337, 338, 353, 421]

In [None]:
pd.set_option('display.max_rows', 10)

In [None]:
compute_delta_syncount(df_cellular, EMPIRICAL, MODEL_NULL)
compute_delta_syncount(df_cellular, EMPIRICAL, MODEL_P_disaggregated)
compute_delta_syncount(df_cellular, EMPIRICAL, MODEL_PS_disaggregated)
compute_delta_syncount(df_cellular, EMPIRICAL, MODEL_PSCb)
compute_delta_syncount(df_cellular, MODEL_NULL, MODEL_PS_disaggregated)
compute_delta_syncount(df_cellular, MODEL_P_disaggregated, MODEL_PS_disaggregated)

df_cellular_connected = df_cellular[df_cellular[EMPIRICAL] != 0]

In [None]:
df_selected = df_cellular[(df_cellular.pre_id_mapped == 26)].reset_index(drop=True).copy()
df_selected.post_compartment

Generate nodelink diagrams

In [None]:
for selected_pre_id in cells_of_interest:

    df_selected_connected = df_cellular_connected[(df_cellular_connected.pre_id_mapped == selected_pre_id)].reset_index(drop=True).copy()
    df_selected = df_cellular[(df_cellular.pre_id_mapped == selected_pre_id)].reset_index(drop=True).copy()
    
    if(selected_pre_id == 26):
        id_color = {
            selected_pre_id : rgb_to_js_color(COLOR_INH),
            408 : rgb_to_js_color(COLOR_EXC),
            255 : rgb_to_js_color(COLOR_EXC2)
        }
    else:
        id_color = {}

    compute_delta_loss(df_selected_connected, MODEL_NULL, MODEL_PS_disaggregated)
    compute_delta_loss(df_selected, MODEL_NULL, MODEL_PS_disaggregated)

    delta_loss_stats = df_selected_connected.groupby(["pre_id_mapped", "post_id_mapped", "post_compartment"]).agg({
            get_delta_loss_column(MODEL_NULL, MODEL_PS_disaggregated) : "sum"
        })
    loss_mean = delta_loss_stats.describe().loc["mean"].values[0]
    loss_std = delta_loss_stats.describe().loc["std"].values[0]
    print(f"neuron id {selected_pre_id}: delta loss sum", delta_loss_stats[get_delta_loss_column(MODEL_NULL, MODEL_PS_disaggregated)].sum())

    df_loss_pairwise = df_selected.groupby(["pre_id_mapped", "post_id_mapped", "pre_celltype", "post_celltype"]).agg({
        get_delta_loss_column(MODEL_NULL, MODEL_PS_disaggregated) : "sum", EMPIRICAL : "sum",
        get_delta_syncount_column(EMPIRICAL, MODEL_NULL) : "sum",
        get_delta_syncount_column(MODEL_NULL, MODEL_PS_disaggregated) : "sum",
        get_delta_syncount_column(MODEL_P_disaggregated, MODEL_PS_disaggregated) : "sum"})
    df_loss_pairwise.reset_index(inplace=True)  
    df_loss_pairwise["post_compartment"] = VIS.DEND[0]  

    print(delta_loss_stats.describe())

    delta_loss = df_selected_connected[get_delta_loss_column(MODEL_NULL, MODEL_PS_disaggregated)].values

    print("colorscale min/max", vmin, vmax)

    # node styler 
    node_styler = PotentialConnectionsNodeStyler(VIS.EXC, VIS.INH, highlighted_colors=id_color)

    # delta loss 
    color_interpolator = ColorInterpolator(cmap_viridis, vmin=vmin, vmax=vmax)
    specificity_edge_styler = SpecificityEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, color_interpolator, only_highlighted_multiedge=False, syncount_labels=True)
    
    color_interpolator_all = ColorInterpolator(cmap_viridis, vmin=vmin, vmax=vmax)  
    specificity_edge_styler_all = SpecificityEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, color_interpolator_all, only_highlighted_multiedge=False, syncount_labels=False, compartment_labels=False)

    subnetworkVisualization_spec = SubnetworkVisualization(plot_folder_nodelink, node_styler, specificity_edge_styler)
    value_column = get_delta_loss_column(MODEL_NULL, MODEL_PS_disaggregated)
    subnetworkVisualization_spec.create(f"{selected_pre_id}_delta_loss", df_selected_connected, EMPIRICAL, value_column);

    subnetworkVisualization_spec_all = SubnetworkVisualization(plot_folder_nodelink, node_styler, specificity_edge_styler_all)
    subnetworkVisualization_spec_all.create(f"{selected_pre_id}_to_all_delta_loss", df_loss_pairwise, EMPIRICAL, value_column);

    # delta synapses  
    vmin_syn = -0.2
    vmax_syn = 0.2    
    syncount_column = get_delta_syncount_column(MODEL_NULL, MODEL_PS_disaggregated)
    syncount_column_subcellular = get_delta_syncount_column(MODEL_P_disaggregated, MODEL_PS_disaggregated)

    syncount_edge_styler = SpecificityEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, 
                                                    ColorInterpolator(cmap_coolwarm, vmin=vmin_syn, vmax=vmax_syn), 
                                                    only_highlighted_multiedge=False, syncount_labels=True)
    subnetworkVisualization_syncount = SubnetworkVisualization(plot_folder_nodelink, node_styler, syncount_edge_styler)
    subnetworkVisualization_syncount.create(f"{selected_pre_id}_P-PS_delta_synapses", df_selected_connected, EMPIRICAL, syncount_column_subcellular);

    syncount_edge_styler_all = SpecificityEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, 
                                                    ColorInterpolator(cmap_coolwarm, vmin=vmin_syn, vmax=vmax_syn), 
                                                    only_highlighted_multiedge=False, syncount_labels=False, compartment_labels=False)
    subnetworkVisualization_syncount = SubnetworkVisualization(plot_folder_nodelink, node_styler, syncount_edge_styler_all)
    subnetworkVisualization_syncount.create(f"{selected_pre_id}_to_all_PS_delta_synapses", df_loss_pairwise, EMPIRICAL, syncount_column);


    # potential connections 
    potential_edge_styler = PotentialConnectionsEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, syncount_labels=False, compartment_labels=False,
                                                only_highlighted_multiedge=False, only_connected_multiedge=True, no_multiedge=True, highlighted_colors=id_color)
    
    potential_edge_styler_multiedge = PotentialConnectionsEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, syncount_labels=True, compartment_labels=False,
                                                only_highlighted_multiedge=False, only_connected_multiedge=True, no_multiedge=False, highlighted_colors=id_color)

    subnetworkVisualization_poten = SubnetworkVisualization(plot_folder_nodelink, node_styler, potential_edge_styler)
    subnetworkVisualization_poten.create(f"{selected_pre_id}_potential", df_selected, EMPIRICAL, None);

    subnetworkVisualization_poten = SubnetworkVisualization(plot_folder_nodelink, node_styler, potential_edge_styler_multiedge)
    subnetworkVisualization_poten.create(f"{selected_pre_id}_potential_multiedge", df_selected_connected, EMPIRICAL, None);


In [None]:
print(plot_folder_nodelink)

### Matrix plots

#### Prepare data and settings for matrix plots

In [None]:
pre_ids = set(df_summary.index.get_level_values("pre_id_mapped"))
post_ids = set(df_summary.index.get_level_values("post_id_mapped"))
all_ids = pre_ids.union(post_ids)
all_ids.remove(-1)   
no_presynaptic = all_ids - pre_ids

#no_presynaptic.remove(248)

neuron_domain_pre = get_neuron_to_neuron_domain(df_summary, "pre_celltype_merged", "post_celltype_merged", \
                                                 celltype_order=[-1, 1, 2], ignored_neuron_ids=no_presynaptic)
neuron_domain_post = get_neuron_to_neuron_domain(df_summary, "pre_celltype_merged", "post_celltype_merged", \
                                                  celltype_order=[-1, 1, 2], ignored_neuron_ids=no_presynaptic)

In [None]:
initPlotSettings(False)

In [None]:
row_markers = {
    "pre_id_mapped" : [26, 255, 408]
}
col_markers = {
    "post_id_mapped" : [26, 255, 408]
}

#### Overlapping/connected

In [None]:
matrix_analyzer = ConnectomeMatrixAnalyzer(df_summary, plot_folder)
matrix_analyzer.set_selection()
matrix_analyzer.set_data_columns(EMPIRICAL)
matrix_analyzer.build_matrix(["pre_celltype_merged", "pre_id_mapped"], 
                             ["post_celltype_merged", "post_id_mapped"],    
                             row_domains = neuron_domain_pre,  
                             col_domains = neuron_domain_post,
                             value_label_map = {
                                 "pre_celltype_merged" : VIS.CELLTYPE_LABELS,
                                 "post_celltype_merged" : VIS.CELLTYPE_LABELS,
                             },
                             aggregation_fn="sum",
                             default_value=-1)
matrix_analyzer.colormaps["binary"].set_under("white")
matrix_analyzer.render_matrix("VIS-overlapping-connected", 
                              colormap_name="binary", vmin=-1, vmax=1, normalization_function=None, 
                              row_markers=row_markers, col_markers=col_markers,
                              col_separator_lines=True, row_separator_lines=True, high_res=True)

#### Cellular loss matrix

In [None]:
df_cellular_pairwise = get_df_cellular(df_summary, selected_models, separate_compartment=False, pre_celltype_column="pre_celltype_merged", post_celltype_column="post_celltype_merged")
compute_delta_loss(df_cellular_pairwise, MODEL_NULL, MODEL_P_disaggregated)
compute_delta_loss(df_cellular_pairwise, MODEL_NULL, MODEL_PS_disaggregated)
compute_delta_loss(df_cellular_pairwise, MODEL_NULL, MODEL_PSCb)

In [None]:
for model_name in [MODEL_P_disaggregated, MODEL_PS_disaggregated, MODEL_PSCb]:

    matrix_analyzer = ConnectomeMatrixAnalyzer(df_cellular_pairwise, plot_folder)
    matrix_analyzer.set_selection()
    matrix_analyzer.set_data_columns(get_delta_loss_column(MODEL_NULL, model_name))
    matrix_analyzer.build_matrix(["pre_celltype_merged", "pre_id_mapped"], 
                                ["post_celltype_merged", "post_id_mapped"],    
                                row_domains = neuron_domain_pre,  
                                col_domains = neuron_domain_post,
                                value_label_map = {
                                    "pre_celltype_merged" : VIS.CELLTYPE_LABELS,
                                    "post_celltype_merged" : VIS.CELLTYPE_LABELS,
                                },
                                aggregation_fn="sum",
                                default_value=np.nan)
    matrix_analyzer.colormaps["viridis"].set_bad("white")
    img = matrix_analyzer.render_matrix(f"VIS-cellular-loss-{model_name}", 
                                colormap_name="viridis", vmin=vmin, vmax=vmax, 
                                row_markers=row_markers, col_markers=col_markers,
                                col_separator_lines=True, row_separator_lines=True, high_res=True)
    display(img)

#### Delta loss matrix

In [None]:
selected_models = [MODEL_NULL, MODEL_P, MODEL_P_disaggregated, MODEL_Pa, MODEL_PS, MODEL_PS_disaggregated, MODEL_PSa, MODEL_PSCa, MODEL_PSCb]
df_pairwise = get_df_cellular(df_filtered, selected_models, separate_compartment=False, 
    pre_celltype_column="pre_celltype_merged", post_celltype_column="post_celltype_merged")

reference_model = MODEL_NULL
for target_model in selected_models:
    if(target_model == reference_model):
        continue
    compute_delta_loss(df_pairwise, reference_model, target_model)
    compute_delta_syncount(df_pairwise, reference_model, target_model)

for target_model in selected_models:
    compute_delta_syncount(df_pairwise, EMPIRICAL, target_model)

In [None]:
model_loss_cols = [col for col in df_pairwise.columns if ("loss_model" in col) and ("delta" not in col)]
model_loss_cols

### Change in loss depending on model

In [None]:
def get_aggregate_loss_for_plot(df, loss_columns):
    loss_median, loss_25, loss_75, labels = [], [], [], []
    for col in loss_columns:
        values = df[col].values
        loss_median.append(np.median(values))
        loss_25.append(np.quantile(values, 0.25))
        loss_75.append(np.quantile(values, 0.75))
        labels.append(get_formatted_model_name(col.replace("loss_","")).replace("Pd","P").replace("PSd","PS"))
    return loss_median, loss_25, loss_75, labels

In [None]:
initPlotSettings(False)

model_cols_plot = ['loss_model-null',
    'loss_model-P_disaggregated',       
    'loss_model-PS_disaggregated',
    #'loss_model-PSCa',
    'loss_model-PSCb'
]

pre_exc_mask = df_pairwise.index.get_level_values("pre_celltype_merged") == VIS.EXC_INH[0]
pre_inh_mask = df_pairwise.index.get_level_values("pre_celltype_merged") == VIS.EXC_INH[1]
post_exc_mask = df_pairwise.index.get_level_values("post_celltype_merged") == VIS.EXC_INH[0]
post_inh_mask = df_pairwise.index.get_level_values("post_celltype_merged") == VIS.EXC_INH[1]

exc_exc_mask = pre_exc_mask & post_exc_mask
exc_inh_mask = pre_exc_mask & post_inh_mask

inh_exc_mask = pre_inh_mask & post_exc_mask
inh_inh_mask = pre_inh_mask & post_inh_mask

pre_all = pre_exc_mask | pre_inh_mask

fig, ax = plt.subplots(figsize=figsize_mm_to_inch(60,40))

def add_to_plot(df, color, label, linestyle = "-", marker="."):
    loss_median, loss_25, loss_75, model_names = get_aggregate_loss_for_plot(df, model_cols_plot)
    print(loss_median)
    x = np.arange(len(model_names))
    ax.plot(x, loss_median, linestyle=linestyle, marker=marker, c = color, label=label, lw=1)
    return model_names

#model_names = add_to_plot(df_pairwise, "lightgrey", "all neuron pairs")

model_names = add_to_plot(df_pairwise[pre_inh_mask], "grey", "presyn. neuron is inhibitory", "-", "o")
model_names = add_to_plot(df_pairwise[pre_exc_mask], "black", "presyn. neuron is excitatory", "--", "^")

#model_names = add_to_plot(df_pairwise[exc_exc_mask], "red", "EE", "-","^")
#model_names = add_to_plot(df_pairwise[exc_inh_mask], "blue", "EI", "--","^")

#model_names = add_to_plot(df_pairwise[exc_exc_mask], "red", "IE", "-","o")
#model_names = add_to_plot(df_pairwise[exc_inh_mask], "blue", "II", "--","o")


#plt.fill_between(x, loss_25, loss_75, color='grey', alpha=0.3, label='25th-75th Percentile')

plt.legend()
x = np.arange(len(model_names))
plt.xticks(x, model_names)
plt.xlim((-0.2, len(model_names) - 0.8))
plt.ylim(-0.01, 0.13)
#plt.xlabel("model")
plt.ylabel("model accuracy \n (loss function)")
plt.subplots_adjust(left = 0.3, bottom=0.25, top=0.9)

img = savefig_png_svg(fig, plot_folder/"loss_model_comparison")
display(img)
