In [None]:
import os

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
mpl.rcParams["figure.dpi"] = 300
from skimage.io import imread

plt.style.use("dark_background")
plt.rcParams["figure.figsize"] = [12, 8]
plt.rcParams["figure.dpi"] = 100  # 200 e.g. is really fine, but slower
plt.style.use("classic")
plt.style.use("seaborn-white")

plt.rcParams["figure.figsize"] = [12, 8]
plt.rcParams["figure.dpi"] = 100  # 200 e.g. is really fine, but slower

mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"] = 42
plt.rcParams["font.family"] = "Arial"

import re

import pandas as pd
import scanpy as sc

sc.set_figure_params(dpi=200)


import seaborn as sns
import skimage
from morphometrics.explore.cluster import cluster_features
from morphometrics.explore.dimensionality_reduction import pca
from morphometrics.utils.anndata_utils import table_to_anndata
from skimage.measure import label
from tqdm import tqdm

rng = np.random.default_rng(42)
import scipy.spatial.distance as distance


def colorFader(
    c1, c2, mix=0
):  # fade (linear interpolate) from color c1 (at mix=0) to c2 (mix=1)
    c1 = np.array(mpl.colors.to_rgb(c1))
    c2 = np.array(mpl.colors.to_rgb(c2))
    return mpl.colors.to_hex((1 - mix) * c1 + mix * c2)


cm = 1 / 2.54  # centimeters in inches

In [None]:
measurement_data = sc.read_h5ad("anndatas/morphometrics_actin_ECM_perturbation.h5ad")

In [None]:
measurement_data

In [None]:
sc.pl.umap(
    measurement_data,
    color="perturbation",
    size=14,
    title="",
    frameon=False,
    legend_fontsize="x-small",
    palette={
        "Matrigel": "#17ad97",
        "Agarose": "#98d9d1",
        "No matrix": "#4d4d4d",
    },
)

In [None]:
sc.set_figure_params(dpi=200, vector_friendly=False)
import matplotlib

# Run cluster age
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42
plt.rcParams["font.family"] = "Arial"
sc.pl.paga(
    measurement_data,
    color=["Day"],
    title="",
    node_size_scale=7,
    threshold=0.1,
    frameon=False,
)

In [None]:
color_midnight_blue = matplotlib.colors.LinearSegmentedColormap.from_list(
    "", ["#f7f7f7", "#191970"]
)

sc.pl.umap(
    measurement_data,
    color="Axis_length_ratio_raw",
    size=15,
    title="",
    frameon=False,
    legend_fontsize="x-small",
    cmap=color_midnight_blue,
)

In [None]:
import met_brewer

colors_edge = met_brewer.met_brew(name="Austria", n=3, brew_type="continuous")
cmap_edge_colors = matplotlib.colors.ListedColormap(
    colors_edge, name="from_list", N=None
)
black = "#000000"

import met_brewer

colors = met_brewer.met_brew(
    name="Johnson",
    n=len(np.unique(measurement_data.obs["leiden"])),
    brew_type="continuous",
)
cmap_brewer_image = matplotlib.colors.ListedColormap(
    ["#000000"] + colors, name="from_list", N=None
)
cmap_brewer_umap = matplotlib.colors.ListedColormap(colors, name="from_list", N=None)

from matplotlib.colors import LinearSegmentedColormap

cmap_colors = LinearSegmentedColormap.from_list(
    name="color", colors=[colors_edge[1], colors_edge[2]]
)
# Create color mixes for colors of clusters
colors[4] = colors_edge[2]
colors[2] = colors_edge[1]
colors[6] = colorFader(colors[2], black, mix=0.25)
colors[3] = colorFader(colors[2], black, mix=0.4)
colors[0] = colorFader(colors[4], black, mix=0.3)
colors[7] = colorFader(colors[4], black, mix=0.5)
colors[5] = colorFader(colors[4], black, mix=0.4)
colors[1] = colorFader(colors[4], black, mix=0.2)
# colors[8]=colorFader(colors[2],black,mix=0.5)
cm = 1 / 2.54  # centimeters in inches

cmap_brewer_image = matplotlib.colors.ListedColormap(
    ["#000000"] + colors, name="from_list", N=None
)
cmap_brewer_image

hue_order = [2, 6, 3, 7, 5, 0, 1, 4]
colors_reordered = list(np.array(colors)[hue_order])
mpl.rcParams.update(mpl.rcParamsDefault)
cmap_brewer_image = matplotlib.colors.ListedColormap(
    ["#000000"] + colors, name="from_list", N=None
)

In [None]:
# color by clustering
sc.set_figure_params(dpi=200)

sc.pl.umap(
    measurement_data,
    color="leiden",
    palette=colors,
    size=14,
    title="",
    frameon=False,
    legend_fontsize="x-small",
)

In [None]:
sc.set_figure_params(dpi=200, vector_friendly=False)

sc.pl.paga(
    measurement_data,
    color=["leiden"],
    title="",
    node_size_scale=7,
    threshold=0.1,
    frameon=False,
)

In [None]:
histo_time = pd.DataFrame()
histo_time["Day"] = np.array(measurement_data.obs["Day"]).astype(int)
histo_time["leiden"] = np.array(measurement_data.obs["leiden"]).astype(int)
histo_time["perturbation"] = np.array(measurement_data.obs["perturbation"]).astype(str)
histo_time["organoid"] = np.array(measurement_data.obs["organoid"]).astype(str)

# Create cluster abundance plots
for perturbation in ["Matrigel", "No matrix", "Agarose"]:
    # for perturbation in ['12']:
    sns.set_theme(
        style="white", rc={"axes.facecolor": (0, 0, 0, 0), "axes.linewidth": 2}
    )
    import matplotlib

    matplotlib.rcParams["pdf.fonttype"] = 42
    matplotlib.rcParams["ps.fonttype"] = 42

    histo_time_perturb = histo_time[histo_time["perturbation"] == perturbation]
    histo_time_perturb["percentage"] = (
        1 / histo_time_perturb.groupby(["Day"]).transform("count")["perturbation"]
    )
    fig, ax = plt.subplots(figsize=(12 * cm, 6 * cm))
    sns.despine(left=True, bottom=True, right=True)

    plot = sns.histplot(
        histo_time_perturb,
        x="Day",
        weights="percentage",
        palette=colors_reordered,
        ax=ax,
        hue="leiden",
        multiple="stack",
        discrete=True,
        legend=False,
        hue_order=hue_order,
    )

    plt.title(perturbation)
    plt.setp(ax.collections, alpha=0.7)

In [None]:
# False color on white background
positions_names = ["Matrigel", "Agarose", "No matrix"]
positions = ["2", "11", "13"]
times = [4, 6, 9]


def first_non_min_value(arr, axis, invalid_val=np.nan):
    mask = arr != arr.min()
    return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)


channel = "GFP"
color = "leiden"
seg_directory = "/3D_Brain_organoids_half_res_morphometrics/"
input_dir = "/3D_one_image_per_day_AGAR_all_all_06_02_2023/"

In [None]:
for position, positions_name in zip(positions, positions_names):
    print(color)
    for time_point in tqdm(times):

        mask_name = f"image_{channel}_{(time_point-4)*24+1}_{position}.tif"
        mask = imread(f"{input_dir}/predictions/{mask_name}")
        image = imread(
            f"{input_dir}/images/image_{channel}_{(time_point-4)*24+1}_{position}.tif"
        )  # +imread(f'{input_dir}/images/image_mCherry_{(time_point-4)*24+1}_{position}.tif')
        adata_subset = measurement_data[
            measurement_data.obs["organoid"] == int(position)
        ]
        adata_subset = adata_subset[adata_subset.obs["Day"] == time_point]
        adata_subset.obs["leiden"] = np.array(adata_subset.obs["leiden"]).astype(int)
        cmap_brewer_image_clusters = matplotlib.colors.ListedColormap(
            ["#000000"]
            + list(
                np.array(colors)[
                    np.unique(np.array(adata_subset.obs["leiden"]).astype(int))
                ]
            ),
            name="from_list",
            N=None,
        )
        print(cmap_brewer_image_clusters)
        array = mask.astype(np.float32)
        from_values = np.unique(mask)
        adata_subset_2 = adata_subset.copy()
        if color == "leiden":
            to_values = np.zeros(from_values.shape) - 1
        else:
            to_values = (
                np.zeros(from_values.shape)
                + np.array([adata_subset_2[:, color].X.min()])[0]
                - 1
            )

        for value in np.unique(adata_subset_2.obs_names):
            one_cell = adata_subset_2[str(value),]
            if color == "leiden":
                to_values[from_values == one_cell.obs["label"][0]] = one_cell.obs[
                    color
                ][0]
            else:
                to_values[from_values == one_cell.obs["label"][0]] = one_cell[
                    str(value), color
                ].X[0, 0]
        sort_idx = np.argsort(from_values)
        idx = np.searchsorted(from_values, array, sorter=sort_idx)
        out = to_values[sort_idx][idx]
        first_non_zero = first_non_min_value(out, 0)
        projection_array = np.zeros(first_non_zero.shape) - out.min()

        for x in tqdm(range(out.shape[1])):
            for y in range(out.shape[2]):
                z = first_non_zero[x, y]
                if np.isnan(z):
                    projection_array[x, y] = -1
                else:
                    projection_array[x, y] = out[int(z), x, y]
        # image=image*(projection_array>projection_array.min())

        cmap_brewer_image_clusters = matplotlib.colors.ListedColormap(
            ["#ffffff"]
            + list(np.array(colors)[np.unique(projection_array)[1:].astype(int)]),
            name="from_list",
            N=None,
        )
        to_values = np.arange(-1, len(np.unique(projection_array)[1:]))
        sort_idx = np.argsort(np.unique(projection_array))
        idx = np.searchsorted(
            np.unique(projection_array), projection_array, sorter=sort_idx
        )
        out = to_values[sort_idx][idx]
        dpi = mpl.rcParams["figure.dpi"]
        dpi = 100

        fig = plt.figure(figsize=(image.shape[1] / dpi, image.shape[2] / dpi))
        # plt.title(f"Projection {color} of {marker}")
        fig.tight_layout()
        ax.axis("off")
        ax = fig.add_axes([0, 0, 1, 1])

        image = (out > out.min()) * image
        # ax.imshow(image.max(0).clip(0,np.percentile(image.max(0),99.6)),cmap='gray')
        # ax.axis('off')

        ax.imshow(
            out, cmap=cmap_brewer_image_clusters, alpha=0.8, interpolation="nearest"
        )
        ax.axis("off")

        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["left"].set_visible(False)
        """
        plt.savefig(
            f"figures/morphotypes_over_time/all_clusters_day_{time_point}_{positions_name}.png",
            pad_inches=0,
            bbox_inches="tight",
            dpi=100,
        )
        """
        plt.show()