In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from iss_preprocess.io import get_processed_path
from brisc.manuscript_analysis import connectivity_matrices as conn_mat
from brisc.manuscript_analysis import distance_between_cells as dist
from brisc.manuscript_analysis import bootstrapping as boot
from brisc.manuscript_analysis import load
from brisc.exploratory_analysis.plot_summary_for_all_bc import (
    compute_flatmap_coors,
    get_avg_layer_depth,
)

from pathlib import Path
import numpy as np

import pandas as pd
from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.font_manager as fm

arial_font_path = "/nemo/lab/znamenskiyp/home/shared/resources/fonts/arial.ttf"
arial_prop = fm.FontProperties(fname=arial_font_path)
plt.rcParams["font.family"] = arial_prop.get_name()
fm.fontManager.addfont(arial_font_path)
matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

In [None]:
processed_path = get_processed_path("becalia_rabies_barseq/BRAC8498.3e/analysis")
error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_26"
df_path = processed_path / f"{error_correction_ds_name}_cell_barcode_df.pkl"
# Run only once to add gene info
if False:
    old_df = pd.read_pickle(processed_path / "cell_barcode_df.pkl")
    from brisc.manuscript_analysis.cell_barcode_assignment import assign_cell_barcodes

    new_df = assign_cell_barcodes(
        error_correction_ds_name=error_correction_ds_name,
    )

    col2copy = ["raw_gene_counts", "best_score", "knn_agree_conf", "best_cluster"]
    new_df = pd.merge(
        new_df, old_df[col2copy], how="left", left_index=True, right_index=True
    )
    new_df.to_pickle(df_path)
    coords = new_df[[f"ara_{o}" for o in "xyz"]]
    print(coords.isna().values.any(1).sum())

In [None]:
print("Loading data")
cells_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
    error_correction_ds_name=error_correction_ds_name,
)

print("Projection on flatmap")
flat_coors = compute_flatmap_coors(cells_df, distance_cutoff=150)
cells_df["flatmap_x"] = flat_coors[:, 0]
cells_df["flatmap_y"] = flat_coors[:, 1]
cells_df["flatmap_z"] = flat_coors[:, 2]


normalised_coors = compute_flatmap_coors(
    cells_df, distance_cutoff=150, thickness_type="normalized_layers"
)
# keep x/y too even if they're the same as non normalised to make it easier for later
# functions

cells_df["flatmap_x_normalised"] = np.array(normalised_coors[:, 0])
cells_df["flatmap_y_normalised"] = np.array(normalised_coors[:, 1])
cells_df["flatmap_z_normalised"] = np.array(normalised_coors[:, 2])

ctx_cells = cells_df[
    (~cells_df["cortical_area"].isna())
    & ~(cells_df.cortical_area.isin(["TH", "hippocampal"]))
]
print(
    f"{ctx_cells.flatmap_x.isna().sum()}/{ctx_cells.shape[0]} cortical cells have no flatmap coordinates"
)
print("Matching barcodes")
conn_mat.match_barcodes(cells_df)
dist.add_connection_distances(cells_df, cols=["flatmap_x", "flatmap_y"])

In [None]:
non_excitatory_cell_types = ["Vip", "Sst", "Pvalb", "Lamp5", "VLMC"]
non_excitatory = cells_df["best_cluster"].isin(non_excitatory_cell_types)
mapping = {
    "excitatory": [
        "Car3",
        "L2/3 IT 1",
        "L2/3 IT 2",
        "L2/3 RSP",
        "L4 IT",
        "L4 RSP",
        "L5 IT",
        "L5 NP",
        "L5 PT",
        "L5/6 IT",
        "L6 CT",
        "L6b",
    ],
    "unassigned": ["Unassigned", "VLMC", "Zero_correlation"],
    "Sst": ["Sst"],
    "Vip": ["Vip"],
    "Lamp5": ["Lamp5"],
    "Pvalb": ["Pvalb"],
}
flattened_mapping = {key: value for value, keys in mapping.items() for key in keys}

# Apply the mapping to create the new column
cells_df["broad_cell_class"] = cells_df["best_cluster"].map(flattened_mapping)

interneuron_df = cells_df[
    cells_df["broad_cell_class"].isin(["Sst", "Vip", "Lamp5", "Pvalb"])
    & (cells_df["best_score"] > 0.3)
    & (cells_df["knn_agree_conf"] > 0.3)
    & (cells_df["raw_gene_counts"] > 2)
]

cells_df = cells_df[~non_excitatory]
layers = ["L2/3", "L4", "L5", "L6a", "L6b"]
areas = {
    "L2/3": "2/3",
    "L4": "4",
    "L5": "5",
    "L6a": "6a",
    "L6b": "6b",
}

# local connections
cells_df = cells_df[cells_df["cortical_layer"].notnull()]
# change any cells with cells_df["cortical_layer"] == "L1" to "L2,3"
cells_df.loc[cells_df["cortical_layer"] == "L1", "cortical_layer"] = "L2/3"
cells_df = cells_df[cells_df["cortical_layer"].apply(lambda layer: layer in layers)]
# cells_df = cells_df[cells_df["distances"].apply(lambda dist: np.max(dist)) < 100]
cells_df[cells_df["distances"].apply(lambda dist: len(dist) > 0)]["distances"].apply(
    lambda dist: np.max(dist)
) < 100


grouping = "cortical_layer"  # "area_acronym_ancestor_rank1"
# Shuffle the barcodes assigned to each cell in the connectivity matrix
(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    mean_input_fraction_dfs,
    starter_input_fractions,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    cells_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=False,
)

(
    shuffled_cell_barcode_dfs,
    shuffled_matrices,
    output_fraction_dfs,
    _,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    cells_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

connectivity_matrix, mean_input_fraction, fractions_df, _ = (
    conn_mat.compute_connectivity_matrix(
        cells_df,
        starter_grouping=grouping,
        presyn_grouping=grouping,
        output_fraction=False,
    )
)

connectivity_matrix, output_fraction, _, _ = conn_mat.compute_connectivity_matrix(
    cells_df,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

input_fraction_log_ratio, input_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        mean_input_fraction,
        np.array(mean_input_fraction_dfs),
        row_order=layers,
        col_order=layers,
    )
)

output_fraction_log_ratio, output_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        output_fraction,
        np.array(output_fraction_dfs),
        row_order=layers,
        col_order=layers,
    )
)

counts_df, mean_input_frac_df, fractions_df, _ = conn_mat.compute_connectivity_matrix(
    cells_df,
    starter_grouping="cortical_layer",  # "area_acronym_ancestor_rank1",
    presyn_grouping="cortical_layer",  # "area_acronym_ancestor_rank1",
)

# Create bootstrap samples
nboot = 1000
bootstrap_samples = []

for i in tqdm(range(nboot)):
    rows = []
    for layer, group in fractions_df.groupby("cortical_layer", observed=True):
        this_layer = group.sample(n=len(group), replace=True)[layers].mean()
        this_layer.name = layer
        rows.append(this_layer)  # Resample with replacement
    bootstrap_samples.append(pd.concat(rows, axis=1))
bootstrap_samples = np.array(bootstrap_samples)
lower_df = pd.DataFrame(
    data=np.quantile(bootstrap_samples, 0.025, axis=0), index=layers, columns=layers
)
upper_df = pd.DataFrame(
    data=np.quantile(bootstrap_samples, 0.975, axis=0), index=layers, columns=layers
)

presynaptic_counts = conn_mat.reorganise_matrix(counts_df).sum(axis=1)
starter_counts = fractions_df.value_counts("cortical_layer").rename(index=areas)

# Compute distances between cells
relative_presyn_coords_flatmap, distancess_flatmap, starters_df = (
    dist.determine_presynaptic_distances(
        cells_df, col_prefix="flatmap_", col_suffix="_normalised"
    )
)
alllayers = ["L1"] + layers + ["WM"]
starters_df["layer_index"] = starters_df.cortical_layer.map(
    lambda x: alllayers.index(x) if x in alllayers else np.nan
)

In [None]:
# Compute shuffles for interneurons
grouping = "broad_cell_class"
(
    inh_shuffled_cell_barcode_dfs,
    inh_shuffled_matrices,
    inh_mean_input_fraction_dfs,
    inh_starter_input_fractions,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    interneuron_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=False,
)

(
    inh_shuffled_cell_barcode_dfs,
    inh_shuffled_matrices,
    inh_output_fraction_dfs,
    _,
    _,
) = conn_mat.shuffle_and_compute_connectivity(
    interneuron_df,
    n_permutations=1000,
    shuffle_starters=False,
    shuffle_presyn=True,
    starter_grouping=grouping,
    presyn_grouping=grouping,
    output_fraction=True,
)

inh_connectivity_matrix, inh_mean_input_fraction, inh_fractions_df, _ = (
    conn_mat.compute_connectivity_matrix(
        interneuron_df,
        starter_grouping=grouping,
        presyn_grouping=grouping,
        output_fraction=False,
    )
)

inh_connectivity_matrix, inh_output_fraction, _, _ = (
    conn_mat.compute_connectivity_matrix(
        interneuron_df,
        starter_grouping=grouping,
        presyn_grouping=grouping,
        output_fraction=True,
    )
)

cell_types = [
    "Pvalb",
    "Sst",
    "Lamp5",
    "Vip",
]

inh_input_fraction_log_ratio, inh_input_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        inh_mean_input_fraction,
        np.array(inh_mean_input_fraction_dfs),
        row_order=cell_types,
        col_order=cell_types,
    )
)

inh_output_fraction_log_ratio, inh_output_fraction_pval = conn_mat.compare_to_shuffle(
    *conn_mat.filter_matrices(
        inh_output_fraction,
        np.array(inh_output_fraction_dfs),
        row_order=cell_types,
        col_order=cell_types,
    )
)

inh_counts_df, inh_mean_input_frac_df, inh_fractions_df, _ = (
    conn_mat.compute_connectivity_matrix(
        interneuron_df,
        starter_grouping=grouping,
        presyn_grouping=grouping,
    )
)

# Create bootstrap samples
nboot = 1000
inh_bootstrap_samples = []

for i in tqdm(range(nboot)):
    rows = []
    for layer, group in inh_fractions_df.groupby(grouping, observed=True):
        this_layer = group.sample(n=len(group), replace=True)[cell_types].mean()
        this_layer.name = layer
        rows.append(this_layer)  # Resample with replacement
    inh_bootstrap_samples.append(pd.concat(rows, axis=1))
inh_bootstrap_samples = np.array(inh_bootstrap_samples)
inh_lower_df = pd.DataFrame(
    data=np.quantile(inh_bootstrap_samples, 0.025, axis=0),
    index=cell_types,
    columns=cell_types,
)
inh_upper_df = pd.DataFrame(
    data=np.quantile(inh_bootstrap_samples, 0.975, axis=0),
    index=cell_types,
    columns=cell_types,
)

inh_areas = {
    # "excitatory": "Exc",
    "Pvalb": "Pvalb",
    "Sst": "Sst",
    "Lamp5": "Lamp5",
    "Vip": "Vip",
}
inh_presynaptic_counts = conn_mat.reorganise_matrix(inh_counts_df, areas=inh_areas).sum(
    axis=1
)
inh_starter_counts = inh_fractions_df.value_counts(grouping).rename(index=inh_areas)

In [None]:
# Plot Fig.1
save_fig = True

fontsize_dict = {"title": 7, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
line_width = 0.9
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=200)

# save_path = Path("Z:/home/shared/presentations/becalick_2025")
save_path = Path("/nemo/lab/znamenskiyp/home/shared/presentations/becalick_2025")
figname = "matrices"


if True:
    frame = fig.add_axes([0, 0, 1, 1])
    frame.set_xticks([])
    frame.set_yticks([])

# Presynaptic scatters
w = 0.17
scl = 10  # scale to put distance in um
avg_layer_tops = get_avg_layer_depth()
layer_borders = np.hstack([0, np.sort(np.hstack(list(avg_layer_tops.values()))), 1200])
midlayer = np.diff(layer_borders) / 2 + layer_borders[:-1]
ax_layers = []

for il, layer in enumerate(layers):
    # axl = fig.add_axes([0.05 + (w + 0.01) * il, 0.6, w, 0.2], aspect="equal")
    axl = fig.add_axes([0.05 + (w + 0.01) * il, 0.8, w, 0.2], aspect="equal")
    for d in layer_borders[:-1]:
        axl.axhline(d, ls="--", lw=0.5, color="lightgray")
    this_layer = starters_df[starters_df["cortical_layer"] == layer]

    rel_ap = np.hstack(this_layer["presynaptic_coors_relative"].values)[0][:, 0] * scl
    rel_ml = np.hstack(this_layer["presynaptic_coors_relative"].values)[0][:, 1] * scl
    abs_depth = np.hstack(this_layer["presynaptic_coors"].values)[0][:, 2] * scl
    plt.scatter(
        rel_ml, abs_depth, marker=".", color="darkred", alpha=0.5, ec="w", lw=0.1, s=10
    )
    plt.scatter(
        np.zeros(len(this_layer)),
        this_layer["flatmap_z_normalised"] * scl,
        marker=".",
        color="k",
        s=20,
        alpha=0.3,
        ec="None",
    )
    axl.set_xlim(-800, 800)
    axl.set_ylim(1000, -50)
    if il == 0:
        axl.set_title(
            f"Starter layer:     {layer[1:]}", fontsize=fontsize_dict["title"]
        )
        rec = plt.Rectangle((-680, 950), 200, 30, color="k")
        axl.add_artist(rec)
    else:
        axl.set_title(layer[1:], fontsize=fontsize_dict["title"])
    axl.set_axis_off()
for il, layer in enumerate(["L1"] + layers):
    axl.text(
        750,
        midlayer[il],
        layer[1:],
        fontsize=fontsize_dict["legend"],
        horizontalalignment="center",
        verticalalignment="center",
    )


# Raw counts
ax_counts = fig.add_axes([0.07, 0.59, 0.2, 0.2])
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(counts_df, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_counts,
    input_fraction=False,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

# Input fraction
ax_input_fraction = fig.add_axes([0.28, 0.59, 0.2, 0.2])
ax_input_fraction_cb = fig.add_axes([0.49, 0.59, 0.01, 0.05])
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(mean_input_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_input_fraction,
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
    cbax=ax_input_fraction_cb,
    cbar_label="Input\nfraction",
    xlabel="Starter layer",
    ylabel="Presynaptic layer",
    vmin=0,
    vmax=0.40,
)
ax_input_fraction.set_ylabel("")
ax_input_fraction.set_yticks([])

# Confidence interval of input fraction
ax_input_fraction_bars = fig.add_axes([0.59, 0.61, 0.25, 0.18])
boot.plot_confidence_intervals(
    conn_mat.reorganise_matrix(mean_input_frac_df, areas=areas),
    conn_mat.reorganise_matrix(lower_df, areas=areas),
    conn_mat.reorganise_matrix(upper_df, areas=areas),
    ax_input_fraction_bars,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    orientation="horizontal",
)

# Schematics
ax_schematics = fig.add_axes([0.86, 0.61, 0.15, 0.18])
cax_schematics = fig.add_axes([0.89, 0.58, 0.01, 0.05])
positions = {l: p for l, p in zip(layers, [(2, 6), (0, 5), (2, 4), (0, 3), (2, 2)])}
fig, ax, cbar = conn_mat.connectivity_diagram_mpl(
    mean_input_fraction,
    lower_df,
    upper_df,
    connection_names=layers,
    positions=positions,
    display_names=[l[1:] for l in layers],
    node_style=dict(facecolor="Lightgray", radius=0.5, fontsize=fontsize_dict["title"]),
    min_fraction_cutoff=0.2,
    ci_to_alpha=False,
    ci_cmap="RdPu_r",
    edge_width_scale=10,
    arrow_head_scale=20,
    arrow_style=dict(connectionstyle="Arc3, rad=-0.2", ec="none"),
    ax=ax_schematics,
    cax=cax_schematics,
    vmin=0,
    vmax=0.4,
)
# cbar.set_ticks([0, 0.5, 1])

# Bubble plot input fraction
ax_bubble_plot_input_frac = fig.add_axes([0.07, 0.29, 0.2, 0.2])
ax_bubble_plot_input_frac_cb = fig.add_axes([0.29, 0.44, 0.01, 0.05])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(input_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(input_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=80,
    ax=ax_bubble_plot_input_frac,
    cbax=ax_bubble_plot_input_frac_cb,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    vmin=-2,
    vmax=2,
)

# Output fraction
ax_output_fraction = fig.add_axes([0.45, 0.29, 0.2, 0.2])
ax_output_fraction_cb = fig.add_axes([0.66, 0.29, 0.01, 0.05])
# output_fraction = conn_mat.reorganise_matrix(output_fraction)
conn_mat.plot_area_by_area_connectivity(
    conn_mat.reorganise_matrix(output_fraction, areas=areas),
    starter_counts,
    presynaptic_counts,
    ax_output_fraction,
    cbax=ax_output_fraction_cb,
    cbar_label="Output\nfraction",
    input_fraction=True,
    odds_ratio=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
    show_counts=False,
    vmin=0,
    vmax=0.4,
)

# Bubble plot input fraction
ax_bubble_plot_output_frac = fig.add_axes([0.78, 0.29, 0.2, 0.2])
conn_mat.bubble_plot(
    conn_mat.reorganise_matrix(output_fraction_log_ratio, areas=areas),
    conn_mat.reorganise_matrix(output_fraction_pval, areas=areas),
    alpha=0.05,
    size_scale=80,
    ax=ax_bubble_plot_output_frac,
    show_legend=False,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    vmin=-2,
    vmax=2,
)

########################################################################
########### Interneuron connectivity matrices ##########################
########################################################################

# Raw counts
if True:
    ax_inh_counts = fig.add_axes([0.07, 0.01, 0.2, 0.2])
    conn_mat.plot_area_by_area_connectivity(
        conn_mat.reorganise_matrix(inh_counts_df, areas=inh_areas),
        inh_starter_counts,
        inh_presynaptic_counts,
        ax_inh_counts,
        input_fraction=False,
        odds_ratio=False,
        label_fontsize=fontsize_dict["label"],
        tick_fontsize=fontsize_dict["tick"],
        line_width=line_width,
        xlabel="Starter cell type",
        ylabel="Presynaptic cell type",
    )

# Input fraction
if True:
    ax_inh_input_fraction = fig.add_axes([0.28, 0.01, 0.2, 0.2])
    ax_inh_input_fraction_cb = fig.add_axes([0.49, 0.01, 0.01, 0.05])
    conn_mat.plot_area_by_area_connectivity(
        conn_mat.reorganise_matrix(inh_mean_input_fraction, areas=inh_areas),
        inh_starter_counts,
        inh_presynaptic_counts,
        ax_inh_input_fraction,
        input_fraction=True,
        odds_ratio=False,
        label_fontsize=fontsize_dict["label"],
        tick_fontsize=fontsize_dict["tick"],
        line_width=line_width,
        show_counts=False,
        cbax=ax_inh_input_fraction_cb,
        cbar_label="Input\nfraction",
        xlabel="Starter cell type",
        ylabel="Presynaptic cell type",
        vmin=0,
        vmax=0.7,
    )
    ax_inh_input_fraction.set_ylabel("")
    ax_inh_input_fraction.set_yticks([])


# Schematics
ax_inh_schematics = fig.add_axes([0.61, 0.07, 0.16, 0.18])
cax_inh_schematics = fig.add_axes([0.78, 0.01, 0.01, 0.05])
positions = {l: p for l, p in zip(cell_types, [(1.5, 2), (0, 2), (1.5, 0), (0, 0)])}
fig, ax, cbar = conn_mat.connectivity_diagram_mpl(
    inh_mean_input_fraction,
    inh_lower_df,
    inh_upper_df,
    connection_names=cell_types,
    positions=positions,
    display_names=cell_types,
    node_style=dict(facecolor="Lightgray", radius=0.5, fontsize=fontsize_dict["title"]),
    min_fraction_cutoff=0.2,
    ci_to_alpha=False,
    ci_cmap="RdPu_r",  # YlOrRd_r
    edge_width_scale=10,
    arrow_head_scale=20,
    arrow_style=dict(connectionstyle="Arc3, rad=-0.2", ec="none"),
    vmin=0,
    vmax=0.6,
    ax=ax_inh_schematics,
    cax=cax_inh_schematics,
)
cbar.set_ticks([0, 0.6])

if save_fig:
    fig.savefig(save_path / f"fig5_conn_mat.pdf", format="pdf")
    fig.savefig(save_path / f"fig5_conn_mat.png", format="png")