In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

from lib.constants import *
from lib.nodelink_viewer import *
import lib.VIS_L23_preprocessing.vis_L23_constants as VIS
from lib.pandas_compute import *
from lib.matrix_analyzer import *
from lib.multilevel_analysis import *
from lib.pandas_impl import *
from lib.pandas_stats_impl import ConnectomeRealization
from lib.pandas_stats_VIS import *
from lib.util_plot import *

In [None]:
eval_folder = Path.cwd() / 'data' / 'eval' / 'VIS_24-12-17_5mu'
eval_folder2 = Path.cwd() / 'data' / 'eval' / 'VIS_24-12-17_25mu'
meta_folder = Path.cwd() / 'data' / 'VIS'/ 'meta'
plot_folder = eval_folder / 'plots' 
plot_folder_nodelink = plot_folder / 'nodelink'
plot_folder_nodelink.mkdir(exist_ok=True, parents=True)

#### Data preparation

In [None]:
def load_analysis(folder):
    filename = folder / "multilevel_analysis.pkl"
    with open(filename, 'rb') as file:
        multilevel_analysis = pickle.load(file)
        return multilevel_analysis    

multilevel_analysis = load_analysis(eval_folder)    
df_summary = multilevel_analysis.df_summary.reset_index()
df_summary.set_index(['pre_celltype', 'post_celltype', 'pre_id_mapped', 'post_id_mapped',
       'post_compartment', 'overlap_volume', 'post_celltype_merged',
       'pre_celltype_merged'], inplace=True)

multilevel_analysis2 = load_analysis(eval_folder2)    
df_summary2 = multilevel_analysis2.df_summary.reset_index()
df_summary2.set_index(['pre_celltype', 'post_celltype', 'pre_id_mapped', 'post_id_mapped',
       'post_compartment', 'overlap_volume', 'post_celltype_merged',
       'pre_celltype_merged'], inplace=True)


selected_models = [MODEL_NULL, MODEL_P, MODEL_PS, MODEL_P_disaggregated, MODEL_PS_disaggregated, MODEL_PSCa, MODEL_PSCb]
selected_models_matrix = [MODEL_NULL, MODEL_P, MODEL_PS]
selected_models_nodelink = [MODEL_NULL, MODEL_P, MODEL_PS]
selected_models_distributions = [MODEL_NULL, MODEL_P, MODEL_PS, MODEL_PSCb]
selected_models_roc = [MODEL_NULL, MODEL_P, MODEL_PS, MODEL_P_disaggregated, MODEL_PS_disaggregated, MODEL_PSCa, MODEL_PSCb]

In [None]:
df_cellular = get_df_cellular(df_summary, selected_models, excluded_neuron_ids=[-1],
                                           pre_celltype_column="pre_celltype_merged",
                                           post_celltype_column="post_celltype_merged")
df_cellular_compartment = get_df_cellular(df_summary, selected_models, excluded_neuron_ids=[-1], 
                                          separate_compartment=True,
                                          pre_celltype_column="pre_celltype_merged",
                                          post_celltype_column="post_celltype_merged")

df_cellular_compartment2 = get_df_cellular(df_summary2, selected_models, excluded_neuron_ids=[-1], 
                                          separate_compartment=True,
                                          pre_celltype_column="pre_celltype_merged",
                                          post_celltype_column="post_celltype_merged")

In [None]:
import numpy as np

def create_realizations(expected_syncount, num_samples = 1000):
    connected = np.zeros(expected_syncount.shape)
    for _ in range(num_samples):
        connected += (np.random.poisson(expected_syncount) > 0).astype(int)
    return connected

In [None]:
from matplotlib.colors import TwoSlopeNorm, SymLogNorm

def norm_realizations(realizations, mode, empirical_counts=None, 
                      min_quantile = 0.25, mid_quantile = 0.5, max_quantile = 0.75,
                      num_realizations = None):
    
    min_value = np.quantile(realizations, min_quantile)
    max_value = np.quantile(realizations, max_quantile)
    mid_value = np.quantile(realizations, mid_quantile)
    if(min_value == mid_value):
        mid_value = min_value + 0.1 * (max_value - min_value)

    if(mode == "symlog"):
        norm_fn = SymLogNorm(1)
        normed_array = 2 * norm_fn(realizations) - 1
        return normed_array.data
    elif(mode == "twosloped"):
        norm_fn = TwoSlopeNorm(vmin=min_value, vcenter=mid_value, vmax=max_value) 
        normed_array = 2 * norm_fn(np.clip(realizations, min_value, max_value)) - 1
        return normed_array.data
    elif(mode == "separated"):
        assert empirical_counts is not None
        mask_connected = empirical_counts > 0

        norm_fn = SymLogNorm(1)
        #normed_connected = norm_fn(realizations[mask_connected]).data
        normed_connected = realizations[mask_connected] / np.max(realizations[mask_connected])
        normed_unconnected = norm_fn(realizations[~mask_connected]).data - 1
        #normed_unconnected = realizations[~mask_connected] / np.max(realizations[~mask_connected]) - 1

        normed_array = np.zeros_like(realizations)
        normed_array[mask_connected] = normed_connected
        normed_array[~mask_connected] = normed_unconnected        
        return normed_array 

    elif(mode == "fraction_connected"):
        assert empirical_counts is not None
        assert num_realizations is not None
        mask_connected = empirical_counts > 0

        normed_array = np.zeros_like(realizations)
        normed_array[mask_connected] = realizations[mask_connected] / num_realizations
        normed_array[~mask_connected] = -1
        return normed_array
    
    elif(mode == "fraction_unconnected"):
        assert empirical_counts is not None
        assert num_realizations is not None
        mask_unconnected = empirical_counts == 0

        normed_array = np.zeros_like(realizations)
        normed_array[mask_unconnected] = 1 - realizations[mask_unconnected] / num_realizations
        normed_array[~mask_unconnected] = -1
        return normed_array

    else:
        raise ValueError(f"Unknown mode {mode}")


In [None]:
def realize_and_norm_deviation(df, model_name, num_realizations = 1000):
    generator = ConnectomeRealization(df)

    realizations_binary_sum = np.zeros(df.shape[0])

    empirical_counts = df[EMPIRICAL].values

    mask_connected = empirical_counts > 0

    print(np.sum(mask_connected))

    pct_correct = np.zeros(num_realizations)

    for i in range(num_realizations):
        realization = generator.generate(model_name)
        realizations_binary_sum += (realization  > 0).astype(int)

        realization_binary = (realization > 0).astype(int)
         
        pct_correct[i] = np.sum(realization[mask_connected] == empirical_counts[mask_connected]) / np.sum(mask_connected)

    realizations_normed_connected = norm_realizations(realizations_binary_sum, "fraction_connected", empirical_counts=empirical_counts, min_quantile=0, max_quantile=1, num_realizations=num_realizations)
    realizations_normed_unconnected = norm_realizations(realizations_binary_sum, "fraction_unconnected", empirical_counts=empirical_counts, min_quantile=0, max_quantile=1, num_realizations=num_realizations)

    df.loc[:, f"{model_name}_realizations_binary_sum"] = realizations_binary_sum
    df.loc[:, f"{model_name}_realizations_normed_connected"] = realizations_normed_connected
    df.loc[:, f"{model_name}_realizations_normed_unconnected"] = realizations_normed_unconnected

    return pct_correct

In [None]:
pct_correct_by_model = {}

for model_name in selected_models:
    realize_and_norm_deviation(df_cellular, model_name)
    pct_correct_by_model[model_name] = realize_and_norm_deviation(df_cellular_compartment, model_name)
    realize_and_norm_deviation(df_cellular_compartment2, model_name)

In [None]:
values = pct_correct_by_model[MODEL_PSCb]
print(np.mean(values), np.std(values), np.std(values) / np.sqrt(values.size))

In [None]:
import numpy as np
import matplotlib.pyplot as plt

data = [
    pct_correct_by_model[MODEL_NULL],
    pct_correct_by_model[MODEL_P],
    pct_correct_by_model[MODEL_PS],
    pct_correct_by_model[MODEL_PSCb]
]

labels = [MODEL_NULL, MODEL_P, MODEL_PS, MODEL_PSCb]

fig = plt.figure(figsize=(10, 6))
plt.boxplot(data, labels=labels)
plt.ylabel("percentage correct")
plt.grid(False)
plt.show()

img = savefig_png_svg(fig, plot_folder / f"realization_boxplot")

display(img)

### Model comparison

#### Plot distributions

In [None]:
from lib.util_plot import *

initPlotSettings(spines_top_right=True)

df_tmp = df_cellular_compartment.copy()
df_tmp2 = df_cellular_compartment2.copy()

num_models = len(selected_models_distributions)

fig, axes = plt.subplots(num_models, 5, figsize=figsize_mm_to_inch(180, 25 * num_models))


color1 = "orange"
color2 = "grey"
alpha = 0.4
num_bins = 20

df_tmp = df_cellular_compartment.copy()
df_tmp.reset_index(inplace=True)
df_tmp.set_index(["pre_celltype_merged", "post_celltype_merged"], inplace=True)
df_tmp.sort_index(inplace=True)

df_tmp2 = df_cellular_compartment2.copy()
df_tmp2.reset_index(inplace=True)
df_tmp2.set_index(["pre_celltype_merged", "post_celltype_merged"], inplace=True)
df_tmp2.sort_index(inplace=True)

def get_valid_values(values):
    return values[values > -1]

for row_idx, model_name in enumerate(selected_models_distributions):

    values_all = df_tmp.loc[:, f"{model_name}_realizations_normed_connected"].values
    values_exc_exc = df_tmp.loc[(VIS.EXC_INH[0], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_exc_inh = df_tmp.loc[(VIS.EXC_INH[0], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values
    values_inh_exc = df_tmp.loc[(VIS.EXC_INH[1], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_inh_inh = df_tmp.loc[(VIS.EXC_INH[1], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values

    values_all2 = df_tmp2.loc[:, f"{model_name}_realizations_normed_connected"].values
    values_exc_exc2 = df_tmp2.loc[(VIS.EXC_INH[0], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_exc_inh2 = df_tmp2.loc[(VIS.EXC_INH[0], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values
    values_inh_exc2 = df_tmp2.loc[(VIS.EXC_INH[1], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_inh_inh2 = df_tmp2.loc[(VIS.EXC_INH[1], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values

    sns.histplot(get_valid_values(values_all2), ax=axes[row_idx, 0], color=color2, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1), label=r"$25\mu m$") 
    sns.histplot(get_valid_values(values_exc_exc2), ax=axes[row_idx, 1], color=color2, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_exc_inh2), ax=axes[row_idx, 2], color=color2, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_inh_exc2), ax=axes[row_idx, 3], color=color2, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_inh_inh2), ax=axes[row_idx, 4], color=color2, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))

    sns.histplot(get_valid_values(values_all), ax=axes[row_idx, 0], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1), label=r"$5\mu m$") 
    sns.histplot(get_valid_values(values_exc_exc), ax=axes[row_idx, 1], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_exc_inh), ax=axes[row_idx, 2], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_inh_exc), ax=axes[row_idx, 3], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_inh_inh), ax=axes[row_idx, 4], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))

    formatted_model_name = get_formatted_model_name(model_name)
    axes[row_idx, 0].set_xlim(0, 1)
    axes[row_idx, 0].set_ylabel("count")
    axes[row_idx, 0].set_title(r"model {}".format(formatted_model_name))
    if(row_idx == 0):
        axes[row_idx, 0].legend()

    axes[row_idx, 1].set_xlim(0, 1)
    axes[row_idx, 1].set_ylabel(None)
    axes[row_idx, 1].set_title(r"model {}: $EE$".format(formatted_model_name))

    axes[row_idx, 2].set_xlim(0, 1)
    axes[row_idx, 2].set_ylabel(None)
    axes[row_idx, 2].set_title(r"model {}: $EI$".format(formatted_model_name))

    axes[row_idx, 3].set_xlim(0, 1)
    axes[row_idx, 3].set_ylabel(None)
    axes[row_idx, 3].set_title(r"model {}: $IE$".format(formatted_model_name))

    axes[row_idx, 4].set_xlim(0, 1)
    axes[row_idx, 4].set_ylabel(None)
    axes[row_idx, 4].set_title(r"model {}: $II$".format(formatted_model_name))


plt.tight_layout()

img = savefig_png_svg(fig, plot_folder / f"model_comparison_distributions_5_25mu")

display(img)

In [None]:
from lib.util_plot import *

initPlotSettings(spines_top_right=True)

df_tmp = df_cellular_compartment.copy()
df_tmp2 = df_cellular_compartment2.copy()

num_models = len(selected_models_distributions)

fig, axes = plt.subplots(num_models, 5, figsize=figsize_mm_to_inch(180, 25 * num_models))


color1 = "orange"
color2 = "grey"
alpha = 0.4
num_bins = 20

df_tmp = df_cellular_compartment.copy()
df_tmp.reset_index(inplace=True)
df_tmp.set_index(["pre_celltype_merged", "post_celltype_merged"], inplace=True)
df_tmp.sort_index(inplace=True)

df_tmp2 = df_cellular_compartment2.copy()
df_tmp2.reset_index(inplace=True)
df_tmp2.set_index(["pre_celltype_merged", "post_celltype_merged"], inplace=True)
df_tmp2.sort_index(inplace=True)

def get_valid_values(values):
    return values[values > -1]

for row_idx, model_name in enumerate(selected_models_distributions):

    values_all = df_tmp.loc[:, f"{model_name}_realizations_normed_connected"].values
    values_exc_exc = df_tmp.loc[(VIS.EXC_INH[0], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_exc_inh = df_tmp.loc[(VIS.EXC_INH[0], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values
    values_inh_exc = df_tmp.loc[(VIS.EXC_INH[1], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_inh_inh = df_tmp.loc[(VIS.EXC_INH[1], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values

    values_all2 = df_tmp2.loc[:, f"{model_name}_realizations_normed_connected"].values
    values_exc_exc2 = df_tmp2.loc[(VIS.EXC_INH[0], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_exc_inh2 = df_tmp2.loc[(VIS.EXC_INH[0], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values
    values_inh_exc2 = df_tmp2.loc[(VIS.EXC_INH[1], VIS.EXC_INH[0]), f"{model_name}_realizations_normed_connected"].values
    values_inh_inh2 = df_tmp2.loc[(VIS.EXC_INH[1], VIS.EXC_INH[1]), f"{model_name}_realizations_normed_connected"].values

    sns.histplot(get_valid_values(values_all), ax=axes[row_idx, 0], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1), label=r"$5\mu m$") 
    sns.histplot(get_valid_values(values_exc_exc), ax=axes[row_idx, 1], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_exc_inh), ax=axes[row_idx, 2], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_inh_exc), ax=axes[row_idx, 3], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))
    sns.histplot(get_valid_values(values_inh_inh), ax=axes[row_idx, 4], color=color1, kde=True, stat='count', alpha=alpha, bins=num_bins, binrange=(0, 1))

    formatted_model_name = get_formatted_model_name(model_name)
    axes[row_idx, 0].set_xlim(0, 1)
    axes[row_idx, 0].set_ylabel("count")
    axes[row_idx, 0].set_title(r"model {}".format(formatted_model_name))

    axes[row_idx, 1].set_xlim(0, 1)
    axes[row_idx, 1].set_ylabel(None)
    axes[row_idx, 1].set_title(r"model {}: $EE$".format(formatted_model_name))

    axes[row_idx, 2].set_xlim(0, 1)
    axes[row_idx, 2].set_ylabel(None)
    axes[row_idx, 2].set_title(r"model {}: $EI$".format(formatted_model_name))

    axes[row_idx, 3].set_xlim(0, 1)
    axes[row_idx, 3].set_ylabel(None)
    axes[row_idx, 3].set_title(r"model {}: $IE$".format(formatted_model_name))

    axes[row_idx, 4].set_xlim(0, 1)
    axes[row_idx, 4].set_ylabel(None)
    axes[row_idx, 4].set_title(r"model {}: $II$".format(formatted_model_name))


plt.tight_layout()

img = savefig_png_svg(fig, plot_folder / f"model_comparison_distributions_5mu")

display(img)

In [None]:
def print_pct_correct(normed_values, threshold = 0.95):
    values = normed_values[normed_values > -1]
    print(values[values > threshold].size / values.size)

In [None]:
print_pct_correct(df_tmp.loc[:, f"{MODEL_NULL}_realizations_normed_connected"].values)
print_pct_correct(df_tmp.loc[:, f"{MODEL_P}_realizations_normed_connected"].values)
print_pct_correct(df_tmp.loc[:, f"{MODEL_PS}_realizations_normed_connected"].values)
print_pct_correct(df_tmp.loc[:, f"{MODEL_PSCb}_realizations_normed_connected"].values)

### Matrix plots

#### Prepare matrix plots data

In [None]:
pre_ids = set(df_summary.index.get_level_values("pre_id_mapped"))
post_ids = set(df_summary.index.get_level_values("post_id_mapped"))
all_ids = pre_ids.union(post_ids)
all_ids.remove(-1)   
no_presynaptic = all_ids - pre_ids
no_presynaptic.add(-1)
#no_presynaptic.remove(248)

neuron_domain_pre = get_neuron_to_neuron_domain(df_summary, "pre_celltype_merged", "post_celltype_merged", \
                                                 celltype_order=[-1, 1, 2], ignored_neuron_ids=no_presynaptic)
neuron_domain_post = get_neuron_to_neuron_domain(df_summary, "pre_celltype_merged", "post_celltype_merged", \
                                                  celltype_order=[-1, 1, 2], ignored_neuron_ids=no_presynaptic)

In [None]:
initPlotSettings(False)

In [None]:
row_markers = {
    "pre_id_mapped" : [26, 248, 396]
}
col_markers = {
    "post_id_mapped" : [26, 248, 396]
}

#### Connected

In [None]:
for model_name in selected_models_matrix:
    matrix_analyzer = ConnectomeMatrixAnalyzer(df_cellular, plot_folder)
    matrix_analyzer.set_selection()
    matrix_analyzer.set_data_columns(f"{model_name}_realizations_normed_connected")
    matrix_analyzer.build_matrix(["pre_celltype_merged", "pre_id_mapped"], 
                                ["post_celltype_merged", "post_id_mapped"],    
                                row_domains = neuron_domain_pre,  
                                col_domains = neuron_domain_post,
                                value_label_map = {
                                    "pre_celltype_merged" : VIS.CELLTYPE_LABELS,
                                    "post_celltype_merged" : VIS.CELLTYPE_LABELS,
                                },
                                aggregation_fn="sum",
                                default_value=np.nan)
    colormap_name = "BrBG"
    matrix_analyzer.colormaps[colormap_name].set_bad("lightgrey") # no overlap
    matrix_analyzer.colormaps[colormap_name].set_under("lightgrey") # unconnected in empirically observed connectome
    img = matrix_analyzer.render_matrix(f"VIS_realizations-connected_{model_name}", 
                                colormap_name=colormap_name, vmin=0, vmax=1, 
                                row_markers=row_markers, col_markers=col_markers,
                                col_separator_lines=True, row_separator_lines=True, high_res=True)
    display(img)

#### Unconnected

In [None]:
for model_name in selected_models_matrix:
    matrix_analyzer = ConnectomeMatrixAnalyzer(df_cellular, plot_folder)
    matrix_analyzer.set_selection()
    matrix_analyzer.set_data_columns(f"{model_name}_realizations_normed_unconnected")
    matrix_analyzer.build_matrix(["pre_celltype_merged", "pre_id_mapped"], 
                                ["post_celltype_merged", "post_id_mapped"],    
                                row_domains = neuron_domain_pre,  
                                col_domains = neuron_domain_post,
                                value_label_map = {
                                    "pre_celltype_merged" : VIS.CELLTYPE_LABELS,
                                    "post_celltype_merged" : VIS.CELLTYPE_LABELS,
                                },
                                aggregation_fn="sum",
                                default_value=np.nan)
    colormap_name = "BrBG"
    matrix_analyzer.colormaps[colormap_name].set_bad("lightgrey") # no overlap
    matrix_analyzer.colormaps[colormap_name].set_under("lightgrey") # connected in empirically observed connectome
    img = matrix_analyzer.render_matrix(f"VIS_realizations-unconnected_{model_name}", 
                                colormap_name=colormap_name, vmin=0, vmax=1, 
                                row_markers=row_markers, col_markers=col_markers,
                                col_separator_lines=True, row_separator_lines=True, high_res=True)
    display(img)

### Node link diagrams

Manually select interesting cells

In [None]:
#pd.set_option('display.max_rows', None)
#df_cellular_connected[df_cellular_connected.pre_celltype.isin(VIS.INH_23)].groupby(["pre_id_mapped", "post_compartment", "post_id_mapped"]).agg({EMPIRICAL : "sum"})

In [None]:
selected_pre_id = 26 # 327

id_color = {
        selected_pre_id : rgb_to_js_color(COLOR_INH),
        408 : rgb_to_js_color(COLOR_EXC),
        255 : rgb_to_js_color(COLOR_EXC2)
    }

Generate nodelink diagrams

In [None]:
for model_name in selected_models_nodelink:

    for connected_unconnected in ["connected", "unconnected"]: 

        # filter data
        if(connected_unconnected == "connected"):
            value_column = f"{model_name}_realizations_normed_connected"
            df_selected_cellular = df_cellular[(df_cellular.index.get_level_values("pre_id_mapped") == selected_pre_id) & (df_cellular[EMPIRICAL] > 0)].copy()
            df_selected_cellular_compartment = df_cellular_compartment[(df_cellular_compartment.index.get_level_values("pre_id_mapped") == selected_pre_id) & (df_cellular_compartment[EMPIRICAL] > 0)].copy()
        elif(connected_unconnected == "unconnected"):
            value_column = f"{model_name}_realizations_normed_unconnected"
            df_selected_cellular = df_cellular[(df_cellular.index.get_level_values("pre_id_mapped") == selected_pre_id) & (df_cellular[EMPIRICAL] == 0)].copy()
            df_selected_cellular_compartment = df_cellular_compartment[(df_cellular_compartment.index.get_level_values("pre_id_mapped") == selected_pre_id) & (df_cellular_compartment[EMPIRICAL] == 0)].copy()
        else:
            raise ValueError(f"Unknown connected_unconnected {connected_unconnected}")

        df_selected_cellular.reset_index(inplace=True)
        df_selected_cellular["post_compartment"] = VIS.DEND[0]
        df_selected_cellular_compartment.reset_index(inplace=True)


        node_styler = PotentialConnectionsNodeStyler([VIS.EXC_INH[0]], [VIS.EXC_INH[1]], highlighted_colors=id_color)

        # cellular
        edge_styler_cellular = SpecificityEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, 
                                                        ColorInterpolator(cmap_brbg, vmin=0, vmax=1), 
                                                        only_highlighted_multiedge=False, syncount_labels=False, compartment_labels=False)
        sv_cellular = SubnetworkVisualization(plot_folder_nodelink, node_styler, edge_styler_cellular, 
                                                                    pre_celltype_column = "pre_celltype_merged", post_celltype_column = "post_celltype_merged")
        sv_cellular.create(f"{selected_pre_id}_{model_name}_fraction-{connected_unconnected}_cellular", df_selected_cellular, EMPIRICAL, value_column);

        # subcellular
        edge_styler_cellular_compartment = SpecificityEdgeStyler(VIS.DEND, VIS.SOMA, VIS.AIS, 
                                                        ColorInterpolator(cmap_brbg, vmin=0, vmax=1), 
                                                        only_highlighted_multiedge=False, syncount_labels=False, compartment_labels=True)
        sv_cellular_compartment = SubnetworkVisualization(plot_folder_nodelink, node_styler, edge_styler_cellular_compartment,
                                                                    pre_celltype_column = "pre_celltype_merged", post_celltype_column = "post_celltype_merged")
        sv_cellular_compartment.create(f"{selected_pre_id}_{model_name}_fraction-{connected_unconnected}_cellular-compartment", df_selected_cellular_compartment, EMPIRICAL, value_column);