In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from flexiznam import PARAMETERS

from iss_preprocess.segment import spot_mask_value, count_spots
from iss_preprocess.io import get_processed_path, load_metadata, load_stack
from iss_preprocess.vis import to_rgb
from iss_preprocess.pipeline.segment import get_big_masks

import matplotlib.pyplot as plt
import colorsys

In [None]:
def hsl_palette(categorical_vector, h=[0, 360], c=85, l=65):
    """Make equally spaced palette from HSL circle."""
    # Ensuring all values in categorical vector are unique
    categorical_vector = np.unique(categorical_vector).tolist()

    if np.diff(h) % 360 > 1:
        h[1] = h[1] - 360 / len(categorical_vector)

    hues = np.linspace(start=h[0], stop=h[1], num=len(categorical_vector) + 1)
    hues = hues[:-1] % 360
    hues = hues / 360

    hlc = np.c_[hues, np.full_like(hues, l / 100), np.full_like(hues, c / 100)]
    rgbs = [
        colorsys.hls_to_rgb(hlc[row, 0], hlc[row, 1], hlc[row, 2])
        for row in range(hlc.shape[0])
    ]
    hex = [
        "#%02x%02x%02x"
        % (round(rgb[0] * 255), round(rgb[1] * 255), round(rgb[2] * 255))
        for rgb in rgbs
    ]

    return dict(zip(categorical_vector, hex))

In [None]:
## Loading registered ex vivo data set for ROI 11
processed_path = Path(PARAMETERS["data_root"]["processed"])

data_path = "ccyp_ex-vivo-reg-pilot/BRBQ77.1f/chamber_1"
genes_prefix = "genes_round_1_1"
anchor_prefix = "anchor_1"
mcherry_prefix = "mcherry_1"
tile_coors = (11, 1, 1)

In [None]:
# Get the masks from suite2p in anchor coordinates
masks = np.load(processed_path / data_path / f"masks_{tile_coors[0]}_s2p_old.npy")
big_masks = get_big_masks(data_path, masks, 2.5)

In [None]:
barcode_dot_threshold = 0.15
spot_score_threshold = 0.1
hyb_score_threshold = 0.8
metadata = load_metadata(data_path=data_path)

thresholds = dict(
    genes_round=("spot_score", spot_score_threshold),
    barcode_round=("dot_product_score", barcode_dot_threshold),
)

spot_acquisitions = ["genes_round"]

for hyb in metadata["hybridisation"]:
    if "anchor" not in hyb:
        spot_acquisitions.append(hyb)
        thresholds[hyb] = ("score", hyb_score_threshold)

# get the spots dataframes
spots_dict = dict()
for prefix in spot_acquisitions:
    print(f"Loading {prefix}", flush=True)
    spot_df = pd.read_pickle(
        processed_path / data_path / f"{prefix}_spots_{tile_coors[0]}.pkl"
    )
    filt_col, threshold = thresholds[prefix]
    spot_df = spot_df[spot_df[filt_col] > threshold]
    # modify spots in place
    spots_dict[prefix] = spot_mask_value(big_masks, spot_df)

In [None]:
thresholds = dict(
    genes_round=("spot_score", spot_score_threshold),
    barcode_round=("dot_product_score", barcode_dot_threshold),
)
for hyb in spots_dict:
    if hyb in thresholds:
        # it is genes or barcode
        continue
    thresholds[hyb] = ("score", hyb_score_threshold)

# get the spots dataframes
spots_in_cells = dict()
for prefix, spot_df in spots_dict.items():
    print(f"Doing {prefix}", flush=True)
    grouping_column = "bases" if prefix.startswith("barcode") else "gene"
    cell_df = count_spots(spots=spot_df, grouping_column=grouping_column)
    spots_in_cells[prefix] = cell_df

In [None]:
save_dir = get_processed_path(data_path) / "cells"
save_dir.mkdir(exist_ok=True)

In [None]:
if "genes_round" in spot_acquisitions:
    fused_df = spots_in_cells.pop("genes_round")

In [None]:
for hyb, hyb_df in spots_in_cells.items():
    for gene in hyb_df.columns:
        if gene in fused_df.columns:
            print(f"Replacing {gene} with hybridisation")
            fused_df.pop(gene)
    fused_df = fused_df.join(hyb_df, how="outer")
fused_df[np.isnan(fused_df)] = 0
fused_df = fused_df.astype(int)

In [None]:
genes_df = pd.concat(spots_dict.values(), axis=0)

In [None]:
fused_df.to_pickle(save_dir / f"genes_df_roi{tile_coors[0]}.pkl")

In [None]:
exvivo_stitched = load_stack(
    processed_path
    / data_path
    / "figures"
    / "exvivo_reg"
    / f"{anchor_prefix}_roi{tile_coors[0]}_stitched_genes2func_old.tif"
)

In [None]:
genes_df = pd.concat(
    [spots_dict["genes_round"], spots_dict["hybridisation_1_1"]], axis=0
)
genes_df = genes_df[~genes_df["gene"].str.contains("unassigned")]

unique_genes = genes_df["gene"].unique().tolist()

gene_palette = hsl_palette(unique_genes, c=58, l=68)

In [None]:
imaging_path = (
    processed_path / "ccyp_ex-vivo-reg-pilot/BRBQ77.1f/S20231103/suite2p_rois_1"
)
Fast = []
for plane in range(7):
    this_plane = np.load(imaging_path / f"plane{plane}" / "Fast.npy")
    if plane == 0:
        l = this_plane.shape[1]
    Fast.append(this_plane[:, : l - 1])
Fast = np.concatenate(Fast, axis=0)

In [None]:
xlims = [3500, 4350]
ylims = [4150, 5000]

gene_palette = np.array(
    [
        "mediumorchid",
        "orangered",
        "darkorange",
        "limegreen",
        "aquamarine",
        "deepskyblue",
        "gold",
        "crimson",
        "lightcoral",
        "dodgerblue",
    ]
)
genes_df["gene_cat"] = genes_df["gene"].astype("category")
fig = plt.figure(figsize=(12, 8))
fig.set_facecolor("white")
plt.subplot(2, 3, 2)
im = to_rgb(
    exvivo_stitched[:, :, [6, 0]],
    colors=[[1, 0, 0], [0, 0, 0]],
    vmax=200,
    vmin=[
        80,
    ],
)
plt.imshow(im)
plt.xlim(xlims)
plt.ylim(ylims)
plt.axis("off")

plt.subplot(2, 3, 1)
to_plot = exvivo_stitched[:, :, [8, 7]].astype(float)
to_plot[:, :, 0] /= 400
to_plot[:, :, 1] /= 2000
im = to_rgb(to_plot, colors=[[1, 0, 0], [0, 1, 0]], vmax=1)

plt.imshow(im)

roi_list = [523, 518, 797]  # 523
sz = 125
for i, roi in enumerate(roi_list):
    centre = np.mean(np.argwhere(big_masks == roi), axis=0)
    x = [centre[1] - sz, centre[1] + sz]
    y = [centre[0] + sz, centre[0] - sz]
    # plot a rectangle
    plt.plot([x[0], x[0]], [y[0], y[1]], ":", color="white")
    plt.plot([x[1], x[1]], [y[0], y[1]], ":", color="white")
    plt.plot([x[0], x[1]], [y[0], y[0]], ":", color="white")
    plt.plot([x[0], x[1]], [y[1], y[1]], ":", color="white")
    plt.text(
        centre[1] + sz - 10,
        centre[0] + sz - 10,
        f"Cell {i+1}",
        fontsize=8,
        horizontalalignment="right",
        verticalalignment="top",
        color="white",
    )
plt.xlim(xlims)
plt.ylim(ylims)
plt.axis("off")

ax = plt.subplot(2, 3, 3)
ax.set_facecolor("black")
# filter spots dataframe for those in ROI
spots_in_view = genes_df.loc[
    (genes_df["x"] > xlims[0])
    & (genes_df["x"] < xlims[1])
    & (genes_df["y"] < ylims[1])
    & (genes_df["y"] > ylims[0])
]
plt.scatter(
    spots_in_view["x"],
    spots_in_view["y"],
    c=gene_palette[spots_in_view["gene_cat"].cat.codes % 10],
    s=1,
)
plt.axis("square")
plt.xlim(xlims)
plt.ylim(ylims)
plt.xticks([])
plt.yticks([])

sz = 125
to_plot = exvivo_stitched[:, :, [9, 7]].astype(float)
to_plot[:, :, 1] /= 2000
im = to_rgb(to_plot, colors=[[0, 0, 0], [0, 1, 0]], vmax=1)
for i, roi in enumerate(roi_list):
    centre = np.mean(np.argwhere(big_masks == roi), axis=0)
    ax = plt.subplot(2, 4, i + 5)
    xlims = [centre[1] - sz, centre[1] + sz]
    ylims = [centre[0] - sz, centre[0] + sz]
    # filter spots dataframe for those in ROI
    spots_in_view = genes_df[
        (genes_df["x"] > xlims[0])
        & (genes_df["x"] < xlims[1])
        & (genes_df["y"] < ylims[1])
        & (genes_df["y"] > ylims[0])
    ]
    ax.imshow(exvivo_stitched[:, :, 7], cmap="Greys_r", vmax=2000)
    # ax.imshow(np.ma.masked_where(big_masks == 0, big_masks), cmap=roi_cycle, alpha=0.4)
    overlay = (big_masks == roi).astype(float)
    overlay = np.concatenate([overlay[:, :, None]] * 4, axis=2)
    overlay[:, :, 1:3] = 0
    ax.imshow(overlay, alpha=0.4)
    markers = ["Sst", "Vip", "Cux2", "Slc17a7"]
    colors = ["dodgerblue", "limegreen", "darkorange", "gold"]
    for marker, color in zip(markers, colors):
        x = spots_in_view[spots_in_view["gene"] == marker]["x"]
        y = spots_in_view[spots_in_view["gene"] == marker]["y"]
        ax.scatter(x, y, c=color, s=25, label=marker)
    plt.scatter(
        spots_in_view["x"],
        spots_in_view["y"],
        c=gene_palette[spots_in_view["gene_cat"].cat.codes % 10],
        s=5,
    )
    if i == 2:
        ax.legend()
    plt.title(f"Cell {i+1}", fontsize=8)
    ax.set_xlim(xlims)
    ax.set_ylim(ylims)
    ax.set_xticks([])
    ax.set_yticks([])

plt.subplot(2, 4, 8)
from scipy.stats import zscore

fs = 4.28305
colors = ["dodgerblue", "limegreen", "darkorange"]
for i, (roi, color) in enumerate(zip(roi_list, colors)):
    this_roi = Fast[roi, 16500:17500]
    f0 = np.quantile(this_roi, 0.4)
    this_roi = (this_roi - f0) / f0
    plt.plot(zscore(this_roi) - i * 10, label=f"ROI {roi}", color=color)
    plt.text(
        -50,
        -i * 10,
        f"Cell {i+1}",
        fontsize=8,
        horizontalalignment="right",
        verticalalignment="center",
    )

plt.axis("off")

In [None]:
import matplotlib as mpl
from matplotlib.colors import ListedColormap

# custom colour palette avoiding green, cyan type colours so ROI boundaries are contrastive
roi_palette = [
    "salmon",
    "lightcoral",
    "orange",
    "gold",
    "mediumorchid",
    "violet",
    "magenta",
    "hotpink",
    "dodgerblue",
    "deepskyblue",
]
roi_cycle = ListedColormap(32 * roi_palette, name="roi_cycle", N=256)

In [None]:
# find the centre of each mask in masks
centres = np.array(
    [np.mean(np.argwhere(masks == mask), axis=0) for mask in np.unique(masks)[1:]]
)

In [None]:
# convert relevant channels of exvivo image to RGB, clip values so we can see both the in vivo GCaMP and DAPI nicely
im = exvivo_stitched[:, :, 9]
# to_plot[:, :, 0] = np.clip(to_plot[:, :, 0], a_min=to_plot[:, :, 0].min(), a_max=None)
# to_plot[:, :, 1] = np.clip(
#     to_plot[:, :, 1], a_min=to_plot[3200:4000, 3200:4000, 1].min() / 3, a_max=None
# )
# to_plot[:, :, 0] *= 18
# im = to_rgb(to_plot, colors=[[0, 1, 1], [1, 1, 1]], vmax=1800)

all_rois = genes_df["mask_id"].unique().tolist()[1:]
unique_genes = genes_df["gene"].unique().tolist()

gene_palette = hsl_palette(unique_genes, c=58, l=68)

fig = plt.figure(figsize=(24, 16))
xlims = [3500, 5000]
ylims = [3500, 5000]
# filter spots dataframe for those in ROI
spots_in_view = genes_df.loc[
    (genes_df["x"] > xlims[0])
    & (genes_df["x"] < xlims[1])
    & (genes_df["y"] > ylims[1])
    & (genes_df["y"] < ylims[0])
    & (genes_df["mask_id"] != roi)
]
plt.imshow(im, cmap="Greys_r", vmax=2000)
plt.imshow(np.ma.masked_where(masks == 0, masks), cmap=roi_cycle, alpha=0.4)
plt.scatter(x=genes_df["x"], y=genes_df["y"], c="white", s=6, label="all genes")
genes = genes_df.value_counts("gene")[:20].index
genes = ["Sst", "Vip", "Pvalb", "Slc17a7"]
for gene in genes:
    plt.scatter(
        x=genes_df[genes_df["gene"] == gene]["x"],
        y=genes_df[genes_df["gene"] == gene]["y"],
        s=12,
        label=gene,
    )
for centre, mask in zip(centres, np.unique(masks)[1:]):
    if centre[0] < ylims[0] or centre[0] > ylims[1]:
        continue
    if centre[1] > xlims[1] or centre[1] < xlims[0]:
        continue
    plt.text(centre[1], centre[0], int(mask), color="black", fontsize=12)

plt.xlim(xlims)
plt.ylim(ylims)
plt.axis("off")
plt.tight_layout()
fig.legend()

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(im, cmap="Greys_r", vmax=2000)
plt.imshow(np.ma.masked_where(masks == 0, masks), cmap=roi_cycle, alpha=0.4)
plt.xlim([3500, 5000])
plt.ylim([3500, 5000])