In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from brisc.manuscript_analysis import connectivity_matrices as conn_mat
from brisc.manuscript_analysis import distance_between_cells as dist
from brisc.manuscript_analysis import bootstrapping as boot
from brisc.manuscript_analysis import load
from brisc.exploratory_analysis.plot_summary_for_all_bc import compute_flatmap_coors

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd

matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

from iss_preprocess.io import get_processed_path

In [None]:
processed_path = get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/cell_barcode_df.pkl"
)

cells_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
)

flat_coors = compute_flatmap_coors(cells_df)
cells_df["flatmap_x"] = flat_coors[:, 0]
cells_df["flatmap_y"] = flat_coors[:, 1]
cells_df["flatmap_z"] = flat_coors[:, 2]

conn_mat.match_barcodes(cells_df)
dist.add_connection_distances(cells_df, cols=["flatmap_x", "flatmap_y"])

In [None]:
non_excitatory_cell_types = ["Vip", "Sst", "Pvalb", "Lamp5", "VLMC"]
non_excitatory = cells_df["best_cluster"].isin(non_excitatory_cell_types)

cells_df = cells_df[~non_excitatory]
layers = ["L2/3", "L4", "L5", "L6a", "L6b"]
# local connections
cells_df = cells_df[cells_df["cortical_layer"].notnull()]
# change any cells with cells_df["cortical_layer"] == "L1" to "L2,3"
cells_df.loc[cells_df["cortical_layer"] == "L1", "cortical_layer"] = "L2/3"
cells_df = cells_df[cells_df["cortical_layer"].apply(lambda layer: layer in layers)]
cells_df = cells_df[cells_df["distances"].apply(lambda dist: np.max(dist)) < 100]

grouping = "cortical_layer" #"area_acronym_ancestor_rank1"
# Shuffle the barcodes assigned to each cell in the connectivity matrix
(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    mean_input_fraction_dfs,
    starter_input_fractions,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    cells_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=False,
)

(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    output_fraction_dfs,
    _,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    cells_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

connectivity_matrix, mean_input_fraction, fractions_df, _ = (
    conn_mat.compute_connectivity_matrix(
        cells_df,
        starter_grouping=grouping,
        presyn_grouping=grouping,
        output_fraction=False,
    )
)

connectivity_matrix, output_fraction, _, _ = conn_mat.compute_connectivity_matrix(
    cells_df,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

In [None]:
input_fraction_log_ratio, input_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
            mean_input_fraction, 
            np.array(mean_input_fraction_dfs),
            row_order=layers,
            col_order=layers,
    )
)

output_fraction_log_ratio, output_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
            output_fraction, 
            np.array(output_fraction_dfs),
            row_order=layers,
            col_order=layers,
    )
)

In [None]:
# (
#     _,
#     shuffled_matrices,
#     bootstrap_input_fractions,
#     starter_input_fractions,
# ) = boot.repeated_hierarchical_bootstrap_in_parallel(
#     cells_df,
#     n_permutations=1000,
#     resample_starters=True,
#     resample_presynaptic=False,
#     starter_grouping="cortical_layer",#"area_acronym_ancestor_rank1",
#     presyn_grouping="cortical_layer",#"area_acronym_ancestor_rank1",    
# )
layers = ["L2/3", "L4", "L5", "L6a", "L6b"]

counts_df, mean_input_frac_df, fractions_df, _ = conn_mat.compute_connectivity_matrix(
    cells_df,
    starter_grouping="cortical_layer",#"area_acronym_ancestor_rank1",
    presyn_grouping="cortical_layer",#"area_acronym_ancestor_rank1",
)

from tqdm import tqdm
# Create bootstrap samples
nboot = 1000
bootstrap_samples = []

for i in tqdm(range(nboot)):
    rows = []
    for layer, group in fractions_df.groupby('cortical_layer', observed=True):
        this_layer = group.sample(n=len(group), replace=True)[layers].mean()
        this_layer.name = layer
        rows.append(this_layer) # Resample with replacement
    bootstrap_samples.append(pd.concat(rows, axis=1))
bootstrap_samples = np.array(bootstrap_samples)    
lower_df = pd.DataFrame(data=np.quantile(bootstrap_samples, 0.025, axis=0), index=layers, columns=layers)
upper_df = pd.DataFrame(data=np.quantile(bootstrap_samples, 0.975, axis=0), index=layers, columns=layers)
# unified_shuffled = boot.reorder_dfs(shuffled_matrices, counts_df)
# bootstrap_input_fractions = boot.reorder_dfs(
#     bootstrap_input_fractions, mean_input_frac_df
# )

# lower_df, upper_df = boot.compute_percentile_matrices(
#     bootstrap_input_fractions, lower_p=2.5, upper_p=97.5
# )

# mean_input_frac_df, bootstrap_input_fractions = conn_mat.filter_matrices(
#     mean_input_frac_df,
#     np.array(bootstrap_input_fractions),
#     row_order=["L1", "L2/3", "L4", "L5", "L6a", "L6b"],
#     col_order=["L1", "L2/3", "L4", "L5", "L6a", "L6b"]
# )
# conn_mat.plot_null_histograms_square(
#     mean_input_frac_df,
#     bootstrap_input_fractions,
#     bins=30,
#     row_label_fontsize=14,
#     col_label_fontsize=14,
# )

In [None]:
areas = {
    "L2/3": "2/3",
    "L4": "4",
    "L5": "5",
    "L6a": "6a",
    "L6b": "6b",
}
presynaptic_counts = conn_mat.reorganise_matrix(counts_df).sum(axis=1)
starter_counts = fractions_df.value_counts("cortical_layer").rename(
    index=areas
)

In [None]:
# Plot Fig.1
fontsize_dict = {"title": 7, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=200)

save_path = Path("Z:/home/shared/presentations/becalick_2025")
save_fig = True
figname = "matrices"

# Raw counts
ax_counts = fig.add_axes([0.07, 0.75, 0.2, 0.2])
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(counts_df, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_counts,
    input_fraction=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Input fraction
ax_input_fraction = fig.add_axes([0.29, 0.75, 0.2, 0.2])
ax_input_fraction_cb = fig.add_axes([0.50, 0.75, 0.01, 0.05])

conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(mean_input_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_input_fraction,
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
    cbax=ax_input_fraction_cb,
    cbar_label="Input\nfraction",
)
ax_input_fraction.set_ylabel("")
ax_input_fraction.set_yticks([])

ax_input_fraction_bars = fig.add_axes([0.62, 0.77, 0.25, 0.18])
boot.plot_confidence_intervals(
    conn_mat.reorganise_matrix(mean_input_frac_df, areas=areas),
    conn_mat.reorganise_matrix(lower_df, areas=areas),
    conn_mat.reorganise_matrix(upper_df, areas=areas),
    ax_input_fraction_bars,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    orientation="horizontal",
)

# Bubble plot input fraction
ax_bubble_plot_input_frac = fig.add_axes([0.07, 0.45, 0.2, 0.2])
ax_bubble_plot_input_frac_cb = fig.add_axes([0.29, 0.6, 0.01, 0.05])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(input_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(input_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_input_frac,
    cbax=ax_bubble_plot_input_frac_cb,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

# Output fraction
ax_output_fraction = fig.add_axes([0.45, 0.45, 0.2, 0.2])
ax_output_fraction_cb = fig.add_axes([0.66, 0.45, 0.01, 0.05])
# output_fraction = conn_mat.reorganise_matrix(output_fraction)
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(output_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_output_fraction,
    cbax=ax_output_fraction_cb,
    cbar_label="Output\nfraction",
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
)

# Bubble plot input fraction
ax_bubble_plot_output_frac = fig.add_axes([0.78, 0.45, 0.2, 0.2])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(output_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(output_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_output_frac,
    show_legend=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

# # Output fraction odds ratio
# output_odds_ratio = np.log(
#     conn_mat.compute_odds_ratio(
#         output_fraction, starter_counts
#     )
# )
# ax_output_odds_ratio = fig.add_axes(ax_locations[3])
# conn_mat.plot_area_by_area_connectivity(
#     conn_mat.reorganise_matrix(output_odds_ratio),
#     starter_counts,
#     presynaptic_counts,
#     ax_output_odds_ratio,
#     input_fraction=True,
#     odds_ratio=True,
#     label_fontsize=fontsize_dict["label"],
#     tick_fontsize=fontsize_dict["tick"],
#     line_width=line_width,
# )


# ax_bubble_plot_output_frac = fig.add_axes(ax_locations[5])
# conn_mat.bubble_plot(
#     output_fraction_log_ratio,
#     output_fraction_pval,
#     alpha=0.05,
#     size_scale=200,
#     ax=ax_bubble_plot_output_frac,
#     label_fontsize=fontsize_dict["label"],
#     tick_fontsize=fontsize_dict["tick"],
# )

if save_fig:
    fig.savefig(save_path / f"conn_mat.pdf", format="pdf")
    fig.savefig(save_path / f"conn_mat.png", format="png")

In [None]:
from brisc.manuscript_analysis import distance_between_cells as dist_cells

relative_presyn_coords_flatmap, distancess_flatmap, starters_df = dist_cells.determine_presynaptic_distances(cells_df, col_prefix="flatmap_")

plt.figure(figsize=(10, 5))
nboot = 1000
dist = np.zeros(len(layers))
dist_boot = np.zeros((len(layers), nboot))

def get_median_dist(starters):
    relative_ap = np.hstack(starters["presynaptic_coors_relative"].values)[0][:, 0]
    relative_ml = np.hstack(starters["presynaptic_coors_relative"].values)[0][:, 1]
    distances = ((relative_ml**2 + relative_ap**2) ** 0.5).astype(float)
    return np.nanmedian(distances)

for i, layer in enumerate(layers):
    
    plt.subplot(2, 6, i+1)
    this_layer = starters_df[starters_df["cortical_layer"]==layer]
    relative_ap = np.hstack(this_layer["presynaptic_coors_relative"].values)[0][:, 0]
    relative_ml = np.hstack(this_layer["presynaptic_coors_relative"].values)[0][:, 1]

    absolute_depth = np.hstack(this_layer["presynaptic_coors"].values)[0][:, 2]
    for iboot in range(nboot):
        dist_boot[i, iboot] = get_median_dist(this_layer.sample(replace=True, n=len(this_layer), axis=0))
    plt.plot(relative_ml, absolute_depth, ',k', alpha=0.5)
    # plt.plot(relative_ap, relative_ml, ',k', alpha=0.5)
    plt.plot(np.zeros(len(this_layer)), this_layer["flatmap_z"], '.', color="r")

    plt.gca().invert_yaxis()
    plt.axis("equal")
    plt.axis("off")
    plt.xlim([-100, 100])
    plt.ylim([100, 0])

    distances = ((relative_ml**2 + relative_ap**2) ** 0.5).astype(float)
    dist[i] = np.nanmedian(distances)
    print(dist[i])


In [None]:
# Adding a panel of connection summary
# We save it separately to manually edit it to biutify it
import graphviz


save_directory = get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/connectivity_schematics"
)
save_directory.mkdir(exist_ok=True)

layer_names = ["2/3", "4", "5", "6a", "6b"]
matx_names = ["VISp2/3", "VISp4", "VISp5", "VISp6a", "VISp6b"]
positions = ["2,5!", "0,4!", "2,3!", "0,2!", "2,1!"]
node_properties = dict(node_color="#aaaaaa", node_shape="circle", 
                       style="filled", height="1", width="1")

dot = graphviz.Digraph(
    "connection_matrix",
    comment="The cortical microcircuit",
    engine="fdp",
    graph_attr=dict(bgcolor="#ffffff"),
)
for i_layer, (name, pos) in enumerate(zip(layer_names, positions)):
    dot.node(f"{i_layer+1}", name, pos=pos)

cmap = plt.get_cmap("RdBu_r", 20)
colors = []
for i in range(cmap.N):
    rgba = cmap(i)
    # rgb2hex accepts rgb or rgba
    colors.append(f"{matplotlib.colors.rgb2hex(rgba)}")
col_val = np.linspace(-1, 1, len(colors))

valid = mean_input_fraction > 0.2
max_alpha = np.nanmax(1/(upper_df - lower_df)[valid])

for istart, starter_layer in enumerate(matx_names):
    for ipres, pres_layer in enumerate(matx_names):
        connection_strength = mean_input_fraction.loc[pres_layer, starter_layer]
        if connection_strength < 0.20:
            continue
        # pval = input_fraction_pval.loc[pres_layer, starter_layer]
        # lograt = input_fraction_log_ratio.loc[pres_layer, starter_layer]
        # pval = np.sign(lograt) * -np.log10(pval)
        conf_range = np.diff(
            [
                lower_df.loc[pres_layer, starter_layer],
                upper_df.loc[pres_layer, starter_layer],
            ]
        )[0]
        col = "#000000"
        # Make strength into transparency
        print(
            f"{pres_layer} --> {starter_layer}: {1-conf_range:.2f} or {1/conf_range:.2f} "
        )
        scaled_alpha = np.clip((1 / conf_range) * 255 / max_alpha, 0, 255).astype(int)
        hex_alpha = f"{scaled_alpha:02x}"
        col += hex_alpha
        dot.edge(
            str(ipres + 1),
            str(istart + 1),
            penwidth=str(connection_strength * 40),
            arrowsize=str(connection_strength),
            color=col,
            arrowhead="normal",
        )

dot.render(directory=save_directory)
dot