In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from brisc.manuscript_analysis import connectivity_matrices as conn_mat
from brisc.manuscript_analysis import bootstrapping as boot
from brisc.manuscript_analysis import load

from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

import iss_preprocess as iss

In [None]:
processed_path = iss.io.get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/cell_barcode_df.pkl"
)

cell_barcode_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
)

# Select groups to analyse cells by
starter_filtering_dict = {
    # Areas
    "area_acronym_ancestor_rank1": [
        "VISp1",
        "VISp2/3",
        "VISp4",
        "VISp5",
        "VISp6a",
        "VISp6b",
    ],
    # Cell types
    # "Annotated_clusters": [
    #     "L2/3 IT 1", "L2/3 IT 2", "L4 IT", "L4/5 IT", "L5 PT",
    #     "L6 CT", "L6 IT", "Pvalb", "Sst", "Vip"
    # ]
}

presyn_filtering_dict = {
    # Areas
    "area_acronym_ancestor_rank1": [
        "VISp1",
        "VISp2/3",
        "VISp4",
        "VISp5",
        "VISp6a",
        "VISp6b",
        # "VISal",
        # "VISl",
        # "VISli",
        # "VISpm",
        # "RSP",
        # "AUD",
        # "TEa",
        # "TH",
    ],
    # Cell types
    # "Annotated_clusters": [
    #     "L2/3 IT 1", "L2/3 IT 2", "L2/3 IT ENT", "L4 IT", "L4/5 IT",
    #     "L5 PT", "L6 CT", "L6 CT ENT", "L6 IT", "L6b", "Lamp5",
    #     "Pvalb", "Sst", "Vip"
    # ]
}

# Filter using the dicts above
cell_barcode_df = conn_mat.make_minimal_df(
    cell_barcode_df, starter_filtering_dict, presyn_filtering_dict
)


# Shuffle the barcodes assigned to each cell in the connectivity matrix

(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    mean_input_fraction_dfs,
    starter_input_fractions,
) = conn_mat.shuffle_and_compute_connectivity(
    cell_barcode_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping="area_acronym_ancestor_rank1",  # Connectivity matrix grouping
    presyn_grouping="area_acronym_ancestor_rank1",
    output_fraction=False,
)

(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    output_fraction_dfs,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    cell_barcode_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping="area_acronym_ancestor_rank1",  # Connectivity matrix grouping
    presyn_grouping="area_acronym_ancestor_rank1",
    output_fraction=True,
)

observed_confusion_matrix, mean_input_fraction, fractions_df = (
    conn_mat.compute_connectivity_matrix(
        cell_barcode_df,
        starter_grouping="area_acronym_ancestor_rank1",
        presyn_grouping="area_acronym_ancestor_rank1",
        output_fraction=False,
    )
)

observed_confusion_matrix, output_fraction, _ = conn_mat.compute_connectivity_matrix(
    cell_barcode_df,
    starter_grouping="area_acronym_ancestor_rank1",
    presyn_grouping="area_acronym_ancestor_rank1",
    output_fraction=True,
)

In [None]:
# Reorder data if necessary
row_order = [
    "VISp1",
    "VISp2/3",
    "VISp4",
    "VISp5",
    "VISp6a",
    "VISp6b",
    "VISal",
    "VISl",
    "VISli",
    "VISpm",
    "RSP",
    "AUD",
    "TEa",
    "TH",
]
col_order = [
    "VISp1",
    "VISp2/3",
    "VISp4",
    "VISp5",
    "VISp6a",
    "VISp6b",
    "VISl",
    "VISpm",
]

# Choose what data is plotted on the bubble plot
observed_cm = observed_confusion_matrix.copy()
all_null_matrices = shuffled_matrices

if row_order is None:
    row_order = list(observed_cm.index)
if col_order is None:
    col_order = list(observed_cm.columns)

row_order = [r for r in row_order if r in observed_cm.index]
col_order = [c for c in col_order if c in observed_cm.columns]
subset_observed_cm = observed_cm.loc[row_order, col_order]
n_rows, n_cols = subset_observed_cm.shape
null_array = np.array(all_null_matrices)
row_indices = [observed_cm.index.get_loc(r) for r in row_order]
col_indices = [observed_cm.columns.get_loc(c) for c in col_order]
subset_null_array = null_array[:, row_indices][:, :, col_indices]

# Compute p-values and log ratio of observed connectivity vs mean for bubble plots
mean_null = subset_null_array.mean(axis=0)
ratio_matrix = subset_observed_cm.values / (mean_null + 1e-9)
ratio_matrix = pd.DataFrame(
    ratio_matrix, index=subset_observed_cm.index, columns=subset_observed_cm.columns
)
log_ratio_matrix = np.log10(ratio_matrix)

# Mask for zero values (log(1) = 0)
mask = np.isclose(log_ratio_matrix, 0)

pval_df = conn_mat.compute_empirical_pvalues(
    subset_observed_cm, subset_null_array, two_sided=True
)
pval_df = pval_df.loc[log_ratio_matrix.index, log_ratio_matrix.columns]
# FDR correction
rejected_df, pval_corrected_df = conn_mat.benjamini_hochberg(pval_df, alpha=0.05)

In [None]:
(
    bootstrapped_results,
    shuffled_matrices,
    mean_input_fractions,
    starter_input_fractions,
) = boot.repeated_hierarchical_bootstrap_in_parallel(
    cell_barcode_df,
    n_permutations=1000,
    resample_starters=True,
    resample_presynaptic=False,
)

counts_df, mean_input_frac_df, fractions_df = conn_mat.compute_connectivity_matrix(
    cell_barcode_df,
    starter_grouping="area_acronym_ancestor_rank1",
    presyn_grouping="area_acronym_ancestor_rank1",
)
unified_shuffled = boot.unify_dfs(counts_df, shuffled_matrices)
unified_input_fracs = boot.unify_dfs(mean_input_frac_df, mean_input_fractions)

subset_null_array, subset_observed_cm = conn_mat.plot_null_histograms_square(
    mean_input_frac_df,
    unified_input_fracs,
    bins=30,
    row_label_fontsize=14,
    col_label_fontsize=14,
    # x_axis_lims=(0,0.3)
)

lower_df, upper_df = boot.compute_percentile_matrices(
    unified_input_fracs, lower_p=2.5, upper_p=97.5
)

In [None]:
# Plot Fig.1
fontsize_dict = {"title": 7, "label": 7, "tick": 5, "legend": 5}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=350)

save_path = Path("/nemo/lab/znamenskiyp/home/shared/presentations/becalick_2025")
save_fig = False
figname = "matrices"


# Coord format: [left, bottom, width, height]
ax_locations = [
    [0.05, 0.67, 0.25, 0.25],  # Row 1, Col 1
    [0.46, 0.67, 0.25, 0.25],  # Row 1, Col 2
    [0.05, 0.31, 0.25, 0.25],  # Row 2, Col 1
    [0.46, 0.31, 0.25, 0.25],  # Row 2, Col 2
    [0.05, 0.0, 0.25, 0.25],  # Row 3, Col 1
    [0.50, -0.3, 0.22, 0.5],  # Row 3, Col 2
]


# Raw counts
ax_counts = fig.add_axes(ax_locations[0])
filtered_confusion_matrix = conn_mat.plot_area_by_area_connectivity(
    mean_input_fraction_dfs,
    observed_confusion_matrix,
    fractions_df,
    input_fraction=False,
    sum_fraction=False,
    ax=ax_counts,
    transpose=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Input fraction
ax_input_fraction = fig.add_axes(ax_locations[1])
filtered_confusion_matrix = conn_mat.plot_area_by_area_connectivity(
    mean_input_fraction,
    observed_confusion_matrix,
    fractions_df,
    input_fraction=True,
    sum_fraction=False,
    ax=ax_input_fraction,
    transpose=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Output fraction
ax_input_fraction = fig.add_axes(ax_locations[2])
filtered_confusion_matrix = conn_mat.plot_area_by_area_connectivity(
    output_fraction,
    observed_confusion_matrix,
    fractions_df,
    input_fraction=True,
    sum_fraction=False,
    ax=ax_input_fraction,
    transpose=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Output fraction odds ratio
ax_input_fraction = fig.add_axes(ax_locations[3])
filtered_confusion_matrix = conn_mat.plot_area_by_area_connectivity(
    output_fraction,
    observed_confusion_matrix,
    fractions_df,
    input_fraction=True,
    sum_fraction=False,
    ax=ax_input_fraction,
    transpose=False,
    odds_ratio=True,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Bubble plot input fraction
ax_bubble_plot_input_frac = fig.add_axes(ax_locations[4])
conn_mat.bubble_plot(
    log_ratio_matrix,
    pval_corrected_df,
    alpha=0.05,
    size_scale=200,
    ax=ax_bubble_plot_input_frac,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

ax_bubble_plot_input_frac = fig.add_axes(ax_locations[5])
boot.plot_confidence_intervals(
    mean_input_frac_df,
    lower_df,
    upper_df,
    ax_bubble_plot_input_frac,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)