In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from brisc.manuscript_analysis import connectivity_matrices as conn_mat
from brisc.manuscript_analysis import bootstrapping as boot
from brisc.manuscript_analysis import load

from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

from iss_preprocess.io import get_processed_path

In [None]:
processed_path = get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/cell_barcode_df.pkl"
)

cell_barcode_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
)

# Select groups to analyse cells by
starter_filtering_dict = {
    "area_acronym_ancestor_rank1": [
        "VISp1",
        "VISp2/3",
        "VISp4",
        "VISp5",
        "VISp6a",
        "VISp6b",
    ],
}

presyn_filtering_dict = {
    "area_acronym_ancestor_rank1": [
        "VISp1",
        "VISp2/3",
        "VISp4",
        "VISp5",
        "VISp6a",
        "VISp6b",
    ],
}

# Filter using the dicts above
cell_barcode_df = conn_mat.make_minimal_df(
    cell_barcode_df, starter_filtering_dict, presyn_filtering_dict
)


# Shuffle the barcodes assigned to each cell in the connectivity matrix
(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    mean_input_fraction_dfs,
    starter_input_fractions,
) = conn_mat.shuffle_and_compute_connectivity(
    cell_barcode_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping="area_acronym_ancestor_rank1",  # Connectivity matrix grouping
    presyn_grouping="area_acronym_ancestor_rank1",
    output_fraction=False,
)

(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    output_fraction_dfs,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    cell_barcode_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping="area_acronym_ancestor_rank1",  # Connectivity matrix grouping
    presyn_grouping="area_acronym_ancestor_rank1",
    output_fraction=True,
)

connectivity_matrix, mean_input_fraction, fractions_df = (
    conn_mat.compute_connectivity_matrix(
        cell_barcode_df,
        starter_grouping="area_acronym_ancestor_rank1",
        presyn_grouping="area_acronym_ancestor_rank1",
        output_fraction=False,
    )
)

connectivity_matrix, output_fraction, _ = conn_mat.compute_connectivity_matrix(
    cell_barcode_df,
    starter_grouping="area_acronym_ancestor_rank1",
    presyn_grouping="area_acronym_ancestor_rank1",
    output_fraction=True,
)

In [None]:
input_fraction_log_ratio, input_fraction_pval = conn_mat.compare_to_shuffle(
    mean_input_fraction, np.array(mean_input_fraction_dfs)
)

output_fraction_log_ratio, output_fraction_pval = conn_mat.compare_to_shuffle(
    output_fraction, np.array(output_fraction_dfs)
)

In [None]:
(
    _,
    shuffled_matrices,
    bootstrap_input_fractions,
    starter_input_fractions,
) = boot.repeated_hierarchical_bootstrap_in_parallel(
    cell_barcode_df,
    n_permutations=1000,
    resample_starters=True,
    resample_presynaptic=False,
)

counts_df, mean_input_frac_df, fractions_df = conn_mat.compute_connectivity_matrix(
    cell_barcode_df,
    starter_grouping="area_acronym_ancestor_rank1",
    presyn_grouping="area_acronym_ancestor_rank1",
)
unified_shuffled = boot.reorder_dfs(shuffled_matrices, counts_df)
bootstrap_input_fractions = boot.reorder_dfs(
    bootstrap_input_fractions, mean_input_frac_df
)

lower_df, upper_df = boot.compute_percentile_matrices(
    bootstrap_input_fractions, lower_p=2.5, upper_p=97.5
)

mean_input_frac_df, bootstrap_input_fractions = conn_mat.filter_matrices(
    mean_input_frac_df,
    np.array(bootstrap_input_fractions),
)
conn_mat.plot_null_histograms_square(
    mean_input_frac_df,
    bootstrap_input_fractions,
    bins=30,
    row_label_fontsize=14,
    col_label_fontsize=14,
)

In [None]:
areas = {
    "VISp1": "1",
    "VISp2/3": "2/3",
    "VISp4": "4",
    "VISp5": "5",
    "VISp6a": "6a",
    "VISp6b": "6b",
}
presynaptic_counts = conn_mat.reorganise_matrix(counts_df).sum(axis=1)
starter_counts = fractions_df.value_counts("area_acronym_ancestor_rank1").rename(
    index=areas
)

In [None]:
# Plot Fig.1
fontsize_dict = {"title": 7, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=200)

save_path = Path("Z:/home/shared/presentations/becalick_2025")
save_fig = False
figname = "matrices"


# # Coord format: [left, bottom, width, height]
# ax_locations = [
#     [0.05, 0.67, 0.25, 0.25],  # Row 1, Col 1
#     [0.46, 0.67, 0.25, 0.25],  # Row 1, Col 2
#     [0.05, 0.31, 0.25, 0.25],  # Row 2, Col 1
#     [0.46, 0.31, 0.25, 0.25],  # Row 2, Col 2
#     [0.05, 0.0, 0.25, 0.25],  # Row 3, Col 1
#     [0.50, 0, 0.25, 0.25],  # Row 3, Col 2
# ]


# Raw counts
ax_counts = fig.add_axes([0.07, 0.75, 0.2, 0.2])
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(counts_df),
    starter_counts,
    presynaptic_counts,
    ax_counts,
    input_fraction=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Input fraction
ax_input_fraction = fig.add_axes([0.355, 0.75, 0.2, 0.2])
ax_input_fraction_cb = fig.add_axes([0.565, 0.75, 0.01, 0.05])

conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(mean_input_fraction),
    starter_counts,
    presynaptic_counts,
    ax_input_fraction,
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
    cbax=ax_input_fraction_cb,
    cbar_label="Input\nfraction",
)
ax_input_fraction.set_ylabel("")
ax_input_fraction.set_yticklabels([])

ax_input_fraction_bars = fig.add_axes([0.68, 0.77, 0.3, 0.18])
boot.plot_confidence_intervals(
    conn_mat.reorganise_matrix(mean_input_frac_df),
    conn_mat.reorganise_matrix(lower_df),
    conn_mat.reorganise_matrix(upper_df),
    ax_input_fraction_bars,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    orientation="horizontal",
)

# Bubble plot input fraction
ax_bubble_plot_input_frac = fig.add_axes([0.07, 0.45, 0.2, 0.2])
ax_bubble_plot_input_frac_cb = fig.add_axes([0.29, 0.6, 0.01, 0.05])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(input_fraction_log_ratio),
    conn_mat.reorganise_matrix(input_fraction_pval),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_input_frac,
    cbax=ax_bubble_plot_input_frac_cb,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

# Output fraction
ax_output_fraction = fig.add_axes([0.45, 0.45, 0.2, 0.2])
ax_output_fraction_cb = fig.add_axes([0.66, 0.45, 0.01, 0.05])
# output_fraction = conn_mat.reorganise_matrix(output_fraction)
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(output_fraction),
    starter_counts,
    presynaptic_counts,
    ax_output_fraction,
    cbax=ax_output_fraction_cb,
    cbar_label="Output\nfraction",
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
)

# Bubble plot input fraction
ax_bubble_plot_output_frac = fig.add_axes([0.78, 0.45, 0.2, 0.2])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(output_fraction_log_ratio),
    conn_mat.reorganise_matrix(output_fraction_pval),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_output_frac,
    show_legend=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

# # Output fraction odds ratio
# output_odds_ratio = np.log(
#     conn_mat.compute_odds_ratio(
#         output_fraction, starter_counts
#     )
# )
# ax_output_odds_ratio = fig.add_axes(ax_locations[3])
# conn_mat.plot_area_by_area_connectivity(
#     conn_mat.reorganise_matrix(output_odds_ratio),
#     starter_counts,
#     presynaptic_counts,
#     ax_output_odds_ratio,
#     input_fraction=True,
#     odds_ratio=True,
#     label_fontsize=fontsize_dict["label"],
#     tick_fontsize=fontsize_dict["tick"],
#     line_width=line_width,
# )


# ax_bubble_plot_output_frac = fig.add_axes(ax_locations[5])
# conn_mat.bubble_plot(
#     output_fraction_log_ratio,
#     output_fraction_pval,
#     alpha=0.05,
#     size_scale=200,
#     ax=ax_bubble_plot_output_frac,
#     label_fontsize=fontsize_dict["label"],
#     tick_fontsize=fontsize_dict["tick"],
# )

if save_fig:
    fig.savefig(save_path / f"conn_mat.pdf", format="pdf")
    fig.savefig(save_path / f"conn_mat.png", format="png")

In [None]:
# Adding a panel of connection summary
# We save it separately to manually edit it to biutify it
import graphviz


save_directory = get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/connectivity_schematics"
)
save_directory.mkdir(exist_ok=True)

layer_names = ["1", "2/3", "4", "5", "6a", "6b"]
matx_names = ["VISp1", "VISp2/3", "VISp4", "VISp5", "VISp6a", "VISp6b"]
positions = ["0,6!", "2,5!", "0,4!", "2,3!", "0,2!", "2,1!"]
node_properties = dict(node_color="#aaaaaa", node_shape="circle", 
                       style="filled", height="1", width="1")

dot = graphviz.Digraph(
    "connection_matrix",
    comment="The cortical microcircuit",
    engine="fdp",
    graph_attr=dict(bgcolor="#ffffff"),
)
for i_layer, (name, pos) in enumerate(zip(layer_names, positions)):
    dot.node(str(f"{i_layer+1}"), name, pos=pos, **node_properties)

cmap = plt.get_cmap("RdBu_r", 20)
colors = []
for i in range(cmap.N):
    rgba = cmap(i)
    # rgb2hex accepts rgb or rgba
    colors.append(f"{matplotlib.colors.rgb2hex(rgba)}")
col_val = np.linspace(-1, 1, len(colors))

valid = mean_input_fraction > 0.2
max_alpha = np.nanmax(1/(upper_df - lower_df)[valid])

for istart, starter_layer in enumerate(matx_names):
    for ipres, pres_layer in enumerate(matx_names):
        connection_strength = mean_input_fraction.loc[pres_layer, starter_layer]
        if connection_strength < 0.20:
            continue
        # pval = input_fraction_pval.loc[pres_layer, starter_layer]
        # lograt = input_fraction_log_ratio.loc[pres_layer, starter_layer]
        # pval = np.sign(lograt) * -np.log10(pval)
        conf_range = np.diff(
            [
                lower_df.loc[pres_layer, starter_layer],
                upper_df.loc[pres_layer, starter_layer],
            ]
        )[0]
        col = "#000000"
        # Make strength into transparency
        print(
            f"{pres_layer} --> {starter_layer}: {1-conf_range:.2f} or {1/conf_range:.2f} "
        )
        scaled_alpha = np.clip((1 / conf_range) * 255 / max_alpha, 0, 255).astype(int)
        hex_alpha = f"{scaled_alpha:02x}"
        col += hex_alpha
        dot.edge(
            str(ipres + 1),
            str(istart + 1),
            penwidth=str(connection_strength * 40),
            arrowsize=str(connection_strength),
            color=col,
            arrowhead="normal",
        )

dot.render(directory=save_directory)
dot