In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from brisc.manuscript_analysis import barcodes_in_cells as bc_cells
from brisc.manuscript_analysis import match_to_library as match_lib
from brisc.manuscript_analysis import sensitivity as sens
from brisc.manuscript_analysis import mcherry_intensity as mcherry_int
from brisc.manuscript_analysis import distance_between_cells as dist_cells
from brisc.manuscript_analysis import overview_image

from pathlib import Path
import pandas as pd
import numpy as np
import itertools
from multiprocessing import Pool
from tqdm import tqdm
import functools

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.font_manager as fm

arial_font_path = "/nemo/lab/znamenskiyp/home/shared/resources/fonts/arial.ttf"  # update path as needed
arial_prop = fm.FontProperties(fname=arial_font_path)
plt.rcParams["font.family"] = arial_prop.get_name()
plt.rcParams.update({"mathtext.default": "regular"})  # make math mode also Arial
fm.fontManager.addfont(arial_font_path)
matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

from iss_preprocess.io import get_processed_path

In [None]:
barseq_path = Path("Y:/")
main_path = Path("Z:/")
barseq_path = get_processed_path("becalia_rabies_barseq").parent.parent
main_path = Path("/nemo/lab/znamenskiyp/")
print(barseq_path)
print(main_path)

error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_26"

In [None]:
(
    in_situ_barcode_matches,
    random_barcode_matches,
    rv35_library,
) = match_lib.load_data(
    redo=False,
    barseq_path=barseq_path,
    main_path=main_path,
    error_correction_ds_name=error_correction_ds_name,
)

cells_df = pd.read_pickle(
    barseq_path
    / f"processed/becalia_rabies_barseq/BRAC8498.3e/analysis/{error_correction_ds_name}_cell_barcode_df.pkl"
)
cells_df = cells_df[cells_df["main_barcode"].notna()]
cells_df["n_unique_barcodes"] = cells_df["all_barcodes"].apply(len)

In [None]:
# Add section info
from iss_preprocess.io.load import load_section_position

sec_infos = load_section_position("becalia_rabies_barseq/BRAC8498.3e/chamber_07")
cells_df["absolute_section"] = np.nan
for chamber, cdf in cells_df.groupby("chamber"):
    chamber_id = int(chamber.split("_")[1])
    sec = sec_infos[sec_infos.chamber == chamber_id].copy()
    sec = sec.set_index("chamber_position", drop=False)
    cells_df.loc[cdf.index, "absolute_section"] = sec.loc[
        cdf.roi, "absolute_section"
    ].values

In [None]:
all_barcodes = list(set(itertools.chain.from_iterable(cells_df["all_barcodes"].values)))
barcodes_df = pd.DataFrame({"barcode": all_barcodes})
barcodes_df["n_starters"] = barcodes_df["barcode"].apply(
    lambda barcode: cells_df[cells_df["is_starter"] == True]["all_barcodes"]
    .apply(lambda x: barcode in x)
    .sum()
)
barcodes_df["n_presynaptic"] = barcodes_df["barcode"].apply(
    lambda barcode: cells_df[cells_df["is_starter"] == False]["all_barcodes"]
    .apply(lambda x: barcode in x)
    .sum()
)

In [None]:
def _hamming_distance(str1, str2):
    return sum(c1 != c2 for c1, c2 in zip(str1, str2))


# Define a function to calculate the minimum edit distance
def _calculate_min_edit_distance_worker(insitu_bc, lib_10bp_seq_ref, rv35_library_ref):
    edit_distances = np.fromiter(
        (_hamming_distance(insitu_bc, lib_bc) for lib_bc in lib_10bp_seq_ref), int
    )
    min_edit_distance_idx = np.argmin(edit_distances)
    min_edit_distance = edit_distances[min_edit_distance_idx]
    lib_bc_sequence = rv35_library_ref.loc[min_edit_distance_idx, "10bp_seq"]
    lib_bc_count = rv35_library_ref.loc[min_edit_distance_idx, "counts"]
    return min_edit_distance, lib_bc_sequence, lib_bc_count


processed_path = barseq_path / "processed/becalia_rabies_barseq/BRAC8498.3e/"
barcode_library_sequence_path = (
    main_path
    / "home/shared/projects/barcode_diversity_analysis/collapsed_barcodes/RV35/RV35_bowtie_ed2.txt"
)
rv35_library = pd.read_csv(barcode_library_sequence_path, sep="\t", header=None)
rv35_library["10bp_seq"] = rv35_library[1].str.slice(0, 10)
rv35_library.rename(columns={0: "counts", 1: "sequence"}, inplace=True)
lib_10bp_seq = np.array(rv35_library["10bp_seq"])

multiple_starter_bcs = barcodes_df[barcodes_df["n_starters"] > 1]["barcode"].values
uni_starter_bcs = barcodes_df[barcodes_df["n_starters"] == 1]["barcode"].values
multiple_starter_bcs


in_situ_barcodes = pd.DataFrame(multiple_starter_bcs, columns=["sequence"])

# Create a partial function with fixed library arguments
partial_worker = functools.partial(
    _calculate_min_edit_distance_worker,
    lib_10bp_seq_ref=lib_10bp_seq,
    rv35_library_ref=rv35_library,
)
# Wrap the outer loop with tqdm for progress tracking
with Pool() as pool:
    results = list(
        tqdm(
            pool.imap(partial_worker, in_situ_barcodes["sequence"]),
            total=len(in_situ_barcodes),
            desc="Calculating edit distances",
        )
    )

# Extract the results from the list of tuples
min_edit_distances, lib_bc_sequences, lib_bc_counts = zip(*results)

# Assign the minimum edit distances, lib_bc sequences, and counts to new columns in in_situ_barcodes
in_situ_barcodes["ham_min_edit_distance"] = min_edit_distances
in_situ_barcodes["ham_lib_bc_sequence"] = lib_bc_sequences
in_situ_barcodes["ham_lib_bc_counts"] = lib_bc_counts

(
    in_situ_barcode_matches,
    random_barcode_matches,
    rv35_library,
) = match_lib.load_data(
    redo=False,
    barseq_path=barseq_path,
    main_path=main_path,
    error_correction_ds_name=error_correction_ds_name,
)

in_situ_barcodes.rename(
    columns={"ham_min_edit_distance": "min_edit_distance"}, inplace=True
)
in_situ_barcodes.rename(columns={"ham_lib_bc_counts": "lib_bc_counts"}, inplace=True)

In [None]:
# Define useful functions

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from scipy.stats import gaussian_kde


def safe_log10(series):
    """Return log10(counts) for values ≥ 1 (drop 0 / 1)."""
    s = series[series >= 1]  # counts ≤ 1 are excluded
    return np.log10(s.astype(float))


def bootstrap_ecdf_ci(values, x_grid, n_boot=1000, ci=95, random_state=None):
    """Bootstrap ECDF confidence interval (values already log-transformed)."""
    rng = np.random.default_rng(random_state)
    n = values.size
    boot_cdfs = np.empty((n_boot, x_grid.size))
    for i in range(n_boot):
        sample = rng.choice(values, size=n, replace=True)
        sample.sort()
        boot_cdfs[i] = np.searchsorted(sample, x_grid, side="right") / n
    alpha = (100 - ci) / 2
    lower = np.percentile(boot_cdfs, alpha, axis=0)
    upper = np.percentile(boot_cdfs, 100 - alpha, axis=0)
    return lower, upper


def bootstrap_kde_ci(
    values, x_grid, n_boot=1000, bw_method=0.3, ci=95, random_state=None
):
    """Bootstrap KDE confidence interval (values already log-transformed)."""
    rng = np.random.default_rng(random_state)
    n = values.size
    boot_dens = np.empty((n_boot, x_grid.size))
    for i in range(n_boot):
        sample = rng.choice(values, size=n, replace=True)
        kde = gaussian_kde(sample, bw_method=bw_method)
        boot_dens[i] = kde.evaluate(x_grid)
    alpha = (100 - ci) / 2
    lower = np.percentile(boot_dens, alpha, axis=0)
    upper = np.percentile(boot_dens, 100 - alpha, axis=0)
    return lower, upper

In [None]:
# Data – log-transform once so every panel uses identical variables
log_all = safe_log10(in_situ_barcode_matches["ham_lib_bc_counts"])
log_multi = safe_log10(in_situ_barcodes["lib_bc_counts"])
lib_counts = rv35_library["counts"].astype(float)
print(lib_counts.sum())
log_lib = safe_log10(lib_counts)

In [None]:
# Common grid for every KDE computation
bw_method = 0.3
x_grid_kde = np.linspace(0, 6, 400)
kde_all = gaussian_kde(log_all, bw_method=bw_method)
kde_multi = gaussian_kde(log_multi, bw_method=bw_method)
kde_lib = gaussian_kde(log_lib, bw_method=bw_method, weights=lib_counts)

dens_all = kde_all.evaluate(x_grid_kde)
dens_multi = kde_multi.evaluate(x_grid_kde)
dens_lib = kde_lib.evaluate(x_grid_kde)
diff_all = dens_all - dens_lib
diff_multi = dens_multi - dens_lib

# 95 % CI for (multi-starter KDE − library KDE)
ci_low_kde, ci_up_kde = bootstrap_kde_ci(
    log_multi.values,
    x_grid_kde,
    n_boot=10_000,
    bw_method=bw_method,
    ci=95,
    random_state=42,
)
ci_low_diff = ci_low_kde - dens_lib
ci_up_diff = ci_up_kde - dens_lib

In [None]:
# Compare the pairwise distance between starters and starters with same barcode
from tqdm import tqdm
from brisc.manuscript_analysis.distance_between_cells import compute_pairwise_distances

multi_starter_bcs = set(multiple_starter_bcs)
uni_starter_bcs = set(uni_starter_bcs)
starter_df = cells_df[cells_df.is_starter]
is_multistart = starter_df.all_barcodes.map(
    lambda x: len(multi_starter_bcs.intersection(set(x))) > 0
)
multi_start = starter_df[is_multistart]
is_unistart = starter_df.all_barcodes.map(
    lambda x: len(uni_starter_bcs.intersection(set(x))) > 0
)
uni_start = starter_df[is_unistart]

print(
    f"{len(barcodes_df[barcodes_df.n_starters>0])} barcodes, with {len(starter_df)} starters"
)
print(f"{len(multi_starter_bcs)} multistart barcode, with {len(multi_start)} starters")
print(f"{len(uni_starter_bcs)} unistart barcode, with {len(uni_start)} starters")
print(
    f"{(is_unistart & is_multistart).sum()} starters with both unique and multi barcodes"
)

pairwise_starters = compute_pairwise_distances(starter_df)
pairwise_uni = compute_pairwise_distances(uni_start)

# Also look at distance between multi starters with same barcodes
pairwise_multi = []
section_diff = []
used = set()
for bc in multi_starter_bcs:
    cells = starter_df[starter_df.all_barcodes.map(lambda x: bc in x)]
    used = used.union(cells.index)
    dst, secin = compute_pairwise_distances(cells, return_section_info=True)
    pairwise_multi.append(dst)
    section_diff.append(secin)
pairwise_multi = np.hstack(pairwise_multi)
section_diff = np.hstack(section_diff)

In [None]:
from brisc.manuscript_analysis.distance_between_cells import (
    bootstrap_barocdes_in_multiple_cells,
)
from brisc.manuscript_analysis.distance_between_cells import distances_between_starters

n_boot = 5000
dist2same, sectiondiff, dist2others, bc_used, cell_used = distances_between_starters(
    multi_starter_bcs, starter_df, verbose=True
)
dist2same_boot, sectiondiff_boot, dist2others_boot = (
    bootstrap_barocdes_in_multiple_cells(
        starter_df, multi_starter_bcs, n_permutations=n_boot, random_state=12
    )
)

In [None]:
# Look if median distance is different
from brisc.manuscript_analysis.bootstrapping import calc_pval

median_distances = dict(
    med2same=np.hstack(list(map(np.nanmedian, dist2same_boot))),
    med2other=np.hstack(list(map(np.nanmedian, dist2others_boot))),
)
# A bit for complex is we want to exclude adjacent sections
is_adjacent = [np.abs(sdiff) == 1 for sdiff in sectiondiff_boot]
dist2same_excludingadj_boot = [
    dst[~is_adj] for dst, is_adj in zip(dist2same_boot, is_adjacent)
]
median_distances["med2same_excluding_adjacent"] = np.hstack(
    list(map(np.nanmedian, dist2same_excludingadj_boot))
)

p_val_exl = calc_pval(
    median_distances["med2other"],
    median_distances["med2same_excluding_adjacent"],
    n_boot,
)
p_val_same = calc_pval(
    median_distances["med2other"], median_distances["med2same"], n_boot
)
print(
    f"Median distance to other barcode: {np.nanmedian(dist2others):.2f} mm ({len(dist2others)} distances)"
)
print(
    f"Median distance to same barcode: {np.nanmedian(dist2same):.2f} mm, p-value: {p_val_same:.3f} ({len(dist2same)} distances)"
)
not_adj = np.abs(sectiondiff) != 1
print(
    f"Median distance to same barcode, excluding adjacent: {np.nanmedian(dist2same[not_adj]):.2f} mm, p-value: {p_val_exl:.3f} ({not_adj.sum()} distances)"
)

In [None]:
from brisc.manuscript_analysis.utils import despine

colors = [
    "#e78ac3",
    "#8da0cb",
    "#fc8d62",
    "#66c2a5",
    "#a6d854",
]

save_fig = True
fontsize_dict = {"title": 8, "label": 8, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
hist_linewidth = 0.5
linewidth = 1.2
line_alpha = 1
save_path = main_path / "home/shared/presentations/becalick_2025"
figname = "suppfig_orphan_origin"

cm = 1 / 2.54
total_read_in_library = np.sum(rv35_library["counts"])
xticklab = np.array([1e-8, 1e-5, 1e-2])
xtick = np.log10(xticklab * total_read_in_library)
xt_labels = ["$10^{-8}$", "$10^{-5}$", "$10^{-2}$"]


# Figure
fig = plt.figure(figsize=(16 * cm, 5 * cm), dpi=600)

ax_frame = fig.add_axes([0.0, 0, 1, 1])  # left, bottom, width, height
ax_frame.set_xticks([])
ax_frame.set_yticks([])

ax_kde = fig.add_axes([0.07, 0.16, 0.22, 0.75])  # left, bottom, width, height
ax_diff = fig.add_axes([0.37, 0.16, 0.22, 0.75])
ax_pairwise = fig.add_axes([0.69, 0.16, 0.22, 0.75])
# ax_cdf = fig.add_axes([0.53, 0.1, 0.25, 0.20])
# KDE panel
if True:
    (lib_line,) = ax_kde.plot(
        x_grid_kde,
        dens_lib,
        color="black",
        linewidth=linewidth,
        linestyle="--",
        zorder=10,
        label="Library barcodes",
    )
    (all_line,) = ax_kde.plot(
        x_grid_kde,
        dens_all,
        label="All in situ barcodes",
        color=colors[0],
        linewidth=linewidth,
    )
    (multi_line,) = ax_kde.plot(
        x_grid_kde,
        dens_multi,
        label="Barcodes found in\nmultiple starter cells",
        color=colors[1],
        linewidth=linewidth,
    )

    # Bootstrap 95 % CI for multi-starter KDE
    ax_kde.fill_between(
        x_grid_kde,
        ci_low_kde,
        ci_up_kde,
        color=colors[1],
        alpha=0.20,
        zorder=0,
        label="__no_label__",
    )

    ax_kde.set_xlabel(
        "Proportion of unique reads",
        fontsize=fontsize_dict["label"],
        labelpad=pad_dict["label"],
    )
    ax_kde.set_ylabel(
        "Density", fontsize=fontsize_dict["label"], labelpad=pad_dict["label"]
    )
    ax_kde.set_xlim(xtick[0], xtick[-1])
    ax_kde.set_ylim(0, 0.7)

    ax_kde.set_xticks(xtick, labels=xt_labels)
    ax_kde.tick_params(axis="both", which="major", labelsize=fontsize_dict["tick"])

# Difference-KDE panel
if True:
    ax_diff.plot(
        x_grid_kde,
        diff_multi,
        lw=linewidth,
        color=colors[1],
        label="Multiple starter – Library",
    )
    ax_diff.fill_between(
        x_grid_kde,
        ci_low_diff,
        ci_up_diff,
        color=colors[1],
        alpha=0.20,
        zorder=0,
        label="95 % CI",
    )
    ax_diff.plot(
        x_grid_kde,
        diff_all,
        lw=linewidth,
        color=colors[0],
        label="All in situ – Library",
    )
    ax_diff.axhline(0, lw=0.8, ls="--", color="grey", zorder=-1)
    ax_diff.set_yticks([-0.2, -0.1, 0, 0.1, 0.2])
    ax_diff.set_ylim(-0.2, 0.2)
    ax_diff.set_xlabel(
        "Proportion of unique reads",
        fontsize=fontsize_dict["label"],
        labelpad=pad_dict["label"],
    )
    ax_diff.set_ylabel(
        r"$\Delta$ density", fontsize=fontsize_dict["label"], labelpad=pad_dict["label"]
    )

    ax_diff.set_xticks(xtick, labels=xt_labels)
    ax_diff.set_xlim(xtick[0], xtick[-1])
    ax_diff.tick_params(axis="both", which="major", labelsize=fontsize_dict["tick"])

# ECDF panel
if False:
    sns.ecdfplot(
        x=log_all,
        label="All in situ barcodes",
        ax=ax_cdf,
        lw=linewidth,
        color="deepskyblue",
    )
    sns.ecdfplot(
        x=log_multi,
        label="Multiple starter barcodes",
        ax=ax_cdf,
        lw=linewidth,
        color="mediumorchid",
    )

    # Bootstrap 95 % CI for multi-starter ECDF
    x_grid_ecdf = np.linspace(0, 6, 400)
    ci_lower_ecdf, ci_upper_ecdf = bootstrap_ecdf_ci(
        log_multi.values, x_grid_ecdf, n_boot=10_000, ci=95, random_state=42
    )
    ax_cdf.fill_between(
        x_grid_ecdf,
        ci_lower_ecdf,
        ci_upper_ecdf,
        color="mediumorchid",
        alpha=0.30,
        zorder=0,
        label="Multiple starter (95 % CI)",
    )

    sns.ecdfplot(
        x=log_lib,
        weights=lib_counts,
        label="Library barcodes",
        ax=ax_cdf,
        lw=linewidth,
        linestyle="--",
        color="black",
    )

    ax_cdf.set_xlabel("Abundance", fontsize=fontsize_dict["label"])
    ax_cdf.set_ylabel("CDF", fontsize=fontsize_dict["label"])
    ax_cdf.set_xlim(0, 6)
    ax_cdf.set_xticks(np.arange(0, 7))
    ax_cdf.xaxis.set_major_formatter(FuncFormatter(lambda x, _: rf"$10^{{{int(x)}}}$"))
    ax_cdf.tick_params(axis="both", which="major", labelsize=fontsize_dict["tick"])

# Add pairwise plot
if True:

    def get_kde(x, data, bw_method=0.2):
        kde = gaussian_kde(data, bw_method=bw_method)
        return kde.evaluate(x)

    x = np.arange(0, 2, 0.01)
    (pw_d2o_line,) = ax_pairwise.plot(
        x,
        get_kde(x, dist2others),
        label="Different barcode",
        color=colors[2],
        linewidth=linewidth,
    )
    (pw_same_line,) = ax_pairwise.plot(
        x,
        get_kde(x, dist2same),
        label="Same barcode",
        color=colors[3],
        linewidth=linewidth,
    )
    (multi_noadj_line,) = ax_pairwise.plot(
        x,
        get_kde(x, dist2same[not_adj]),
        label="Same barcode - \nexcluding adjacent sections",
        color=colors[4],
        linewidth=linewidth,
    )
    ci = 95
    i = 0
    for col, vals, med_name in zip(
        [colors[2], colors[4], colors[3]],
        [dist2others, dist2same[not_adj], dist2same],
        ["med2other", "med2same_excluding_adjacent", "med2same"],
    ):
        med = np.nanmedian(vals)
        boot = median_distances[med_name]
        lower = np.percentile(boot, (100 - ci) / 2, axis=0)
        upper = np.percentile(boot, 100 - (100 - ci) / 2, axis=0)
        y = 1.4 + i * 0
        # ax_pairwise.plot([lower, upper], [y,y], color='k', lw=0.5, ms=2, zorder=0, clip_on=False)
        ax_pairwise.scatter(
            med, y, color=col, marker="v", ec="none", s=20, zorder=2, clip_on=False
        )
        i += 1
    m1 = np.nanmedian(dist2others)
    m2 = np.nanmedian(dist2same)
    ax_pairwise.plot(
        [m1, m2], [y + 0.08, y + 0.08], color="k", lw=0.5, ms=2, zorder=0, clip_on=False
    )
    ax_pairwise.text(
        m1 - (m1 - m2) / 2,
        y + 0.12,
        "***",
        fontsize=fontsize_dict["tick"],
        horizontalalignment="center",
        verticalalignment="top",
    )
    m2 = np.nanmedian(dist2same[not_adj])
    ax_pairwise.plot(
        [m1, m2], [y + 0.04, y + 0.04], color="k", lw=0.5, ms=2, zorder=0, clip_on=False
    )
    ax_pairwise.text(
        m1 - (m1 - m2) / 2,
        y + 0.075,
        "*",
        fontsize=fontsize_dict["tick"],
        horizontalalignment="center",
        verticalalignment="top",
    )

    ax_pairwise.set_ylabel(
        "Density", fontsize=fontsize_dict["label"], labelpad=pad_dict["label"]
    )
    ax_pairwise.set_xlabel(
        "Distance between starter cells (mm)",
        fontsize=fontsize_dict["label"],
        labelpad=pad_dict["label"],
    )
    ax_pairwise.set_xticks([0, 1, 2], labels=[0, 1, 2], fontsize=fontsize_dict["tick"])
    ax_pairwise.set_yticks(
        [0, 0.5, 1, 1.5], labels=[0, 0.5, 1, 1.5], fontsize=fontsize_dict["tick"]
    )
    ax_pairwise.set_xlim(0, 2)
    ax_pairwise.set_ylim(0, 1.5)
for ax in [ax_pairwise, ax_kde, ax_diff]:
    despine(ax)
fig.legend(
    handles=[lib_line, all_line, multi_line],
    loc="lower left",
    fontsize=fontsize_dict["legend"],
    bbox_to_anchor=(0.2, 0.7),
    frameon=False,
    handlelength=1,
    handletextpad=0.5,
    ncols=1,
)

fig.legend(
    handles=[pw_d2o_line, pw_same_line, multi_noadj_line],
    loc="lower left",
    fontsize=fontsize_dict["legend"],
    bbox_to_anchor=(0.8, 0.6),
    frameon=False,
    handlelength=1,
    handletextpad=0.5,
    ncols=1,
)

if save_fig:
    save_path.mkdir(parents=True, exist_ok=True)
    fig.savefig(save_path / f"{figname}.pdf", dpi=600)
    fig.savefig(save_path / f"{figname}.png", dpi=600)
    print(f"Figure saved as {save_path / f'{figname}.pdf'}")

In [None]:
high_cutoff = np.where(diff_multi < 0)[0][-1]
cutoff_value = x_grid_kde[high_cutoff]
print(f"Cut-off value: {10**cutoff_value/lib_counts.sum():.2e}")
xscale = np.diff(x_grid_kde)[0]
nhigh = (log_multi > cutoff_value).sum()
print(f"Barcodes > cutoff: {nhigh}/{len(log_multi)}, {nhigh/len(log_multi)*100:.2f}%")
print(f"Same from denstiy: {dens_multi[high_cutoff:].sum()*xscale* len(log_multi):.1f}")
pexcess = diff_multi[high_cutoff:].sum() * xscale
print(
    f"Excess: {pexcess * len(log_multi):.0f} barcodes, {pexcess*100:.1f}% of multi-bcs"
)
plt.plot(
    x_grid_kde,
    diff_multi,
)
plt.axhline(0, color="k")
plt.axvline(cutoff_value, color="darkorchid")
plt.plot(x_grid_kde[high_cutoff:], diff_multi[high_cutoff:], color="k")
plt.fill_between(x_grid_kde[high_cutoff:], diff_multi[high_cutoff:])