In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from brisc.manuscript_analysis import viral_library as virlib
from brisc.manuscript_analysis import start_density_sim as start_sim
from brisc.manuscript_analysis import rabies_cell_counting as rv_count
from brisc.manuscript_analysis import starter_cell_counting as sc_count
from brisc.manuscript_analysis import overview_image
from brisc.manuscript_analysis.utils import despine
import numpy as np
from pathlib import Path
import tifffile as tf

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.font_manager as fm

arial_font_path = "/nemo/lab/znamenskiyp/home/shared/resources/fonts/arial.ttf"
arial_prop = fm.FontProperties(fname=arial_font_path)
plt.rcParams["font.family"] = arial_prop.get_name()
plt.rcParams.update({"mathtext.default": "regular"})  # make math mode also Arial
fm.fontManager.addfont(arial_font_path)
matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

import flexiznam as flz

In [None]:
DATA_ROOT = Path("/nemo/lab/znamenskiyp")
# DATA_ROOT = Path("Z:")

data_path = (
    DATA_ROOT / "home/shared/projects/barcode_diversity_analysis/collapsed_barcodes/"
)

libraries = {
    "Plasmid library": virlib.load_library_data(data_path, "PBC20", 1, "bowtie"),
    "Virus library": virlib.load_library_data(data_path, "RV35", 2, "bowtie"),
}

In [None]:
libraries_scale = {
    "2 plates #1": virlib.load_library_data(data_path, "RV31", 2, "bowtie"),
    "2 plates #2": virlib.load_library_data(data_path, "RV32", 2, "bowtie"),
    "12 plates": virlib.load_library_data(data_path, "RV35", 2, "bowtie"),
}

In [None]:
PROJECT = "rabies_barcoding"
MOUSE = "BRYC64.2h"
IMAGE_FILE = "Slide_3_section_1.czi"
PROJECT_ROOT = DATA_ROOT / "home/shared/projects" / PROJECT
confocal_data = PROJECT_ROOT / MOUSE / "zeiss_confocal" / IMAGE_FILE
starter_img_metadata, starter_img = sc_count.load_confocal_image(confocal_data)

In [None]:
mdata = starter_img_metadata["ImageDocument"]["Metadata"]
print(
    f"Microscope: {mdata['Information']['Instrument']['Microscopes']['Microscope']['System']}"
)
scale = mdata["Scaling"]["Items"]["Distance"]
scale = {s["Id"]: s["Value"] * 1e6 for s in scale}
print(f"Scale: {scale}")

In [None]:
MOUSE = "BRYC64.2i"
mcherry_file = (
    PROJECT_ROOT
    / MOUSE
    / "cellfinder_results_010/registration/downsampled_channel_0.tiff"
)
background_file = (
    PROJECT_ROOT / MOUSE / "cellfinder_results_010/registration/downsampled.tiff"
)

mcherry = tf.imread(mcherry_file)
background = tf.imread(background_file)

In [None]:
# Load distances between rabies cells
voxel_distances_sorted, cell_distances_sorted = rv_count.rv_cortical_cell_distances(
    inj_center=np.array([673, 205, 890]),
    project="rabies_barcoding",
    mouse="BRYC64.2i",
    processed=DATA_ROOT / "home/shared/projects",
)

In [None]:
# Load max projection of local  vs tail vein injection
recompute_max_proj = False

# Doing the filtering before the projection is a bit slow. Save it
taillocal_projections = taillocal_projections = flz.get_processed_path(
    "becalia_rabies_barseq/tail_vs_local"
)
mouse_names = dict(tail="BRAC10946.1f", local="BRAC10946.1c")
if recompute_max_proj:
    from skimage.morphology import white_tophat, black_tophat, disk
    from scipy.ndimage import median_filter, gaussian_filter

    projection_window = np.array(
        [[-630, 630], [-630, 630]]
    )  # part around injection center to keep
    shift = dict(local=[0, 0], tail=[-20, -80])  # to align more border of the brains

    full_size = dict()
    for where, mouse in mouse_names.items():
        red = tf.imread(taillocal_projections / f"{mouse}_injection_site_ch3.tif")
        cyan = tf.imread(taillocal_projections / f"{mouse}_injection_site_ch2.tif")
        print(f"Full size image is {cyan.shape}")
        full_size[where] = np.stack([red, cyan])

    def subtract_background(image, radius=50, light_bg=False):
        str_el = disk(radius)
        if light_bg:
            return black_tophat(image, str_el)
        else:
            return white_tophat(image, str_el)

    cell_pos = sc_count.load_cell_click_data(return_px=True)
    projected_images = {}
    for iw, where in enumerate(["local", "tail"]):
        print(f"Projecting {where}")
        # find injection center using cell positions
        cell_pixels = cell_pos[where]
        center = np.nanmean(cell_pixels, axis=0)[:2] + shift[where]
        cell_slices = np.array(
            [cell_pixels[:, 2].min(), cell_pixels[:, 2].max()]
        ).astype(int)
        xpart = (center[0] + projection_window[0]).astype(int)
        ypart = (center[1] + projection_window[1]).astype(int)
        img = full_size[where][..., ypart[0] : ypart[1], xpart[0] : xpart[1]]
        print("Median filter image in xy")
        img = median_filter(img, footprint=disk(5), axes=(1, 2))
        print("Median filter image in z")
        img = median_filter(img, footprint=np.ones(3), axes=(0))
        if False:
            print("Gaussian filter image in xy")
            img = gaussian_filter(img, sigma=(2, 2), axes=(1, 2))
        max_proj = np.nanmax(img[:, cell_slices[0] : cell_slices[1], ...], axis=1)
        max_proj = np.moveaxis(max_proj, 0, 2)
        for chan, radius in enumerate([50, 50]):
            print(f"Subtracting background for {chan}")
            max_proj[..., chan] = subtract_background(
                max_proj[..., chan], radius=radius
            )
        target = taillocal_projections / f"{where}_filtered_max_projection.tif"
        print(f"Writing {target}")
        tf.imwrite(target, max_proj)
        projected_images[where] = max_proj
    print("Projection done")
else:
    projected_images = {}
    for where in ["local", "tail"]:
        projected_images[where] = tf.imread(
            taillocal_projections / f"{where}_filtered_max_projection.tif"
        )

In [None]:
# Plot Fig.1
fontsize_dict = {"title": 7, "label": 8, "tick": 6, "legend": 6}

line_width = 1.2
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=150)

save_path = DATA_ROOT / "home/shared/presentations/becalick_2025/"
save_fig = True
figname = "fig1_plasmid_barcoding_schema_library"

# 1) Plot the plasmid and virus abundance histograms
ax_abundance = fig.add_axes([0.08, 0.8, 0.13, 0.13])
im = virlib.plot_barcode_counts_and_percentage(
    libraries,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "dodgerblue",
        "darkorange",
    ],
    ax=ax_abundance,
)

# 2) Plot the plasmid and virus histograms
ax_unique = fig.add_axes([0.33, 0.8, 0.13, 0.13])
im = virlib.plot_unique_label_fraction(
    libraries,
    stride=50,
    max_cells=1e6,
    log_scale=True,
    min_max_percent_unique_range=(0.5, 1.0),
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "dodgerblue",
        "darkorange",
    ],
    ax=ax_unique,
    show_legend=False,
)
ax_unique.set_xticks([1, 1e3, 1e6])

# 3) Plot the virus rescue scaling abundance histograms
ax_scaling = fig.add_axes([0.58, 0.8, 0.13, 0.13])
im = virlib.plot_barcode_counts_and_percentage(
    libraries_scale,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "orchid",
        "darkorchid",
        "darkorange",
    ],
    ax=ax_scaling,
)

ax_scaling_unique = fig.add_axes([0.83, 0.8, 0.13, 0.13])
im = virlib.plot_unique_label_fraction(
    libraries_scale,
    stride=50,
    max_cells=1e4,
    log_scale=True,
    min_max_percent_unique_range=(0.5, 1.0),
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "orchid",
        "darkorchid",
        "darkorange",
    ],
    ax=ax_scaling_unique,
    show_legend=False,
)


# Add probability of spread simulation graph
prob_spread_starters = fig.add_axes([0.08, 0.45, 0.13, 0.18])
start_sim.plot_starter_spread_sim(
    ax=prob_spread_starters,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

ax_coronal_local_rabies1 = fig.add_axes([0.33, 0.36, 0.40, 0.40])
ax_coronal_local_rabies2 = fig.add_axes([0.32, 0.37, 0.15, 0.15])

rv_count.plot_rv_coronal_slice(
    (ax_coronal_local_rabies1, ax_coronal_local_rabies2), mcherry, background
)

ax_starter_detection = fig.add_axes([0.74, 0.555, 0.2, 0.15])
sc_count.plot_starter_confocal(ax_starter_detection, starter_img, starter_img_metadata)

overview_image.add_scalebar(
    ax_starter_detection,
    downsample_factor=1,
    pixel_size_um=0.207,
    length_um=20,
    bar_height_px=15,
    margin_px=10,
)

overview_image.print_image_stats(
    "starter_confocal",
    starter_img,
    pixel_size_um=0.207,
    downsample_factor=1,
)

ax_presynaptic_density = fig.add_axes([0.85, 0.445, 0.1, 0.1])
im = rv_count.plot_rabies_density(
    ax=ax_presynaptic_density,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    processed=DATA_ROOT / "home/shared/projects",
    linewidth=line_width,
    voxel_distances_sorted=voxel_distances_sorted,
    cell_distances_sorted=cell_distances_sorted,
)

if False:
    # Cell density vs PhP.eB dilution. Not included
    ax_starter_density = fig.add_axes([0.08, 0.20, 0.13, 0.13])
    im = sc_count.plot_starter_dilution_densities(
        ax_starter_density,
        label_fontsize=fontsize_dict["label"],
        tick_fontsize=fontsize_dict["tick"],
        processed=DATA_ROOT / "home/shared/projects",
    )

# Tail vein vs Local injection
inj_colors = ["yellowgreen", "midnightblue"]
ax_local = fig.add_axes([0.01, 0.16, 0.17, 0.21])
ax_tail = fig.add_axes([0.20, 0.16, 0.17, 0.21])
ax_scatter = fig.add_axes([0.44, 0.21, 0.20, 0.15])
ax_pairwise_dstion = fig.add_axes([0.72, 0.21, 0.24, 0.15])
ax_legend = fig.add_axes([0.42, 0.35, 0.4, 0.04])
sc_count.plot_tail_vs_local_images(
    local_img=projected_images["local"],
    tail_img=projected_images["tail"],
    ax_local=ax_local,
    ax_tail=ax_tail,
    vmax=[450, 550],
    vmin=[100, 0],
    xl=[100, 1260 - 100],
    yl=None,
    scale_size=250,
)
ax_local.text(
    0.02,
    0.99,
    "Intracerebral\nAAV-Cre injection",
    color="w",
    transform=ax_local.transAxes,
    fontsize=fontsize_dict["tick"],
    verticalalignment="top",
    horizontalalignment="left",
)
ax_tail.text(
    0.02,
    0.99,
    "Intravenous\nAAV-Cre injection",
    color="w",
    transform=ax_tail.transAxes,
    fontsize=fontsize_dict["tick"],
    verticalalignment="top",
    horizontalalignment="left",
)
scatters = sc_count.plot_taillocal_scatter(
    ax_scatter, colors=inj_colors, fontsize_dict=fontsize_dict, alpha=0.5, s=4
)
sc_count.plot_pairwise_dist_distri(
    ax_pairwise_dstion, colors=inj_colors, fontsize_dict=fontsize_dict
)
rec_shape = [0.05, 0.1]
labels = ["Intracerebral", "Intravenous"]
for ilab, lab in enumerate(labels):
    rec = plt.Rectangle(
        [-rec_shape[0] / 2, ilab - rec_shape[1] / 2],
        rec_shape[0],
        rec_shape[1],
        color=inj_colors[ilab],
    )
    ax_legend.add_artist(rec)
    ax_legend.text(
        rec_shape[0] * 0.7,
        ilab,
        labels[ilab],
        fontsize=fontsize_dict["legend"],
        verticalalignment="center",
    )
ax_legend.set_frame_on(False)
ax_legend.xaxis.set_visible(False)
ax_legend.yaxis.set_visible(False)
ax_legend.set_xlim(-0.2, 0.5)
ax_legend.set_ylim(-0.8, 1.8)

if save_fig:
    save_path.mkdir(parents=True, exist_ok=True)
    fig.savefig(
        f"{save_path/figname}.pdf",
        format="pdf",
        dpi=600,
    )
    fig.savefig(
        f"{save_path/figname}.png",
        format="png",
    )
    print(f"Figure saved as {save_path/figname}")

In [None]:
# What happens if we subsample the cells to have the same number for both conditions?
colors = ["yellowgreen", "midnightblue"]
kwargs = dict(alpha=0.5, s=4)

rng = np.random.default_rng()
ax = plt.subplot(2, 2, 1)
clicked_cells = sc_count.load_cell_click_data(relative=True)
sc_count.plot_taillocal_scatter(
    ax, colors, fontsize_dict, clicked_cells=clicked_cells, **kwargs
)
ax = plt.subplot(2, 2, 2)
sc_count.plot_pairwise_dist_distri(
    ax, colors, fontsize_dict, clicked_cells=clicked_cells
)
n_php = clicked_cells["tail"].shape[0]
subsampled = clicked_cells["local"]
ax = plt.subplot(2, 2, 3)
clicked_cells = sc_count.load_cell_click_data(relative=True)

n_php = clicked_cells["tail"].shape[0]
subphp = rng.choice(clicked_cells["tail"], n_php * 2 // 3)
subsample = rng.choice(clicked_cells["local"], n_php * 2 // 3)
print(clicked_cells["local"].shape)
click_sub = dict(tail=subphp, local=subsample)
print(click_sub["local"].shape)
sc_count.plot_taillocal_scatter(
    ax, colors, fontsize_dict, clicked_cells=click_sub, **kwargs
)
ax = plt.subplot(2, 2, 4)
sc_count.plot_pairwise_dist_distri(ax, colors, fontsize_dict, clicked_cells=click_sub)
n_php = clicked_cells["tail"].shape[0]
subsampled = clicked_cells["local"]

In [None]:
from iss_preprocess.pipeline.ara_registration import load_coordinate_image
from iss_preprocess.io import get_roi_dimensions, load_ops, get_processed_path
import pandas as pd

atlas_coverage_path = "becalia_rabies_barseq/BRAC8498.3e/chamber_07/"

area_images = []

for ch in ["07", "08", "09", "10"]:
    for roi in range(1, 11):
        data_path = f"becalia_rabies_barseq/BRAC8498.3e/chamber_{ch}/"
        area_img = load_coordinate_image(
            data_path,
            roi,
            full_scale=False,
        )
        area_images.append(area_img)

barseq_path = get_processed_path("becalia_rabies_barseq").parent.parent
error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_26"
barcoded_cells_df = pd.read_pickle(
    barseq_path
    / f"processed/becalia_rabies_barseq/BRAC8498.3e/analysis/{error_correction_ds_name}_cell_barcode_df.pkl"
)
barcoded_cells_df = barcoded_cells_df[barcoded_cells_df["all_barcodes"].notnull()]
barcoded_cells = barcoded_cells_df[["ara_x", "ara_y", "ara_z"]].to_numpy()

In [None]:
import numpy as np
from scipy.spatial import ConvexHull
import plotly.graph_objects as go
from pathlib import Path


def collect_plane_pixels(planes, sample_step: int = 10) -> np.ndarray:
    pix = [P[::sample_step, ::sample_step].reshape(-1, 3) for P in planes]
    return np.concatenate(pix, axis=0)


def in_hull(points: np.ndarray, hull: ConvexHull, tol: float = 1e-12) -> np.ndarray:
    """
    Vectorised point-in-hull test for all cells
    Return boolean mask: True if `points` are inside/on `hull`.
    """
    A, b = hull.equations[:, :3], hull.equations[:, 3]
    return np.all(A @ points.T + b[:, None] <= tol, axis=0)


plot_barcoded = False

# Gather pixel points from in situ ara coord images and make a convex hull of the imaged volume
sample_step = 50
pix_coords = collect_plane_pixels(area_images, sample_step)
pix_coords = pix_coords[
    ~np.all(pix_coords == 0, axis=1)
]  # drop (0,0,0) failed atlas points
if len(pix_coords) < 4:
    raise RuntimeError("Not enough valid pixel coordinates to build a hull.")
hull = ConvexHull(pix_coords)
hull_center = pix_coords[hull.vertices].mean(axis=0)
print("Current hull centre:", hull_center)

# Shift the in situ hull so it is centred on the injection site of the 2P data
# 2P 64.2h inj_center in ara coords 8.0, 1.1, 8.2
target_center = np.array([8.0, 1.7307591, 8.533215], dtype=float)
shift_vec = target_center - hull_center
shifted_pix_coords = pix_coords + shift_vec
hull = ConvexHull(shifted_pix_coords)

# load 2P detected rabies cell coords
project = "rabies_barcoding"
mouse = "BRYC64.2h"
processed = Path("/nemo/lab/znamenskiyp/home/shared/projects/")

points_file = processed / project / mouse / "cellfinder_results_010/points/abc4d.npy"
points = np.load(points_file, allow_pickle=True)
points = points[:, :3] * 0.001  # Nx3 array of XYZ coordinates
points

# Find which 2P detect rabies cells are inside the shifted hull
inside_mask = in_hull(points, hull)
inside_cells = points[inside_mask]
outside_cells = points[~inside_mask]

print(f"Cells inside hull : {inside_cells.shape[0]:,}")
print(f"Cells outside hull: {outside_cells.shape[0]:,}")

# Plot – hull + inside/outside cells
# hull mesh
tri = hull.simplices
mesh = go.Mesh3d(
    x=shifted_pix_coords[:, 0],
    y=shifted_pix_coords[:, 1],
    z=shifted_pix_coords[:, 2],
    i=tri[:, 0],
    j=tri[:, 1],
    k=tri[:, 2],
    opacity=0.45,
    color="lightgrey",
    name="Convex hull",
)

inside_scatter = go.Scatter3d(
    x=inside_cells[:, 0],
    y=inside_cells[:, 1],
    z=inside_cells[:, 2],
    mode="markers",
    marker=dict(size=3, color="red", opacity=0.1),
    name=f"Inside cells ({len(inside_cells)})",
)

outside_scatter = go.Scatter3d(
    x=outside_cells[:, 0],
    y=outside_cells[:, 1],
    z=outside_cells[:, 2],
    mode="markers",
    marker=dict(size=3, color="blue", opacity=0.2),
    name=f"Outside cells ({len(outside_cells)})",
)

if plot_barcoded:
    # --- add a green scatter trace for barcoded rabies cells (not shifted)
    barcoded_scatter = go.Scatter3d(
        x=barcoded_cells[:, 0],
        y=barcoded_cells[:, 1],
        z=barcoded_cells[:, 2],
        mode="markers",
        marker=dict(size=3, color="green", opacity=0.01),
        name=f"Barcoded cells ({len(barcoded_cells)})",
    )
    fig = go.Figure(data=[mesh, inside_scatter, outside_scatter, barcoded_scatter])
else:
    fig = go.Figure(data=[mesh, inside_scatter, outside_scatter])
fig.update_layout(
    scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="data"),
    height=800,
    margin=dict(l=0, r=0, t=10, b=0),
)

fig.show()