In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
DATA_ROOT = "/Users/blota/Data/brisc"
# If in the lab, set DATA_ROOT to None, we'll find the data with flexiznam

# optional, can be None or the path to arial.ttf:
arial_font_path = None  # "/nemo/lab/znamenskiyp/home/shared/resources/fonts/arial.ttf"

In [None]:
# Imports and setting matplotlib options
import numpy as np
from pathlib import Path
import tifffile as tf

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.font_manager as fm

from brisc.manuscript_analysis import viral_library as virlib
from brisc.manuscript_analysis import start_density_sim as start_sim
from brisc.manuscript_analysis import rabies_cell_counting as rv_count
from brisc.manuscript_analysis import starter_cell_counting as sc_count
from brisc.manuscript_analysis import overview_image
from brisc.manuscript_analysis.utils import get_path, get_output_folder

# set matplotlib options
if arial_font_path is not None:
    arial_prop = fm.FontProperties(fname=arial_font_path)
    plt.rcParams["font.family"] = arial_prop.get_name()
    plt.rcParams.update({"mathtext.default": "regular"})  # make math mode also Arial
    fm.fontManager.addfont(arial_font_path)

matplotlib.rcParams["pdf.fonttype"] = 42  # for pdfs

In [None]:
# Load viral library data
data_path = get_path(
    "barcode_diversity_analysis/collapsed_barcodes/", data_root=DATA_ROOT
)

libraries = {
    "Plasmid library": virlib.load_library_data(data_path, "PBC20", 1, "bowtie"),
    "Virus library": virlib.load_library_data(data_path, "RV35", 2, "bowtie"),
}
libraries_scale = {
    # "2 plates #1": virlib.load_library_data(data_path, "RV31", 2, "bowtie"),
    "2 plates #2": virlib.load_library_data(data_path, "RV32", 2, "bowtie"),
    "12 plates": virlib.load_library_data(data_path, "RV35", 2, "bowtie"),
}

In [None]:
# Load example confocal image
PROJECT = "rabies_barcoding"
MOUSE = "BRYC64.2h"
IMAGE_FILE = "Slide_3_section_1.czi"
confocal_data = (
    get_path(PROJECT, data_root=DATA_ROOT) / MOUSE / "zeiss_confocal" / IMAGE_FILE
)
starter_img_metadata, starter_img = sc_count.load_confocal_image(confocal_data)

mdata = starter_img_metadata["ImageDocument"]["Metadata"]
print(
    f"Microscope: {mdata['Information']['Instrument']['Microscopes']['Microscope']['System']}"
)
scale = mdata["Scaling"]["Items"]["Distance"]
scale = {s["Id"]: s["Value"] * 1e6 for s in scale}
print(f"Scale: {scale}")

In [None]:
# Load bulk rabies images
mcherry_file = (
    get_path(PROJECT, data_root=DATA_ROOT)
    / MOUSE
    / "cellfinder_results_010/registration/downsampled_channel_0.tiff"
)
background_file = (
    get_path(PROJECT, data_root=DATA_ROOT)
    / MOUSE
    / "cellfinder_results_010/registration/downsampled.tiff"
)

mcherry = tf.imread(mcherry_file)
background = tf.imread(background_file)

In [None]:
# Load distances between rabies cells
injection_centers = {
    "BRYC64.2h": np.array([567, 144, 864]),
    "BRYC64.2i": np.array([673, 205, 890]),
}
voxel_distances_sorted, cell_distances_sorted = rv_count.rv_cortical_cell_distances(
    inj_center=injection_centers[MOUSE],
    project="rabies_barcoding",
    mouse=MOUSE,
    processed=DATA_ROOT,
    data_root=DATA_ROOT,
)

In [None]:
# Load max projection of local  vs tail vein injection
recompute_max_proj = False

# Doing the filtering before the projection is a bit slow. Save it
taillocal_projections = get_path(
    "becalia_rabies_barseq/tail_vs_local", data_root=DATA_ROOT
)
mouse_names = dict(tail="BRAC10946.1f", local="BRAC10946.1c")
cell_pos_px = sc_count.load_cell_click_data(DATA_ROOT, return_px=True)
cell_pos_relative = sc_count.load_cell_click_data(DATA_ROOT, relative=True)
cell_pos_um = sc_count.load_cell_click_data(DATA_ROOT)

if recompute_max_proj:
    from skimage.morphology import white_tophat, black_tophat, disk
    from scipy.ndimage import median_filter

    projection_window = np.array(
        [[-630, 630], [-630, 630]]
    )  # part around injection center to keep
    shift = dict(local=[0, 0], tail=[-20, -80])  # to align more border of the brains

    full_size = dict()
    for where, mouse in mouse_names.items():
        red = tf.imread(taillocal_projections / f"{mouse}_injection_site_ch3.tif")
        cyan = tf.imread(taillocal_projections / f"{mouse}_injection_site_ch2.tif")
        print(f"Full size image is {cyan.shape}")
        full_size[where] = np.stack([red, cyan])

    def subtract_background(image, radius=50, light_bg=False):
        str_el = disk(radius)
        if light_bg:
            return black_tophat(image, str_el)
        else:
            return white_tophat(image, str_el)

    projected_images = {}
    for iw, where in enumerate(["local", "tail"]):
        print(f"Projecting {where}")
        # find injection center using cell positions
        cell_pixels = cell_pos_px[where]
        center = np.nanmean(cell_pixels, axis=0)[:2] + shift[where]
        cell_slices = np.array(
            [cell_pixels[:, 2].min(), cell_pixels[:, 2].max()]
        ).astype(int)
        xpart = (center[0] + projection_window[0]).astype(int)
        ypart = (center[1] + projection_window[1]).astype(int)
        img = full_size[where][..., ypart[0] : ypart[1], xpart[0] : xpart[1]]
        print("Median filter image in xy")
        img = median_filter(img, footprint=disk(5), axes=(1, 2))
        print("Median filter image in z")
        img = median_filter(img, footprint=np.ones(3), axes=(0))
        max_proj = np.nanmax(img[:, cell_slices[0] : cell_slices[1], ...], axis=1)
        max_proj = np.moveaxis(max_proj, 0, 2)
        for chan, radius in enumerate([50, 50]):
            print(f"Subtracting background for {chan}")
            max_proj[..., chan] = subtract_background(
                max_proj[..., chan], radius=radius
            )
        target = taillocal_projections / f"{where}_filtered_max_projection.tif"
        print(f"Writing {target}")
        tf.imwrite(target, max_proj)
        projected_images[where] = max_proj
    print("Projection done")
else:
    projected_images = {}
    for where in ["local", "tail"]:
        projected_images[where] = tf.imread(
            taillocal_projections / f"{where}_filtered_max_projection.tif"
        )

In [None]:
import flexiznam as flz

flz.get_processed_path("becalia_rabies_barseq").parent

In [None]:
# Plot Fig.1
fontsize_dict = {"title": 7, "label": 8, "tick": 6, "legend": 6}

line_width = 1.2
line_alpha = 1

cm = 1 / 2.54
fig = plt.figure(figsize=(17.4 * cm, 17.4 * cm), dpi=150)

save_path = get_output_folder(DATA_ROOT)
save_fig = True
figname = "fig1_plasmid_barcoding_schema_library"

# 1) Plot the plasmid and virus abundance histograms
ax_abundance = fig.add_axes([0.08, 0.8, 0.13, 0.13])
im = virlib.plot_barcode_counts_and_percentage(
    libraries,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "dodgerblue",
        "darkorange",
    ],
    ax=ax_abundance,
)

# 2) Plot the plasmid and virus histograms
ax_unique = fig.add_axes([0.33, 0.8, 0.13, 0.13])
im = virlib.plot_unique_label_fraction(
    libraries,
    stride=50,
    max_cells=1e6,
    log_scale=True,
    min_max_percent_unique_range=(0.5, 1.0),
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "dodgerblue",
        "darkorange",
    ],
    ax=ax_unique,
    show_legend=False,
)
ax_unique.set_xticks([1, 1e3, 1e6])

# 3) Plot the virus rescue scaling abundance histograms
ax_scaling = fig.add_axes([0.58, 0.8, 0.13, 0.13])
im = virlib.plot_barcode_counts_and_percentage(
    libraries_scale,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "orchid",
        "darkorchid",
        "darkorange",
    ],
    ax=ax_scaling,
)

ax_scaling_unique = fig.add_axes([0.83, 0.8, 0.13, 0.13])
im = virlib.plot_unique_label_fraction(
    libraries_scale,
    stride=50,
    max_cells=1e4,
    log_scale=True,
    min_max_percent_unique_range=(0.5, 1.0),
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_alpha=line_alpha,
    line_width=line_width,
    colors=[
        "orchid",
        "darkorchid",
        "darkorange",
    ],
    ax=ax_scaling_unique,
    show_legend=False,
)


# Add probability of spread simulation graph
prob_spread_starters = fig.add_axes([0.08, 0.45, 0.13, 0.18])
start_sim.plot_starter_spread_sim(
    ax=prob_spread_starters,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    line_width=line_width,
)

ax_coronal_local_rabies1 = fig.add_axes([0.33, 0.36, 0.39, 0.39])
ax_coronal_local_rabies2 = fig.add_axes([0.32, 0.37, 0.15, 0.15])
rv_count.plot_rv_coronal_slice(
    injection_centers[MOUSE],
    (ax_coronal_local_rabies1, ax_coronal_local_rabies2),
    mcherry,
    background,
)

ax_starter_detection = fig.add_axes([0.74, 0.555, 0.2, 0.15])
sc_count.plot_starter_confocal(ax_starter_detection, starter_img, starter_img_metadata)
overview_image.add_scalebar(
    ax_starter_detection,
    downsample_factor=1,
    pixel_size_um=0.207,
    length_um=20,
    bar_height_px=7,
    margin_px=10,
)
overview_image.print_image_stats(
    "starter_confocal",
    starter_img,
    pixel_size_um=0.207,
    downsample_factor=1,
)

ax_presynaptic_density = fig.add_axes([0.85, 0.445, 0.1, 0.1])
im = rv_count.plot_rabies_density(
    inj_center=injection_centers[MOUSE],
    ax=ax_presynaptic_density,
    label_fontsize=fontsize_dict["label"],
    tick_fontsize=fontsize_dict["tick"],
    processed=get_path("rabies_barcoding", DATA_ROOT).parent,
    linewidth=line_width,
    voxel_distances_sorted=voxel_distances_sorted,
    cell_distances_sorted=cell_distances_sorted,
)


# Tail vein vs Local injection
if True:
    inj_colors = ["yellowgreen", "midnightblue"]
    ax_local = fig.add_axes([0.01, 0.16, 0.17, 0.21])
    ax_tail = fig.add_axes([0.20, 0.16, 0.17, 0.21])
    ax_scatter = fig.add_axes([0.44, 0.21, 0.20, 0.15])
    ax_pairwise_dstion = fig.add_axes([0.72, 0.21, 0.24, 0.15])
    ax_legend = fig.add_axes([0.42, 0.35, 0.4, 0.04])
    sc_count.plot_tail_vs_local_images(
        local_img=projected_images["local"],
        tail_img=projected_images["tail"],
        ax_local=ax_local,
        ax_tail=ax_tail,
        vmax=[450, 550],
        vmin=[100, 0],
        xl=[100, 1260 - 100],
        yl=None,
        scale_size=250,
    )
    ax_local.text(
        0.02,
        0.99,
        "Intracerebral\nAAV-Cre injection",
        color="w",
        transform=ax_local.transAxes,
        fontsize=fontsize_dict["tick"],
        verticalalignment="top",
        horizontalalignment="left",
    )
    ax_tail.text(
        0.02,
        0.99,
        "Intravenous\nAAV-Cre injection",
        color="w",
        transform=ax_tail.transAxes,
        fontsize=fontsize_dict["tick"],
        verticalalignment="top",
        horizontalalignment="left",
    )
    scatters = sc_count.plot_taillocal_scatter(
        cell_pos_relative,
        ax_scatter,
        colors=inj_colors,
        fontsize_dict=fontsize_dict,
        alpha=0.5,
        s=4,
    )
    sc_count.plot_pairwise_dist_distri(
        cell_pos_relative, ax_pairwise_dstion, colors=inj_colors, fontsize_dict=fontsize_dict
    )
    rec_shape = [0.05, 0.1]
    labels = ["Intracerebral", "Intravenous"]
    for ilab, lab in enumerate(labels):
        rec = plt.Rectangle(
            [-rec_shape[0] / 2, ilab - rec_shape[1] / 2],
            rec_shape[0],
            rec_shape[1],
            color=inj_colors[ilab],
        )
        rec_shape = [0.05, 0.1]
        labels = ["Intracerebral", "Intravenous"]
        for ilab, lab in enumerate(labels):
            rec = plt.Rectangle(
                [-rec_shape[0] / 2, ilab - rec_shape[1] / 2],
                rec_shape[0],
                rec_shape[1],
                color=inj_colors[ilab],
            )
            ax_legend.add_artist(rec)
            ax_legend.text(
                rec_shape[0] * 0.7,
                ilab,
                labels[ilab],
                fontsize=fontsize_dict["legend"],
                verticalalignment="center",
            )
        ax_legend.set_frame_on(False)
        ax_legend.xaxis.set_visible(False)
        ax_legend.yaxis.set_visible(False)
        ax_legend.set_xlim(-0.2, 0.5)
        ax_legend.set_ylim(-0.8, 1.8)

if save_fig:
    save_path.mkdir(parents=True, exist_ok=True)
    fig.savefig(
        f"{save_path/figname}.pdf",
        format="pdf",
        dpi=600,
    )
    fig.savefig(
        f"{save_path/figname}.png",
        format="png",
    )
    print(f"Figure saved as {save_path/figname}")

In [None]:
# What happens if we subsample the cells to have the same number for both conditions?
# (answer: Nothing changes)
colors = ["yellowgreen", "midnightblue"]
kwargs = dict(alpha=0.5, s=4)

rng = np.random.default_rng()
ax = plt.subplot(2, 2, 1)

sc_count.plot_taillocal_scatter(cell_pos_relative,
    ax, colors, fontsize_dict,  **kwargs
)
ax = plt.subplot(2, 2, 2)
sc_count.plot_pairwise_dist_distri(cell_pos_relative,
    ax, colors, fontsize_dict
)
n_php = cell_pos_relative["tail"].shape[0]
subsampled = cell_pos_relative["local"]
ax = plt.subplot(2, 2, 3)



subphp = rng.choice(cell_pos_relative["tail"], n_php * 2 // 3)
subsample = rng.choice(cell_pos_relative["local"], n_php * 2 // 3)
print(cell_pos_relative["local"].shape)
click_sub = dict(tail=subphp, local=subsample)
print(click_sub["local"].shape)
sc_count.plot_taillocal_scatter(click_sub,
    ax, colors, fontsize_dict,  **kwargs
)
ax = plt.subplot(2, 2, 4)
sc_count.plot_pairwise_dist_distri(click_sub, ax, colors, fontsize_dict)
n_php = cell_pos_relative["tail"].shape[0]
subsampled = cell_pos_relative["local"]