In [None]:
import pandas as pd
import numpy as np
import pickle
from pathlib import Path

from IPython.display import display
from lib.util_plot import *
from lib.constants import *

import lib.VIS_L23_preprocessing.vis_L23_constants as VIS
from lib.pandas_impl import *
from lib.pandas_stats_impl import *
from lib.pandas_stats_VIS import VISAggregateStatistics
from lib.multilevel_analysis import MultilevelAnalysis
from models import *

from lib.parameter_inference import ParameterDomain, ParameterInference

#### Prepare data

In [None]:
data_folder = Path.cwd() / 'data' / 'VIS'
eval_folder = Path.cwd() / 'data' / 'eval' / 'VIS_SBI_example_24-12-17'
eval_folder.mkdir(parents=True, exist_ok=True)
plot_folder = eval_folder / "plots"
plot_folder.mkdir(parents=True, exist_ok=True)
synapse_file = data_folder / 'synapses_grid-25000_aggregated.csv'

In [None]:
def merge_inh_celltypes(df_summary):
    df = df_summary.copy()

    if("post_celltype_merged" not in df.index.names):
        post_celltype_merged = df.index.get_level_values("post_celltype").values.copy()
        post_celltype_merged[post_celltype_merged > 1] = 2 
        df.loc[:, "post_celltype_merged"] = post_celltype_merged
        df.set_index("post_celltype_merged", append=True, inplace=True)

    if("pre_celltype_merged" not in df.index.names):
        pre_celltype_merged = df.index.get_level_values("pre_celltype").values.copy()
        pre_celltype_merged[pre_celltype_merged > 1] = 2 
        df.loc[:, "pre_celltype_merged"] = pre_celltype_merged
        df.set_index("pre_celltype_merged", append=True, inplace=True)

    df.sort_index(inplace=True)
    return df

In [None]:
df_synapses = pd.read_csv(synapse_file)

# filter self connections
df_synapses = df_synapses[(df_synapses.pre_id_mapped == -1) | (df_synapses.pre_id_mapped != df_synapses.post_id_mapped)]
df_synapses.reset_index(drop=True, inplace=True)

df_synapses_indexed = df_synapses.set_index(["pre_celltype", "post_celltype",  \
                                             "pre_id_mapped", "post_id_mapped", "post_compartment", "overlap_volume"])

df_synapses_indexed = merge_inh_celltypes(df_synapses_indexed)

df_synapses_indexed.sort_index(inplace=True)
df_synapses_indexed.synapse_count.sum()

In [None]:
index_data = compile_index_data(eval_folder, df_synapses_indexed)

#### Run default models

In [None]:
statistics = VISAggregateStatistics(index_data, compute_syncounts=True, compute_motifs=True, compute_clusters=True)
multilevel_analysis = MultilevelAnalysis(index_data, df_synapses_indexed, statistics, num_realizations=100)
multilevel_analysis.run_null_and_empirical()

#### Model definition

In [None]:
class SimulationModel:
    def __init__(self, name, index_data, groupby_fields, parameter_domain, sequential=False):
        self.index_data = index_data
        self.groupby_fields = groupby_fields
        self.parameter_domain = parameter_domain
        self.sequential = sequential
        self.name = name
        self.prior_descriptor = f"{name}_prior"
        self.posterior_descriptor = f"{name}_posterior"
    

    def compute(self, df_summary, reference_model_descriptor, parameters, group_index_column_reference=None):
        assert len(self.groupby_fields)

        #values_empirical = df_summary[EMPIRICAL].values
        values_reference_model = df_summary[reference_model_descriptor].values

        if(self.sequential):
            assert group_index_column_reference is not None
            groupby_fields = self.groupby_fields + [group_index_column_reference]
        else:
            groupby_fields = self.groupby_fields
        indices = df_summary.groupby(groupby_fields).indices
         
        specificity_values = np.ones(len(df_summary))
        values_model = values_reference_model.copy()
        group_indices_model = np.ones(len(df_summary)).astype(int)

        # iterate over groupings
        group_index = 0
        for group_key, global_indices in indices.items():
            group_indices_model[global_indices] = group_index
            
            if(group_key in self.parameter_domain.keys):
                if(self.sequential):
                    parameter_col_idx = self.parameter_domain.get_parameter_column_index(group_key[:-1])
                else:
                    parameter_col_idx = self.parameter_domain.get_parameter_column_index(group_key)
                specificity_value = parameters[parameter_col_idx]
            else:
                specificity_value = 1    

            expected_syncounts = specificity_value * values_reference_model[global_indices]
            values_model[global_indices] = expected_syncounts
            specificity_values[global_indices] = specificity_value

            group_index += 1
                
        return values_model, specificity_values, group_indices_model

#### Run model with sampled parameters

In [None]:
parameter_domain_P = ParameterDomain()
parameter_domain_P.add_parameter((VIS.E, VIS.E), 0, 10, label=r"$\alpha_{EE}$")
parameter_domain_P.add_parameter((VIS.E, VIS.I), 0, 10, label=r"$\alpha_{EI}$")
parameter_domain_P.add_parameter((VIS.I, VIS.E), 0, 10, label=r"$\alpha_{IE}$")
parameter_domain_P.add_parameter((VIS.I, VIS.I), 0, 10, label=r"$\alpha_{II}$")

parameter_inference_P = ParameterInference(parameter_domain_P, num_posterior_samples=500, max_epochs=100)
prior_parameters_P = parameter_inference_P.sample_parameters(500)

In [None]:
parameter_domain_PS = ParameterDomain()
parameter_domain_PS.add_parameter((VIS.E, VIS.E, VIS.SOMA[0]), 0, 2, label=r"$\alpha_{EE_S}$")
parameter_domain_PS.add_parameter((VIS.E, VIS.E, VIS.DEND[0]), 0, 2, label=r"$\alpha_{EE_D}$")
parameter_domain_PS.add_parameter((VIS.E, VIS.E, VIS.AIS[0]), 0, 2, label=r"$\alpha_{EE_A}$")
parameter_domain_PS.add_parameter((VIS.E, VIS.I, VIS.SOMA[0]), 0, 2, label=r"$\alpha_{EI_S}$")
parameter_domain_PS.add_parameter((VIS.E, VIS.I, VIS.DEND[0]), 0, 2, label=r"$\alpha_{EI_D}$")
parameter_domain_PS.add_parameter((VIS.E, VIS.I, VIS.AIS[0]), 0, 2, label=r"$\alpha_{EI_A}$")

parameter_inference_PS = ParameterInference(parameter_domain_PS, num_posterior_samples=500, max_epochs=100)
prior_parameters_PS = parameter_inference_PS.sample_parameters(2000)

In [None]:
model_P_impl = Model(index_data, ["pre_celltype_merged", "post_celltype_merged"])
model_S_impl = Model(index_data, ["post_compartment"])

model_P_sim_impl = SimulationModel("sim-P", index_data, ["pre_celltype_merged", "post_celltype_merged"], parameter_domain_P, sequential=False)
model_PS_sim_impl = SimulationModel("sim-PS", index_data, ["pre_celltype_merged", "post_celltype_merged", "post_compartment"], parameter_domain_PS)

In [None]:
multilevel_analysis.run_model(
    MODEL_NULL,
    model_P_impl,
    MODEL_P
)

In [None]:
model_S_impl = Model(index_data, ["post_compartment"])

multilevel_analysis.run_model(
    MODEL_P,
    model_S_impl,
    MODEL_PS
)

In [None]:
multilevel_analysis.run_model_with_parameters(
    MODEL_NULL,
    model_P_sim_impl,
    model_P_sim_impl.prior_descriptor,
    prior_parameters_P
)

In [None]:
multilevel_analysis.run_model_with_parameters(
    MODEL_P,
    model_PS_sim_impl,
    model_PS_sim_impl.prior_descriptor,
    prior_parameters_PS
)

#### Infer posterior distribution

In [None]:
x_0 = multilevel_analysis.stats.to_numpy(SELECTION_CELLTYPE, EMPIRICAL)
x_model = multilevel_analysis.stats.to_numpy(SELECTION_CELLTYPE, model_P_sim_impl.prior_descriptor)

_ = parameter_inference_P.infer_parameters(x_model, x_0)

posterior_parameters_P = parameter_inference_P.samples_posterior
posterior_parameters_P.mean(axis = 0)

In [None]:
fig, _ = parameter_inference_P.plot_posterior(figsize=(10,10))
image = savefig_png_svg(fig, plot_folder / "posterior_parameters_P")    
# display(image)

In [None]:
x_0 = multilevel_analysis.stats.to_numpy(SELECTION_EXC_SUBCELLULAR, EMPIRICAL)
x_model = multilevel_analysis.stats.to_numpy(SELECTION_EXC_SUBCELLULAR, model_PS_sim_impl.prior_descriptor)

_ = parameter_inference_PS.infer_parameters(x_model, x_0)

posterior_parameters_PS = parameter_inference_PS.samples_posterior
posterior_parameters_PS.mean(axis = 0)

In [None]:
fig, _ = parameter_inference_PS.plot_posterior(figsize=(10,10))
image = savefig_png_svg(fig, plot_folder / "posterior_parameters_PS")   
display(image) 

#### Run model with posterior parameters

In [None]:
multilevel_analysis.run_model_with_parameters(
    MODEL_NULL,
    model_P_sim_impl,
    model_P_sim_impl.posterior_descriptor,
    posterior_parameters_P
)

In [None]:
multilevel_analysis.run_model_with_parameters(
    MODEL_P,
    model_PS_sim_impl,
    model_PS_sim_impl.posterior_descriptor,
    posterior_parameters_PS
)

In [None]:
filename = eval_folder / "multilevel_analysis.pkl"
with open(filename, 'wb') as file:
    pickle.dump(multilevel_analysis, file)


np.savetxt(eval_folder / "posterior_parameters_P", posterior_parameters_P)
np.savetxt(eval_folder / "posterior_parameters_PS", posterior_parameters_PS)

#### Plot connectivity statistics

In [None]:
initPlotSettings(False)

COLORS_CATEGORICAL_MUTED = sns.color_palette("muted") 
COLORS_CATEGORICAL_DARK = sns.color_palette("dark") 
COLORS_CATEGORICAL_COLORBLIND = sns.color_palette("colorblind") 
COLORS_CATEGORICAL_BRIGHT = sns.color_palette("bright") 
COLORS_CATEGORICAL_PASTEL = sns.color_palette("pastel") 

C_EMPIRICAL = COLORS_CATEGORICAL[2]
C_H0 = COLOR_EMPIRICAL
C_M1 = COLORS_CATEGORICAL_MUTED[7]
C_M2 = COLORS_CATEGORICAL_BRIGHT[7]
C_M3 = COLORS_CATEGORICAL_PASTEL[7]

In [None]:
datasets = [multilevel_analysis.stats.to_numpy(SELECTION_CELLTYPE, EMPIRICAL),
            multilevel_analysis.stats.to_numpy(SELECTION_CELLTYPE, MODEL_NULL),
            multilevel_analysis.stats.to_numpy(SELECTION_CELLTYPE, model_P_sim_impl.posterior_descriptor)]

img = plot_bar_chart(datasets,
                     #dataset_labels = [STR_EMPIRICAL, STR_NULL, STR_P], 
                     colors=[C_EMPIRICAL, C_H0, C_M1],
                     x_labels=SELECTION_CELLTYPE,
                     y_lim=(1, 10**5), use_log=True, error_bars=True,
                     adjust_left=0.2, fig_size=figsize_mm_to_inch(60,40), 
                     filename=plot_folder/f"VIS_{SYNCOUNT_POPULATION}_INH_model_P.png")
display(img)

In [None]:
datasets = [multilevel_analysis.stats.to_numpy(SELECTION_EXC_SUBCELLULAR, EMPIRICAL),
            multilevel_analysis.stats.to_numpy(SELECTION_EXC_SUBCELLULAR, MODEL_NULL),
            multilevel_analysis.stats.to_numpy(SELECTION_EXC_SUBCELLULAR, model_PS_sim_impl.posterior_descriptor)]

img = plot_bar_chart(datasets,
                     #dataset_labels = [STR_EMPIRICAL, STR_NULL, STR_P], 
                     colors=[C_EMPIRICAL, C_H0, C_M1],
                     x_labels=SELECTION_EXC_SUBCELLULAR,
                     y_lim=(1, 10**5), use_log=True, error_bars=True,
                     adjust_left=0.2, fig_size=figsize_mm_to_inch(100,40), 
                     filename=plot_folder/f"VIS_{SYNCOUNT_EXC_SUBCELLULAR}_model_PS.png")
display(img)

In [None]:
from lib.pandas_compute import *

selected_models = [MODEL_NULL, model_P_sim_impl.posterior_descriptor]
df_cellular_pairwise = get_df_cellular(multilevel_analysis.df_summary, selected_models, excluded_neuron_ids=[-1], separate_compartment=False, pre_celltype_column="pre_celltype_merged", post_celltype_column="post_celltype_merged")
compute_delta_loss(df_cellular_pairwise, MODEL_NULL, model_P_sim_impl.posterior_descriptor)

In [None]:
dataset_labels = [STR_EMPIRICAL, STR_NULL, STR_P]

SELECTION_CLUSTER_15 = SELECTION_CLUSTER[0:15]

datasets = [multilevel_analysis.stats.to_numpy(SELECTION_CLUSTER_15, EMPIRICAL), 
            multilevel_analysis.stats.to_numpy(SELECTION_CLUSTER_15, MODEL_NULL),
            multilevel_analysis.stats.to_numpy(SELECTION_CLUSTER_15, model_P_sim_impl.posterior_descriptor)]

labels = [l.split("-")[1] for l in SELECTION_CLUSTER_15]
        
filename = plot_folder/f"VIS_{SYNCLUSTERS}_model_specificity.png"
img = plot_line_chart(datasets, x_labels=labels, linestyles=[".-",".-",".-"],
                      colors=[C_EMPIRICAL, C_H0, C_M1], y_lim=(-0.1, 0.5 * 10**6), capsize=2, linewidth=0.8, marker_size=4,       
    dataset_labels=dataset_labels, error_bars=True, use_log=True, hline_y=None,  adjust_left=0.18, adjust_bottom=0.15, fig_size=figsize_mm_to_inch(60,80), filename=filename)
display(img)

In [None]:
dataset_labels = [STR_EMPIRICAL, STR_NULL, STR_P]

datasets = [multilevel_analysis.stats.to_numpy(SELECTION_MOTIF, EMPIRICAL), 
            multilevel_analysis.stats.to_numpy(SELECTION_MOTIF, MODEL_NULL),
            multilevel_analysis.stats.to_numpy(SELECTION_MOTIF, model_P_sim_impl.posterior_descriptor)]

filename = plot_folder/f"VIS_{MOTIFS}.png"
img = plot_motifs_bar_chart(datasets, 
    colors = [C_EMPIRICAL, C_H0, C_M1],
    dataset_labels=dataset_labels, quantile_low=25, quantile_high=75,
    use_mean=False, marker_size=4, capsize=0, y_axis_label="occurrences relative to random",
    error_bars=True, y_lim=(-0.5,10**6), use_log=True, fig_size=figsize_mm_to_inch(160,50), adjust_left=0.07, filename=filename)
display(img)