In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from brisc.manuscript_analysis import connectivity_matrices as conn_mat
from brisc.manuscript_analysis import distance_between_cells as dist
from brisc.manuscript_analysis import bootstrapping as boot
from brisc.manuscript_analysis import load
from brisc.manuscript_analysis.flatmap_projection import compute_flatmap_coors
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd

matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

from iss_preprocess.io import get_processed_path

In [None]:
processed_path = get_processed_path("becalia_rabies_barseq/BRAC8498.3e/analysis")

cells_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
    error_correction_ds_name="BRAC8498.3e_error_corrected_barcodes_26",
)

In [None]:
cells_df[["best_cluster", "raw_gene_counts"]]

In [None]:
flat_coors = compute_flatmap_coors(cells_df, distance_cutoff=150)
cells_df["flatmap_x"] = flat_coors[:, 0]
cells_df["flatmap_y"] = flat_coors[:, 1]
cells_df["flatmap_z"] = flat_coors[:, 2]

conn_mat.match_barcodes(cells_df)
dist.add_connection_distances(cells_df, cols=["flatmap_x", "flatmap_y"])

mapping = {
    "excitatory": [
        "Car3",
        "L2/3 IT 1",
        "L2/3 IT 2",
        "L2/3 RSP",
        "L4 IT",
        "L4 RSP",
        "L5 IT",
        "L5 NP",
        "L5 PT",
        "L5/6 IT",
        "L6 CT",
        "L6b",
    ],
    "unassigned": ["Unassigned", "VLMC", "Zero_correlation"],
    "Sst": ["Sst"],
    "Vip": ["Vip"],
    "Lamp5": ["Lamp5"],
    "Pvalb": ["Pvalb"],
}
flattened_mapping = {key: value for value, keys in mapping.items() for key in keys}

# Apply the mapping to create the new column
cells_df["broad_cell_class"] = cells_df["best_cluster"].map(flattened_mapping)

interneuron_df = cells_df[
    cells_df["broad_cell_class"].isin(["Sst", "Vip", "Lamp5", "Pvalb"])
    & (cells_df["best_score"] > 0.3)
    & (cells_df["knn_agree_conf"] > 0.3)
    & (cells_df["raw_gene_counts"] > 2)
]

In [None]:
grouping = "broad_cell_class"
# Shuffle the barcodes assigned to each cell in the connectivity matrix
(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    mean_input_fraction_dfs,
    starter_input_fractions,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    interneuron_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=False,
)

(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    output_fraction_dfs,
    _,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    interneuron_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

connectivity_matrix, mean_input_fraction, fractions_df, _ = (
    conn_mat.compute_connectivity_matrix(
        interneuron_df,
        starter_grouping=grouping,
        presyn_grouping=grouping,
        output_fraction=False,
    )
)

connectivity_matrix, output_fraction, _, _ = conn_mat.compute_connectivity_matrix(
    interneuron_df,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

In [None]:
cell_types = ["Sst", "Vip", "Lamp5", "Pvalb"]

input_fraction_log_ratio, input_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        mean_input_fraction,
        np.array(mean_input_fraction_dfs),
        row_order=cell_types,
        col_order=cell_types,
    )
)

output_fraction_log_ratio, output_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        output_fraction,
        np.array(output_fraction_dfs),
        row_order=cell_types,
        col_order=cell_types,
    )
)

counts_df, mean_input_frac_df, fractions_df, _ = conn_mat.compute_connectivity_matrix(
    interneuron_df,
    starter_grouping=grouping,
    presyn_grouping=grouping,
)

from tqdm import tqdm

# Create bootstrap samples
nboot = 1000
bootstrap_samples = []

for i in tqdm(range(nboot)):
    rows = []
    for layer, group in fractions_df.groupby(grouping, observed=True):
        this_layer = group.sample(n=len(group), replace=True)[cell_types].mean()
        this_layer.name = layer
        rows.append(this_layer)  # Resample with replacement
    bootstrap_samples.append(pd.concat(rows, axis=1))
bootstrap_samples = np.array(bootstrap_samples)
lower_df = pd.DataFrame(
    data=np.quantile(bootstrap_samples, 0.025, axis=0),
    index=cell_types,
    columns=cell_types,
)
upper_df = pd.DataFrame(
    data=np.quantile(bootstrap_samples, 0.975, axis=0),
    index=cell_types,
    columns=cell_types,
)

areas = {
    # "excitatory": "Exc",
    "Sst": "Sst",
    "Vip": "Vip",
    "Lamp5": "Lamp5",
    "Pvalb": "Pvalb",
}
presynaptic_counts = conn_mat.reorganise_matrix(counts_df, areas=areas).sum(axis=1)
starter_counts = fractions_df.value_counts(grouping).rename(index=areas)

In [None]:
from brisc.manuscript_analysis import distance_between_cells as dist_cells

relative_presyn_coords_flatmap, distancess_flatmap, starters_df = (
    dist_cells.determine_presynaptic_distances(cells_df, col_prefix="flatmap_")
)

In [None]:
# Plot figure

fontsize_dict = {"title": 7, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=200)

# save_path = Path("Z:/home/shared/presentations/becalick_2025")
save_path = Path("/nemo/lab/znamenskiyp/home/shared/presentations/becalick_2025")

save_fig = True
figname = "interneuron_connections"


# Raw counts
if True:
    ax_counts = fig.add_axes([0.07, 0.75, 0.2, 0.2])
    conn_mat.plot_area_by_area_connectivity(
        conn_mat.reorganise_matrix(counts_df, areas=areas),
        starter_counts,
        presynaptic_counts,
        ax_counts,
        input_fraction=False,
        odds_ratio=False,
        label_fontsize=fontsize_dict["label"],
        tick_fontsize=fontsize_dict["tick"],
        line_width=line_width,
    )

# Input fraction
if True:
    ax_input_fraction = fig.add_axes([0.29, 0.75, 0.2, 0.2])
    ax_input_fraction_cb = fig.add_axes([0.50, 0.75, 0.01, 0.05])
    conn_mat.plot_area_by_area_connectivity(
        conn_mat.reorganise_matrix(mean_input_fraction, areas=areas),
        starter_counts,
        presynaptic_counts,
        ax_input_fraction,
        input_fraction=True,
        odds_ratio=False,
        label_fontsize=fontsize_dict["label"],
        tick_fontsize=fontsize_dict["tick"],
        line_width=line_width,
        show_counts=False,
        cbax=ax_input_fraction_cb,
        cbar_label="Input\nfraction",
    )
    ax_input_fraction.set_ylabel("")
    ax_input_fraction.set_yticks([])


# Schematics
ax_schematics = fig.add_axes([0.61, 0.77, 0.16, 0.18])
cax_schematics = fig.add_axes([0.78, 0.75, 0.01, 0.05])
positions = {l: p for l, p in zip(cell_types, [(1.5, 2), (0, 2), (1.5, 0), (0, 0)])}
fig, ax, cbar = conn_mat.connectivity_diagram_mpl(
    mean_input_fraction,
    lower_df,
    upper_df,
    connection_names=cell_types,
    positions=positions,
    display_names=cell_types,
    node_style=dict(facecolor="Lightgray", radius=0.5, fontsize=fontsize_dict["title"]),
    min_fraction_cutoff=0.2,
    ci_to_alpha=False,
    ci_cmap="plasma_r",
    edge_width_scale=10,
    arrow_head_scale=20,
    arrow_style=dict(connectionstyle="Arc3, rad=-0.2", ec="k"),
    vmin=0,
    vmax=1,
    ax=ax_schematics,
    cax=cax_schematics,
)
cbar.set_ticks([0, 0.5, 1])


if save_fig:
    fig.savefig(save_path / f"{figname}.pdf", format="pdf")
    fig.savefig(save_path / f"{figname}.png", format="png")
    print(f"Figure saved in {save_path / figname}.pdf")

# Long range
for cell_type in cell_types:
    type_df = cells_df.query(f"broad_cell_class == '{cell_type}'")
    st_type = type_df.query("is_starter == True")

In [None]:
from scipy import stats


def determine_presynaptic_distances(rabies_cell_properties):
    """Determine the distances between starter and presynaptic cells.

    Args:
        rabies_cell_properties (pd.DataFrame): DataFrame with cell properties

    Returns:
        all_coords (numpy.ndarray): 2D array with x, y, z coordinates of presynaptic cells
    """
    if False:
        rabies_cell_properties = rabies_cell_properties.query("max_n_spots>4").copy()
    starter_df = rabies_cell_properties.query("is_starter==True").copy()
    starter_df["n_presynaptic"] = 0
    for stid, series in starter_df.iterrows():
        bc = series["main_barcode"]
        presy = rabies_cell_properties.query("main_barcode==@bc")
        starter_df.loc[stid, "n_presynaptic"] = presy.shape[0]

    if False:
        valid = starter_df.query("n_presynaptic>4").copy()
    else:
        valid = starter_df.query("n_presynaptic>0").copy()

    # build a list of coordinates of presynaptic cells relative to starter position
    valid["presynaptic_coords"] = None
    valid["non_rel_coords"] = None
    for stid, series in valid.iterrows():
        bc = series["main_barcode"]
        presy = rabies_cell_properties.query("main_barcode==@bc")
        # remove starter cell from list
        presy = presy.query("is_starter==False")
        start_coords = series[["ara_x", "ara_y", "ara_z"]].values
        presy_coords = presy[["ara_x", "ara_y", "ara_z"]].values
        relative_coords = presy_coords - start_coords
        valid.loc[stid, "presynaptic_coords"] = [relative_coords]
        valid.loc[stid, "non_rel_coords"] = [presy_coords]

    pres = valid[valid.n_presynaptic > 0]
    all_coords = np.hstack(pres.presynaptic_coords.values)[0].astype(float)
    distances = np.linalg.norm(all_coords, axis=1) * 1000
    return all_coords, distances


def find_distances_to_non_starter_cell_type(
    annotated_cells_df,
    non_starter_cell_type="Pvalb",
    plot_hist=True,
):
    """Find distances from starter cells to a specific non-starter cell type and compare to all other non-starter cells.

    Args:
        annotated_cells_df (pd.DataFrame): DataFrame containing barcode and cell type info.
        non_starter_cell_type (str, optional): Annotation of non-starter cell type. Defaults to 'Pvalb'.

    Returns:
        all_coords_target_non_starter_df (np.array): 2D array with x, y, z coordinates of target non-starter cells.
        distances_target_non_starter_df (np.array): 1D array of distances from starter cells to target non-starter cells.
        all_coords_other_non_starter_df (np.array): 2D array with x, y, z coordinates of other non-starter cells.
        distances_other_non_starter_df (np.array): 1D array of distances from starter cells to other non-starter cells.
    """
    # Identify non-starter cells of the specific type and other non-starter cells
    target_non_starter_cells = annotated_cells_df.query(
        f"is_starter==False and Annotated_clusters=='{non_starter_cell_type}'"
    )
    other_non_starter_cells = annotated_cells_df.query(
        f"is_starter==False and Annotated_clusters!='{non_starter_cell_type}'"
    )

    # Create dataframes for the two groups
    target_non_starter_df = annotated_cells_df[
        ~annotated_cells_df.index.isin(other_non_starter_cells.index)
    ].copy()
    other_non_starter_df = annotated_cells_df[
        ~annotated_cells_df.index.isin(target_non_starter_cells.index)
    ].copy()

    # Calculate distances from all starter cells to the two groups of non-starter cells
    all_coords_target_non_starter_df, distances_target_non_starter_df = (
        determine_presynaptic_distances(target_non_starter_df)
    )
    all_coords_other_non_starter_df, distances_other_non_starter_df = (
        determine_presynaptic_distances(other_non_starter_df)
    )

    if plot_hist:
        fig, ax = plt.subplots(dpi=200)
        res = stats.ks_2samp(
            distances_other_non_starter_df,
            distances_target_non_starter_df,
            alternative="two_sided",
        )
        ax.hist(
            distances_other_non_starter_df,
            bins=60,
            range=(0, 5000),
            edgecolor="lightblue",
            histtype="step",
            linewidth=2,
            label=f"Other non-{non_starter_cell_type} cells",
            density=True,
            alpha=0.9,
        )
        ax.hist(
            distances_target_non_starter_df,
            bins=60,
            range=(0, 5000),
            edgecolor="orange",
            histtype="step",
            linewidth=2,
            label=f"{non_starter_cell_type} cells",
            density=True,
            alpha=0.9,
        )
        ax.set_xlabel("Distance (um)")
        ax.set_ylabel("Frequency")
        ax.text(
            0.5,
            0.5,
            f"KS test p-value: {res.pvalue:.2f}",
            transform=plt.gca().transAxes,
        )
        ax.text(
            0.5,
            0.45,
            f"KS test statistic: {res.statistic:.2f}",
            transform=plt.gca().transAxes,
        )
        ax.legend()
        return (
            all_coords_target_non_starter_df,
            distances_target_non_starter_df,
            all_coords_other_non_starter_df,
            distances_other_non_starter_df,
            fig,
        )
    else:
        fig = None
        return (
            all_coords_target_non_starter_df,
            distances_target_non_starter_df,
            all_coords_other_non_starter_df,
            distances_other_non_starter_df,
            fig,
        )

In [None]:
out = {}
for cell_type in cell_types:
    (
        all_coords_target_non_starter_df,
        distances_target_non_starter_df,
        all_coords_other_non_starter_df,
        distances_other_non_starter_df,
        fig,
    ) = find_distances_to_non_starter_cell_type(
        cells_df, non_starter_cell_type=cell_type, plot_hist=False
    )
    out[cell_type] = distances_target_non_starter_df

In [None]:
# Plot full figure with bubble plots

fontsize_dict = {"title": 7, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=200)

save_path = get_processed_path("presentations/becalick_2025")
save_fig = True
figname = "matrices"

# Raw counts
ax_counts = fig.add_axes([0.07, 0.75, 0.2, 0.2])
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(counts_df, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_counts,
    input_fraction=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Input fraction
ax_input_fraction = fig.add_axes([0.29, 0.75, 0.2, 0.2])
ax_input_fraction_cb = fig.add_axes([0.50, 0.75, 0.01, 0.05])

conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(mean_input_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_input_fraction,
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
    cbax=ax_input_fraction_cb,
    cbar_label="Input\nfraction",
)
ax_input_fraction.set_ylabel("")
ax_input_fraction.set_yticks([])

ax_input_fraction_bars = fig.add_axes([0.62, 0.77, 0.25, 0.18])
boot.plot_confidence_intervals(
    conn_mat.reorganise_matrix(mean_input_frac_df, areas=areas),
    conn_mat.reorganise_matrix(lower_df, areas=areas),
    conn_mat.reorganise_matrix(upper_df, areas=areas),
    ax_input_fraction_bars,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    orientation="horizontal",
)

# Bubble plot input fraction
ax_bubble_plot_input_frac = fig.add_axes([0.07, 0.45, 0.2, 0.2])
ax_bubble_plot_input_frac_cb = fig.add_axes([0.29, 0.6, 0.01, 0.05])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(input_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(input_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_input_frac,
    cbax=ax_bubble_plot_input_frac_cb,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)

# Output fraction
ax_output_fraction = fig.add_axes([0.45, 0.45, 0.2, 0.2])
ax_output_fraction_cb = fig.add_axes([0.66, 0.45, 0.01, 0.05])
# output_fraction = conn_mat.reorganise_matrix(output_fraction)
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(output_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_output_fraction,
    cbax=ax_output_fraction_cb,
    cbar_label="Output\nfraction",
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
)

# Bubble plot input fraction
ax_bubble_plot_output_frac = fig.add_axes([0.78, 0.45, 0.2, 0.2])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(output_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(output_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=250,
    ax=ax_bubble_plot_output_frac,
    show_legend=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
)