# Plot some connection

Using the last BRAC85 brain, try to plot a few example starter with their presynaptic
cells



In [None]:
# imports and chamber selection
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from flexiznam.config import PARAMETERS
from pathlib import Path
from itertools import cycle
from matplotlib.animation import FuncAnimation

import iss_preprocess as iss

mouse = "BRAC8501.6a"
chamber = "chamber_07"
data_path = f"becalia_rabies_barseq/{mouse}/{chamber}/"
# data_path = 'becalia_rabies_barseq/BRYC64.2j/chamber_12/'

processed_path = iss.io.load.get_processed_path(data_path)
metadata = iss.io.load_metadata(data_path)
ops = iss.io.load_ops(data_path=data_path)
print(data_path)

## Combine barcode spots for error correction

If we error correct separatly, we might have different barcodes, do it with everything
fused together.

In [None]:
import helper

barcode_spots, gmm, all_barcode_spots = helper.get_barcodes(
    data_path,
    mean_intensity_threshold=0.01,
    dot_product_score_threshold=0.2,
    mean_score_threshold=0.75,
)

## Filter barcode spots

We want to remove very low intensity spots or spots with bad dot product score and use
a GMM to fit the remaining spots into good and bad. This is done in `get_barcodes`
here we just plot for sanity

In [None]:
import seaborn as sns

metrics = ["dot_product_score", "spot_score", "mean_intensity", "mean_score"]
d = all_barcode_spots[metrics][:: len(all_barcode_spots) // 5000].copy()
d["labels"] = gmm.predict(d.values)
sns.pairplot(d, hue="labels", plot_kws={"alpha": 0.1})

In [None]:
# Error correct
# that is a bit slow
barcode_spots = iss.call.correct_barcode_sequences(barcode_spots, 2)
barcode_spots.shape

# Gene image

Make a background image using all genes

In [None]:
roi = ops["use_rois"][0]
data_folder = iss.io.get_processed_path(data_path)
genes_spots = pd.read_pickle(data_folder / f"genes_round_spots_{roi}.pkl")
genes_spots.head()

In [None]:
# Remake genes spot dataframe but keeping duplicates
if True:
    from scipy.spatial import KDTree

    def find_nearest_neighbours(points, ids, k):
        # slowish takes about 2minutes per ROI
        # Create a KDTree instance
        tree = KDTree(points)
        
        # For each point, find the k-1 nearest neighbours (-1 because point is included)
        # Limit the search in 50px radius around the point. 
        nearest_neighbours_id = np.zeros_like(ids)
        for i in range(len(points)):
            dist, idx = tree.query(points[i], k, distance_upper_bound=50)
            nearest_neighbours_id[i] = np.median(ids[idx])
        return nearest_neighbours_id
    
    for roi in ops["use_rois"]:
        all_genes = iss.pipeline.stitch.merge_and_align_spots(
            data_path,
            roi,
            spots_prefix="genes_round",
            reg_prefix="genes_round_4_1",
            ref_prefix="genes_round_4_1",
            keep_all_spots=True,
        )
        points = all_genes[["x", "y"]].values
        tiles = {t:i for i, t in enumerate(all_genes["tile"].unique())}
        ids = all_genes["tile"].map(tiles).values
        nearest_neighbours = find_nearest_neighbours(points, ids, 10)
        points2keep = nearest_neighbours == ids
        all_genes = all_genes[points2keep].copy()
        all_genes.to_pickle(processed_path / f"genes_round_spots_{roi}.pkl")


In [None]:
points = all_genes[["x", "y"]].values
tiles = {t:i for i, t in enumerate(all_genes["tile"].unique())}
ids = all_genes["tile"].map(tiles).values
nearest_neighbours = find_nearest_neighbours(points, ids, 10)
points2keep = nearest_neighbours == ids

In [None]:
np.median([0,1])

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
axes[0].scatter(points[:, 0], points[:, 1], c=ids, cmap="prism")
axes[1].scatter(points[:, 0], points[:, 1], c=nearest_neighbours, cmap="prism")
axes[2].scatter(points[:, 0], points[:, 1], c=points2keep, cmap="tab10")
for x in axes:
    x.set_aspect('equal')
    x.inverse_yaxis()
    x.set_xticks([])
    x.set_yticks([])
    x.set_facecolor('black')

In [None]:
# define the list of all genes for later
all_genes = [g for g in genes_spots.gene.unique() if not g.startswith("unassigned")]

In [None]:
# example plot of background genes scatter
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111)
ax.set_aspect("equal")
# colors = cycle(mpl.cm.Accent.colors)
helper.plot_gene_image(ax, genes_spots)
ax.set_xticks([])
ax.set_yticks([])
ax.set_facecolor("black")
ax.legend(ncol=2, loc="lower right", fontsize=5)
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
ax.invert_yaxis()

# Find starter cells

The next few sections are used to manually define the shift between the mcherry channel
and the rabies channel. Then we use the overlay to define which mcherry cells are actual
starter cells.

## Select ROI

In [None]:
# select roi
roi = ops["use_rois"][2]
print(f"Using roi {roi}")

## Load raw data for ROI

In [None]:
st, mask = helper.get_raw_data(data_path, roi, mcherry_channel=2)
# load corresponding manually clicked cells
mcherry_cells = helper.get_mcherry_cells(data_path, roi)

In [None]:
# make a downsampled stack to have a quick look
from skimage.transform import downscale_local_mean

ds_factor = 10
downsampled_st = downscale_local_mean(st, (ds_factor, ds_factor, 1))

vmin = np.percentile(downsampled_st, 0.1, axis=(0, 1))
vmax = np.percentile(downsampled_st, 99.99, axis=(0, 1))
vmin[0] = 1000
vmax[0] = 4000
rgb = iss.vis.to_rgb(
    downsampled_st, colors=([1, 0, 0], [0, 1, 0], [0, 0, 1]), vmin=vmin, vmax=vmax
)

In [None]:
# DEFINE SHIFTS
manual_shift = {4: [10, 20], 5: [10, 20], 14: [220, 50], 6: [50, 30]}

# Notes:
# mCherry in ROI 6 has an annoying rotation. Not ideal

## Plot snippet around each cells

Use to define the shift between mcherry and rabies

In [None]:
# make an axis grid including all cells
nrow = int(np.ceil(np.sqrt(len(mcherry_cells))))
ncol = int(np.ceil(len(mcherry_cells) / nrow))

fig, axes = plt.subplots(nrow, ncol, figsize=(25, 25))
w = 200

for icell, ctr in enumerate(mcherry_cells):
    ax = axes.flatten()[icell]
    if roi in manual_shift:
        sh = manual_shift[roi]
        ctr = np.array(ctr + sh)
    else:
        sh = [0, 0]
    helper.plot_stack_part(ax, ctr, w, st, mask=mask, sh=sh)
    ax.set_title(f"Cell {icell}")

for ax in axes.flatten():
    ax.axis("off")
fig.tight_layout()
fig_folder = iss.io.get_processed_path(data_path) / "figures" / "starter_cells"
fig_folder.mkdir(exist_ok=True)
fig.savefig(fig_folder / f"all_starters_snippet_roi{roi}.png")

## Plot spots on each cells

Used to find which are actual starters

In [None]:
# get spots
genes_spots, roi_barcode_spots = helper.get_spots(data_path, roi, barcode_spots)

In [None]:
# DEFINE ACTUAL STARTER
starter_cells = {4: [0, 1, 2, 3, 4, 5, 8, 9, 13, 15, 19, 32]}

In [None]:
# Plot all manually clicked cells
from itertools import cycle
from matplotlib import cm

# make an axis grid including all cells
nrow = int(np.ceil(np.sqrt(len(mcherry_cells))))
ncol = int(np.ceil(len(mcherry_cells) / nrow))

fig, axes = plt.subplots(nrow, ncol, figsize=(25, 25))
w = 100

set3 = cycle(cm.get_cmap("Set1").colors)
selected_starters = starter_cells[roi]

for icell, ctr in enumerate(mcherry_cells):
    ax = axes.flatten()[icell]
    if roi in manual_shift:
        sh = manual_shift[roi]
        ctr = np.array(ctr + sh)
    else:
        sh = [0, 0]
    helper.plot_stack_part(ax, ctr, w, st, sh=sh, mask=mask)
    valid_barcodes = helper.select_spots(
        roi_barcode_spots, ylim=[ctr[1] - w, ctr[1] + w], xlim=[ctr[0] - w, ctr[0] + w]
    )

    corrected_bases = valid_barcodes.corrected_bases.value_counts()
    corrected_bases.sort_values(ascending=False, inplace=True)
    for iseq, seq in enumerate(corrected_bases.index):
        sp = valid_barcodes[valid_barcodes["corrected_bases"] == seq]
        if (iseq > 15) and (len(sp) < 5):
            label = None
            color = "grey"
        else:
            label = seq
            color = next(set3)
        ax.scatter(
            sp.x - ctr[0] + w,
            sp.y - ctr[1] + w,
            label=label,
            fc=color,
            ec="k",
            s=20,
            alpha=0.9,
        )

    ax.legend(ncols=1, loc="upper left", bbox_to_anchor=(0.95, 0.95))
    ax.set_title(icell, color="darkred" if icell in selected_starters else "black")
for ax in axes.flatten():
    ax.axis("off")
fig.tight_layout()
fig_folder = iss.io.get_processed_path(data_path) / "figures" / "starter_cells"
fig_folder.mkdir(exist_ok=True)
fig.savefig(fig_folder / f"all_starters_snippet_roi{roi}_sequence.png")

# Generate figures

Now that we know which cells are starters, we can plot them.

In [None]:
# get the main sequence for each starter cell
radius = 50

starter_sequences = dict()
w = radius
for roi, starters in starter_cells.items():
    mcherry_cells = helper.get_mcherry_cells(data_path, roi)
    genes_spots, roi_barcode_spots = helper.get_spots(data_path, roi, barcode_spots)

    for icell in starters:
        starter_name = f"starter_{roi}_{icell}"
        center = mcherry_cells[icell]
        xlim = [center[0] - w, center[0] + w]
        ylim = [center[1] - w, center[1] + w]

        valid_barcodes = helper.select_spots(roi_barcode_spots, xlim, ylim).copy()
        valid_barcodes["distance"] = np.sqrt(
            (valid_barcodes.x - center[0]) ** 2 + (valid_barcodes.y - center[1]) ** 2
        )
        if any(valid_barcodes.distance < radius):
            main_seq = (
                valid_barcodes[valid_barcodes.distance < radius]
                .corrected_bases.value_counts()
                .idxmax()
            )
        else:
            main_seq = ""
            print(f"No spots in circle for {starter_name}")
        starter_sequences[starter_name] = main_seq

## Snippet around starters

In [None]:
ROI_TO_DO = 4
STARTER_TO_DO = None
RELOAD = True  # to debug, if raw data has been already loaded
SAVE = True

target_dir = processed_path / "figures" / "starter_cells"
for roi, starters in starter_cells.items():
    if (ROI_TO_DO is not None) and (roi != ROI_TO_DO):
        continue
    print(f"Doing roi {roi}")
    if RELOAD:
        print("Loading raw data")
        st, mask = helper.get_raw_data(data_path, roi, mcherry_channel=2)
    print("Loading spots")
    # load corresponding manually clicked cells
    mcherry_cells = helper.get_mcherry_cells(data_path, roi)
    genes_spots, roi_barcode_spots = helper.get_spots(data_path, roi, barcode_spots)

    for icell in starters:
        starter_name = f"starter_{roi}_{icell}"
        if (STARTER_TO_DO is not None) and (icell != STARTER_TO_DO):
            continue
        print(f"    Doing cell {icell}")
        ctr = mcherry_cells[icell]
        w = 100
        from matplotlib import cm
        from matplotlib.patches import Circle

        if roi in manual_shift:
            sh = manual_shift[roi]
            ctr = np.array(ctr + sh)
        else:
            sh = [0, 0]
        set3 = cycle(cm.get_cmap("Set3").colors)
        xlim = [ctr[0] - w, ctr[0] + w]
        ylim = [ctr[1] - w, ctr[1] + w]
        valid_barcodes = helper.select_spots(roi_barcode_spots, xlim, ylim).copy()
        valid_barcodes.x = valid_barcodes.x - ctr[0] + w
        valid_barcodes.y = valid_barcodes.y - ctr[1] + w

        valid_genes = helper.select_spots(genes_spots, xlim, ylim).copy()
        valid_genes.x = valid_genes.x - ctr[0] + w
        valid_genes.y = valid_genes.y - ctr[1] + w

        main_seq = starter_sequences[starter_name]
        fig = plt.figure(figsize=(15, 7))
        ax0 = fig.add_subplot(1, 3, 1)
        helper.plot_stack_part(
            ax0, ctr, w, st, sh=sh, show_mask=False, show_contours=False
        )

        ax1 = fig.add_subplot(1, 3, 2)
        helper.plot_stack_part(
            ax1, ctr, w, st, sh=sh, show_mask=False, show_contours=False
        )
        helper.plot_gene_image(
            ax1,
            valid_genes,
            s=50,
            ec="k",
            alpha=0.75,
            layers=valid_genes.gene.unique(),
            ok=[],
        )

        ax = fig.add_subplot(1, 3, 3)
        helper.plot_stack_part(ax, ctr, w, st, sh=sh, show_contours=False)
        sequences = True
        if sequences:
            corrected_bases = valid_barcodes.corrected_bases.value_counts()
            corrected_bases.sort_values(ascending=False, inplace=True)
            for iseq, seq in enumerate(corrected_bases.index):
                sp = valid_barcodes[valid_barcodes["corrected_bases"] == seq]
                if (iseq > 12) and (len(sp) < 2):
                    label = None
                    color = "grey"
                else:
                    label = seq
                    color = next(set3)
                ax.scatter(
                    sp.x,
                    sp.y,
                    label=label,
                    color=color,
                    s=50,
                    alpha=1,
                    ec="black" if iseq else "white",
                )
            l = ax.legend(ncols=1, loc="upper left", bbox_to_anchor=(0.95, 0.95))

            for text in l.get_texts():
                if text.get_text() == main_seq:
                    text.set_color("darkred")
        else:
            sc = ax.scatter(
                valid_barcodes.x, valid_barcodes.y, c=valid_barcodes.dot_product_score
            )
            fig.colorbar(sc)

        # Plot the circle
        circle = Circle(
            (w, w),
            radius,
            edgecolor="white",
            facecolor="none",
            linewidth=1,
            ls="--",
            alpha=0.5,
        )
        ax.add_patch(circle)

        labels = ["Raw data", "Genes", "Barcodes"]
        for il, ax in enumerate(fig.axes):
            ax.set_title(labels[il])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlim([0, 2 * w - 1])
            ax.set_ylim([2 * w - 1, 0])
        fig.tight_layout()
        if SAVE:
            target_dir.mkdir(exist_ok=True)
            fig.savefig(target_dir / f"{starter_name}_snippet.png")
            fig.savefig(target_dir / f"{starter_name}_snippet.pdf")
        break

## Presynaptic cells

In [None]:
from scipy.stats import gaussian_kde

ROI_TO_DO = None
STARTER_TO_DO = None
WINDOW = None  # [10000, 5000]

fig, ax = plt.subplots(1, 1)
target_dir = processed_path / "figures" / "starter_cells"
target_dir.mkdir(exist_ok=True, parents=True)
for starter_roi, starters in starter_cells.items():
    # load corresponding manually clicked cells
    mcherry_cells = helper.get_mcherry_cells(data_path, starter_roi)

    for icell in starters:
        starter_name = f"starter_{roi}_{icell}"
        if (STARTER_TO_DO is not None) and (icell != STARTER_TO_DO):
            continue
        print(f"    Doing cell {icell}")
        starter_name = f"starter_{starter_roi}_{icell}"
        main_seq = starter_sequences[starter_name]

        all_spots_this_starter = barcode_spots[
            barcode_spots["corrected_bases"] == main_seq
        ]
        rois = all_spots_this_starter.roi.unique()
        print(
            f"Found {len(all_spots_this_starter)} barcode spots across {len(rois)} rois"
        )
        if WINDOW is None:
            # find the maximum/minimum x and y for genes spots of all rois
            xlims = [np.inf, -np.inf]
            ylims = [np.inf, -np.inf]
            for roi in rois:
                genes_spots = pd.read_pickle(
                    iss.io.get_processed_path(data_path)
                    / f"genes_round_spots_{roi}.pkl"
                )
                xlims = [
                    min(xlims[0], genes_spots.x.min()),
                    max(xlims[1], genes_spots.x.max()),
                ]
                ylims = [
                    min(ylims[0], genes_spots.y.min()),
                    max(ylims[1], genes_spots.y.max()),
                ]
            print(f"    Found xlims {xlims} and ylims {ylims}")
        else:
            xlims = [-WINDOW[0], WINDOW[0]]
            ylims = [-WINDOW[1], WINDOW[1]]
        prop = (ylims[1] - ylims[0]) / (xlims[1] - xlims[0])
        fig.set_size_inches(10, 10 * prop)
        for roi in rois:
            print(f"    Plotting presynaptic cells in roi {roi}")
            ax.clear()
            spots_roi = all_spots_this_starter[all_spots_this_starter.roi == roi]
            genes_spots = pd.read_pickle(
                iss.io.get_processed_path(data_path) / f"genes_round_spots_{roi}.pkl"
            )

            helper.plot_gene_image(ax, genes_spots)
            ax.scatter(
                spots_roi.x,
                spots_roi.y,
                marker="o",
                s=10,
                ec="none",
                fc="red",
                alpha=0.4,
            )
            if roi == starter_roi:
                ax.scatter(
                    mcherry_cells[icell][0],
                    mcherry_cells[icell][1],
                    marker="*",
                    s=50,
                    ec="none",
                    fc="yellow",
                    alpha=1,
                )
            ax.set_facecolor("k")
            ax.set_aspect("equal")
            if WINDOW is not None:
                if len(spots_roi) < 10:
                    print(f"    Not enough spots in roi {roi} to estimate density")
                    x_max_density = np.median(spots_roi.x)
                    y_max_density = np.median(spots_roi.y)
                else:
                    # Estimate the density of the points
                    density = gaussian_kde(spots_roi[["x", "y"]].T)

                    # Generate a grid of points to evaluate the density
                    x_grid, y_grid = np.mgrid[
                        0 : genes_spots.x.max() : 100j, 0 : genes_spots.y.max() : 100j
                    ]
                    positions = np.vstack([x_grid.ravel(), y_grid.ravel()])

                    # Evaluate the density at each point in the grid
                    density_values = density(positions)

                    # Find the index of the maximum density value
                    max_density_index = np.argmax(density_values)

                    # Get the x, y coordinates of the point with the maximum density
                    x_max_density = positions[0, max_density_index]
                    y_max_density = positions[1, max_density_index]
                xlims = [x_max_density - WINDOW[0], x_max_density + WINDOW[0]]
                ylims = [y_max_density - WINDOW[1], y_max_density + WINDOW[1]]
                # Print the x, y coordinates
                print(
                    f"The point with the maximum density is at x={x_max_density}, y={y_max_density}"
                )

            ax.set_xlim(xlims)
            ax.set_ylim(ylims[::-1])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(f"Starter cell {starter_name} - Roi {roi}")
            fig.subplots_adjust(left=0, right=1, top=0.97, bottom=0)
            fig.savefig(
                target_dir / f"{starter_name}_presynaptic_cells_roi{roi}.png", dpi=600
            )
            # fig.savefig(target_dir / f"{starter_name}_presynaptic_cells_roi{roi}.pdf", dpi=600)

In [None]:
density_values.shape

In [None]:
# first plot for the main ROI
for roi in rois:

    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    helper.plot_gene_image(ax, genes_spots)
    ax.scatter(
        spots_roi.x, spots_roi.y, marker="o", s=5, ec="k", fc="darkred", alpha=0.5
    )
    ax.set_facecolor("k")
    ax.set_aspect("equal")
    ax.set_xlim(xlims)
    ax.set_ylim(ylims[::-1])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Roi {roi}")