In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from brisc.manuscript_analysis import connectivity_matrices as conn_mat
from brisc.manuscript_analysis import distance_between_cells as dist
from brisc.manuscript_analysis import bootstrapping as boot
from brisc.manuscript_analysis import load
from brisc.exploratory_analysis.plot_summary_for_all_bc import compute_flatmap_coors
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd

matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

from iss_preprocess.io import get_processed_path

In [None]:
processed_path = get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/cell_barcode_df.pkl"
)

cells_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
)

flat_coors = compute_flatmap_coors(cells_df, distance_cutoff=150)
cells_df["flatmap_x"] = flat_coors[:, 0]
cells_df["flatmap_y"] = flat_coors[:, 1]
cells_df["flatmap_z"] = flat_coors[:, 2]

conn_mat.match_barcodes(cells_df)
dist.add_connection_distances(cells_df, cols=["flatmap_x", "flatmap_y"])

mapping = {
    "excitatory": [
        "Car3",
        "L2/3 IT 1",
        "L2/3 IT 2",
        "L2/3 RSP",
        "L4 IT",
        "L4 RSP",
        "L5 IT",
        "L5 NP",
        "L5 PT",
        "L5/6 IT",
        "L6 CT",
        "L6b",
    ],
    "unassigned": ["Unassigned", "VLMC", "Zero_correlation"],
    "Sst": ["Sst"],
    "Vip": ["Vip"],
    "Lamp5": ["Lamp5"],
    "Pvalb": ["Pvalb"],
}
flattened_mapping = {key: value for value, keys in mapping.items() for key in keys}

# Apply the mapping to create the new column
cells_df["broad_cell_class"] = cells_df["best_cluster"].map(flattened_mapping)

interneuron_df = cells_df[
    cells_df["broad_cell_class"].isin(["Sst", "Vip", "Lamp5", "Pvalb"])
    & (cells_df["best_score"] > 0.3)
    & (cells_df["knn_agree_conf"] > 0.3)
    & (cells_df["raw_gene_counts"] > 2)
]

In [None]:
grouping = "broad_cell_class"
# Shuffle the barcodes assigned to each cell in the connectivity matrix
(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    mean_input_fraction_dfs,
    starter_input_fractions,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    interneuron_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=False,
)

(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    output_fraction_dfs,
    _,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    interneuron_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

connectivity_matrix, mean_input_fraction, fractions_df, _ = (
    conn_mat.compute_connectivity_matrix(
        interneuron_df,
        starter_grouping=grouping,
        presyn_grouping=grouping,
        output_fraction=False,
    )
)

connectivity_matrix, output_fraction, _, _ = conn_mat.compute_connectivity_matrix(
    interneuron_df,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

In [None]:
cell_types = ["Sst", "Vip", "Lamp5", "Pvalb"]

input_fraction_log_ratio, input_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        mean_input_fraction,
        np.array(mean_input_fraction_dfs),
        row_order=cell_types,
        col_order=cell_types,
    )
)

output_fraction_log_ratio, output_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        output_fraction,
        np.array(output_fraction_dfs),
        row_order=cell_types,
        col_order=cell_types,
    )
)

counts_df, mean_input_frac_df, fractions_df, _ = conn_mat.compute_connectivity_matrix(
    interneuron_df,
    starter_grouping=grouping,
    presyn_grouping=grouping,
)

from tqdm import tqdm

# Create bootstrap samples
nboot = 1000
bootstrap_samples = []

for i in tqdm(range(nboot)):
    rows = []
    for layer, group in fractions_df.groupby(grouping, observed=True):
        this_layer = group.sample(n=len(group), replace=True)[cell_types].mean()
        this_layer.name = layer
        rows.append(this_layer)  # Resample with replacement
    bootstrap_samples.append(pd.concat(rows, axis=1))
bootstrap_samples = np.array(bootstrap_samples)
lower_df = pd.DataFrame(
    data=np.quantile(bootstrap_samples, 0.025, axis=0),
    index=cell_types,
    columns=cell_types,
)
upper_df = pd.DataFrame(
    data=np.quantile(bootstrap_samples, 0.975, axis=0),
    index=cell_types,
    columns=cell_types,
)

areas = {
    # "excitatory": "Exc",
    "Sst": "Sst",
    "Vip": "Vip",
    "Lamp5": "Lamp5",
    "Pvalb": "Pvalb",
}
presynaptic_counts = conn_mat.reorganise_matrix(counts_df, areas=areas).sum(axis=1)
starter_counts = fractions_df.value_counts(grouping).rename(index=areas)

In [None]:
from brisc.manuscript_analysis import distance_between_cells as dist_cells

relative_presyn_coords_flatmap, distancess_flatmap, starters_df = (
    dist_cells.determine_presynaptic_distances(cells_df, col_prefix="flatmap_")
)

In [None]:
# Plot figure

fontsize_dict = {"title": 7, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=200)

save_path = get_processed_path("presentations/becalick_2025")
save_fig = True
figname = "matrices"


# Raw counts
if True:
    ax_counts = fig.add_axes([0.07, 0.75, 0.2, 0.2])
    conn_mat.plot_area_by_area_connectivity(
        conn_mat.reorganise_matrix(counts_df, areas=areas),
        starter_counts,
        presynaptic_counts,
        ax_counts,
        input_fraction=False,
        odds_ratio=False,
        label_fontsize=fontsize_dict["label"],
        tick_fontsize=fontsize_dict["tick"],
        line_width=line_width,
    )

# Input fraction
if True:
    ax_input_fraction = fig.add_axes([0.29, 0.75, 0.2, 0.2])
    ax_input_fraction_cb = fig.add_axes([0.50, 0.75, 0.01, 0.05])
    conn_mat.plot_area_by_area_connectivity(
        conn_mat.reorganise_matrix(mean_input_fraction, areas=areas),
        starter_counts,
        presynaptic_counts,
        ax_input_fraction,
        input_fraction=True,
        odds_ratio=False,
        label_fontsize=fontsize_dict["label"],
        tick_fontsize=fontsize_dict["tick"],
        line_width=line_width,
        show_counts=False,
        cbax=ax_input_fraction_cb,
        cbar_label="Input\nfraction",
    )
    ax_input_fraction.set_ylabel("")
    ax_input_fraction.set_yticks([])


# Schematics
ax_schematics = fig.add_axes([0.61, 0.77, 0.16, 0.18])
cax_schematics = fig.add_axes([0.78, 0.75, 0.01, 0.05])
positions = {l: p for l, p in zip(cell_types, [(1.5, 2), (0, 2), (1.5, 0), (0, 0)])}
fig, ax, cbar = conn_mat.connectivity_diagram_mpl(
    mean_input_fraction,
    lower_df,
    upper_df,
    connection_names=cell_types,
    positions=positions,
    display_names=cell_types,
    node_style=dict(facecolor="Lightgray", radius=0.5, fontsize=fontsize_dict["title"]),
    min_fraction_cutoff=0.2,
    ci_to_alpha=False,
    ci_cmap="plasma_r",
    edge_width_scale=10,
    arrow_head_scale=20,
    arrow_style=dict(connectionstyle="Arc3, rad=-0.2", ec="k"),
    vmin=0,
    vmax=1,
    ax=ax_schematics,
    cax=cax_schematics,
)
cbar.set_ticks([0, 0.5, 1])

# Long range
ax_long_connections = fig.add_axes([0.78, 0.77, 0.2, 0.18])
for cell_type in cell_types:
    type_df = cells_df.query(f"broad_cell_class == '{cell_type}'")
    st_type = type_df.query('is_starter == True')



In [None]:
starters_df.shape

In [None]:

starters_df.broad_cell_class

In [None]:
fig = plt.figure()
ax_long_connections = fig.add_subplot(1,1,1)
for cell_type in cell_types:
    st_type = starters_df.query(f"broad_cell_class == '{cell_type}'")
    rel_pres = np.hstack(st_type.presynaptic_coors_relative.values)[0,:,:2].astype(float)
    dst_flat = np.linalg.norm(rel_pres, axis=1)
    ax_long_connections.hist(dst_flat, histtype='step', log=True, label=cell_type, density=True)
    
ax_long_connections.legend(loc='upper right')



In [None]:
rel_pres = np.hstack(st_type.presynaptic_coors_relative.values)[0,:,:2].astype(float)
dst_flat = np.linalg.norm(rel_pres, axis=1)
dst_flat.shape

In [None]:
rel_pres

In [None]:
# Plot full figure with bubble plots

fontsize_dict = {"title": 7, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=200)

save_path = get_processed_path("presentations/becalick_2025")
save_fig = True
figname = "matrices"

# Raw counts
ax_counts = fig.add_axes([0.07, 0.75, 0.2, 0.2])
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(counts_df, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_counts,
    input_fraction=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Input fraction
ax_input_fraction = fig.add_axes([0.29, 0.75, 0.2, 0.2])
ax_input_fraction_cb = fig.add_axes([0.50, 0.75, 0.01, 0.05])

conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(mean_input_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_input_fraction,
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
    cbax=ax_input_fraction_cb,
    cbar_label="Input\nfraction",
)
ax_input_fraction.set_ylabel("")
ax_input_fraction.set_yticks([])

ax_input_fraction_bars = fig.add_axes([0.62, 0.77, 0.25, 0.18])
boot.plot_confidence_intervals(
    conn_mat.reorganise_matrix(mean_input_frac_df, areas=areas),
    conn_mat.reorganise_matrix(lower_df, areas=areas),
    conn_mat.reorganise_matrix(upper_df, areas=areas),
    ax_input_fraction_bars,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    orientation="horizontal",
)

# Bubble plot input fraction
ax_bubble_plot_input_frac = fig.add_axes([0.07, 0.45, 0.2, 0.2])
ax_bubble_plot_input_frac_cb = fig.add_axes([0.29, 0.6, 0.01, 0.05])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(input_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(input_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_input_frac,
    cbax=ax_bubble_plot_input_frac_cb,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

# Output fraction
ax_output_fraction = fig.add_axes([0.45, 0.45, 0.2, 0.2])
ax_output_fraction_cb = fig.add_axes([0.66, 0.45, 0.01, 0.05])
# output_fraction = conn_mat.reorganise_matrix(output_fraction)
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(output_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_output_fraction,
    cbax=ax_output_fraction_cb,
    cbar_label="Output\nfraction",
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
)

# Bubble plot input fraction
ax_bubble_plot_output_frac = fig.add_axes([0.78, 0.45, 0.2, 0.2])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(output_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(output_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_output_frac,
    show_legend=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

In [None]:
# Adding a panel of connection summary
# We save it separately to manually edit it to biutify it since the arrow heads look
# very ugly

save_directory = get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/connectivity_schematics"
)
save_directory.mkdir(exist_ok=True)

layer_names = ["Lamp5", "Sst", "Pvalb", "Vip"]
matx_names = layer_names
positions = ["0,6!", "2,5!", "0,4!", "2,3!", "0,2!", "2,1!"]
node_properties = dict(
    node_color="#aaaaaa", node_shape="circle", style="filled", height="1", width="1"
)


dot = conn_mat.connectivity_diagram(
    mean_input_fraction,
    lower_df,
    upper_df,
    node_names=layer_names,
    matx_names=matx_names,
    positions=positions,
    min_fraction_cutoff=0.2,
    node_properties=node_properties,
)
dot.render(directory=save_directory, filename="inhibitory_neurons", format="svg")
print(f"Saved graph in {save_directory}")
dot

In [None]:
fig = plt.figure(figsize=(4 * cm, 4 * cm))
ax = fig.add_axes([0, 0, 1, 1])
layer_names = ["Lamp5", "Sst", "Pvalb", "Vip"]
positions = ["0,2!", "2,2!", "0,0!", "2,0!"]
fig, ax, (vmin, vmax) = conn_mat.connectivity_diagram_mpl(
    mean_input_fraction,
    lower_df,
    upper_df,
    node_names=layer_names,
    positions=positions,
    node_style=dict(facecolor="Lightgray", radius=0.4, fontsize=fontsize_dict['title']),
    min_fraction_cutoff=0.2,
    ci_to_alpha=False,
    ci_cmap="plasma",
    edge_width_scale=10,
    arrow_head_scale=10,
    arrow_style=dict(arrowstyle="-", connectionstyle="Arc3, rad=-0.2"),
    vmin=0.2,
    ax=ax,
)
fig.savefig(save_directory / "connection_matrix_interneuron.svg")

In [None]:
fontsize_dict

In [None]:
save_directory