In [None]:
import colorsys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from flexiznam import PARAMETERS

from iss_preprocess.io import (
    get_processed_path,
    load_metadata,
    load_stack,
)
from iss_preprocess.pipeline.segment import get_big_masks
from iss_preprocess.segment import count_spots, spot_mask_value
from iss_preprocess.vis import to_rgb

In [None]:
def hsl_palette(categorical_vector, h=[0, 360], c=85, l=65):
    """Make equally spaced palette from HSL circle."""
    # Ensuring all values in categorical vector are unique
    categorical_vector = np.unique(categorical_vector).tolist()

    if np.diff(h) % 360 > 1:
        h[1] = h[1] - 360 / len(categorical_vector)

    hues = np.linspace(start=h[0], stop=h[1], num=len(categorical_vector) + 1)
    hues = hues[:-1] % 360
    hues = hues / 360

    hlc = np.c_[hues, np.full_like(hues, l / 100), np.full_like(hues, c / 100)]
    rgbs = [
        colorsys.hls_to_rgb(hlc[row, 0], hlc[row, 1], hlc[row, 2])
        for row in range(hlc.shape[0])
    ]
    hex = [
        "#%02x%02x%02x"
        % (round(rgb[0] * 255), round(rgb[1] * 255), round(rgb[2] * 255))
        for rgb in rgbs
    ]

    return dict(zip(categorical_vector, hex))

In [None]:
## Loading registered ex vivo data set for ROI 11
processed_path = Path(PARAMETERS["data_root"]["processed"])

data_path = "ccyp_ex-vivo-reg-pilot/BRBQ77.1f/chamber_1"
genes_prefix = "genes_round_1_1"
anchor_prefix = "anchor_1"
mcherry_prefix = "mcherry_1"
tile_coors = (12, 1, 1)

In [None]:
# Get the masks from suite2p in anchor coordinates
masks = np.load(processed_path / data_path / f"masks_{tile_coors[0]}_s2p.npy")

In [None]:
barcode_dot_threshold = 0.15
spot_score_threshold = 0.1
hyb_score_threshold = 0.8

big_masks = get_big_masks(data_path, tile_coors[0], masks, 2.5)

metadata = load_metadata(data_path=data_path)

thresholds = dict(
    genes_round=("spot_score", spot_score_threshold),
    barcode_round=("dot_product_score", barcode_dot_threshold),
)

spot_acquisitions = ["genes_round"]

for hyb in metadata["hybridisation"]:
    if "anchor" not in hyb:
        spot_acquisitions.append(hyb)
        thresholds[hyb] = ("score", hyb_score_threshold)

# get the spots dataframes
spots_dict = dict()
for prefix in spot_acquisitions:
    print(f"Loading {prefix}", flush=True)
    spot_df = pd.read_pickle(
        processed_path / data_path / f"{prefix}_spots_{tile_coors[0]}.pkl"
    )
    filt_col, threshold = thresholds[prefix]
    spot_df = spot_df[spot_df[filt_col] > threshold]
    # modify spots in place
    spots_dict[prefix] = spot_mask_value(big_masks, spot_df)

In [None]:
thresholds = dict(
    genes_round=("spot_score", spot_score_threshold),
    barcode_round=("dot_product_score", barcode_dot_threshold),
)
for hyb in spots_dict:
    if hyb in thresholds:
        # it is genes or barcode
        continue
    thresholds[hyb] = ("score", hyb_score_threshold)

# get the spots dataframes
spots_in_cells = dict()
for prefix, spot_df in spots_dict.items():
    print(f"Doing {prefix}", flush=True)
    grouping_column = "bases" if prefix.startswith("barcode") else "gene"
    cell_df = count_spots(spots=spot_df, grouping_column=grouping_column)
    spots_in_cells[prefix] = cell_df

print(spots_in_cells)

In [None]:
save_dir = get_processed_path(data_path) / "cells"
save_dir.mkdir(exist_ok=True)

In [None]:
if "genes_round" in spot_acquisitions:
    fused_df = spots_in_cells.pop("genes_round")

In [None]:
for hyb, hyb_df in spots_in_cells.items():
    for gene in hyb_df.columns:
        if gene in fused_df.columns:
            print(f"Replacing {gene} with hybridisation")
            fused_df.pop(gene)
    fused_df = fused_df.join(hyb_df, how="outer")
fused_df[np.isnan(fused_df)] = 0
fused_df = fused_df.astype(int)

In [None]:
fused_df.to_pickle(save_dir / f"genes_df_roi{tile_coors[0]}.pkl")

In [None]:
exvivo_stitched = load_stack(
    processed_path
    / data_path
    / "figures"
    / "exvivo_reg"
    / f"{anchor_prefix}_roi{tile_coors[0]}_stitched_genes2func.tif"
)

In [None]:
fig = plt.figure(figsize=(12, 8))
ax0 = fig.add_subplot(121)
ax0.imshow(exvivo_stitched[3400:6600, 3200:6400, 9], cmap="Greys_r", vmax=2000)
ax0.imshow(
    np.ma.masked_where(big_masks == 0, big_masks)[3400:6600, 3200:6400],
    cmap="prism",
    alpha=0.4,
)
ax0.set_title("suite2p masks\n expanded 5 px post-transformation")

ax1 = fig.add_subplot(122)
ax1.imshow(exvivo_stitched[3400:6600, 3200:6400, 5], cmap="Greys_r", vmax=300)
ax1.imshow(
    np.ma.masked_where(big_masks == 0, big_masks)[3400:6600, 3200:6400],
    cmap="prism",
    alpha=0.4,
)

In [None]:
genes_df = spots_dict["genes_round"]
genes_df = genes_df[~genes_df["gene"].str.contains("unassigned")]

roi_list = [833, 760, 889, 828, 878, 873]

unique_genes = genes_df["gene"].unique().tolist()

gene_palette = hsl_palette(unique_genes, c=58, l=68)

# convert relevant channels of exvivo image to RGB, clip values so we can see both the in vivo GCaMP and DAPI nicely
to_plot = exvivo_stitched[:, :, [4, 9]]
to_plot[:, :, 0] = np.clip(to_plot[:, :, 0], a_min=to_plot[:, :, 0].min(), a_max=None)
to_plot[:, :, 1] = np.clip(
    to_plot[:, :, 1], a_min=to_plot[3200:4000, 3200:4000, 1].min() / 3, a_max=None
)
to_plot[:, :, 0] *= 18
im = to_rgb(to_plot, colors=[[0, 1, 1], [1, 1, 1]], vmax=1800)

In [None]:
from matplotlib.colors import ListedColormap

# custom colour palette avoiding green, cyan type colours so ROI boundaries are contrastive
roi_palette = [
    "salmon",
    "lightcoral",
    "orange",
    "gold",
    "mediumorchid",
    "violet",
    "magenta",
    "hotpink",
]
roi_cycle = ListedColormap(32 * roi_palette, name="roi_cycle", N=256)

In [None]:
fig = plt.figure(figsize=(24, 16))
for i, roi in enumerate(roi_list):
    centre = np.mean(np.argwhere(big_masks == roi), axis=0)
    ax = fig.add_subplot(2, 3, i + 1)
    xlims = [centre[1] - 200, centre[1] + 200]
    ylims = [centre[0] + 200, centre[0] - 200]
    # filter spots dataframe for those in ROI
    spots_in_view = genes_df.loc[
        (genes_df["x"] > xlims[0])
        & (genes_df["x"] < xlims[1])
        & (genes_df["y"] > ylims[1])
        & (genes_df["y"] < ylims[0])
        & (genes_df["mask_id"] != roi)
    ]
    tmp = genes_df.loc[genes_df["mask_id"] == roi]
    ax.imshow(im)
    ax.imshow(np.ma.masked_where(big_masks == 0, big_masks), cmap=roi_cycle, alpha=0.4)
    ax.scatter(x=spots_in_view["x"], y=spots_in_view["y"], c="white", marker="x")
    for gene in tmp["gene"]:
        ax.scatter(
            x=tmp[tmp["gene"] == gene]["x"],
            y=tmp[tmp["gene"] == gene]["y"],
            c=gene_palette[gene],
            s=12,
        )

    ax.set_xlim(xlims)
    ax.set_ylim(ylims)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"ROI {roi}", fontsize=24)

plt.tight_layout()
fig.legend(labels=gene_palette.keys())

In [None]:
all_rois = genes_df["mask_id"].unique().tolist()[1:]

roi_list = all_rois[:6]

fig = plt.figure(figsize=(24, 16))
for i, roi in enumerate(roi_list):
    centre = np.mean(np.argwhere(big_masks == roi), axis=0)
    ax = fig.add_subplot(2, 3, i + 1)
    xlims = [centre[1] - 200, centre[1] + 200]
    ylims = [centre[0] + 200, centre[0] - 200]
    # filter spots dataframe for those in ROI
    spots_in_view = genes_df.loc[
        (genes_df["x"] > xlims[0])
        & (genes_df["x"] < xlims[1])
        & (genes_df["y"] > ylims[1])
        & (genes_df["y"] < ylims[0])
        & (genes_df["mask_id"] != roi)
    ]
    tmp = genes_df.loc[genes_df["mask_id"] == roi]
    ax.imshow(im)
    ax.imshow(np.ma.masked_where(big_masks == 0, big_masks), cmap=roi_cycle, alpha=0.4)
    ax.scatter(x=spots_in_view["x"], y=spots_in_view["y"], c="white", marker="x")
    for gene in tmp["gene"]:
        ax.scatter(
            x=tmp[tmp["gene"] == gene]["x"],
            y=tmp[tmp["gene"] == gene]["y"],
            c=gene_palette[gene],
            s=12,
        )

    ax.set_xlim(xlims)
    ax.set_ylim(ylims)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"ROI {roi}", fontsize=24)

plt.tight_layout()
fig.legend(labels=gene_palette.keys())