## Import packages

In [None]:
import scanpy as sc
import numpy as np
# import imageio.v3 as iio

from skimage.io import imread_collection, imsave
from skmisc.loess import loess
from sklearn import decomposition

from umap import UMAP
import leidenalg
import igraph as ig

from matplotlib import pyplot as plt
from pathlib import Path

## Define helper functions

In [None]:
def illustrate_roi_box(
    seq_rgb,
    x_min=400,
    x_max=1000,
    y_min=300,
    y_max=500,
    line_width=5,
):

    # Make a grayscale image from the first color channel of the first frame
    image = np.array(seq_rgb[0], dtype="float32") / 255
    image = np.repeat(np.mean(image, axis=2, keepdims=True), 3, 2)

    # Create a mask for the ROI
    roi = np.zeros(image.shape, dtype=np.int32)
    mask_x = np.logical_and(
        np.arange(roi.shape[1]) > x_min, np.arange(roi.shape[1]) < x_max
    )
    mask_y = np.logical_and(
        np.arange(roi.shape[0]) > y_min, np.arange(roi.shape[0]) < y_max
    )
    roi[mask_y, :, :] += 1
    roi[:, mask_x, :] += 1
    roi = roi > 1

    # Illustrate the mask with a box
    roi_box = np.zeros(image.shape, dtype=np.int32)
    mask_x = np.logical_and(
        np.arange(roi_box.shape[1]) > x_min - line_width,
        np.arange(roi_box.shape[1]) < x_max + line_width,
    )
    mask_y = np.logical_and(
        np.arange(roi_box.shape[0]) > y_min - line_width,
        np.arange(roi_box.shape[0]) < y_max + line_width,
    )
    roi_box[mask_y, :, 0] += 1
    roi_box[:, mask_x, 0] += 1
    roi_box = roi_box - roi
    roi_box = roi_box > 1

    fig, ax = plt.subplots(1, 1, dpi=300)
    plt.imshow(
        image + roi_box.astype("float32"),
    )
    plt.grid(False)
    plt.show()

In [None]:
def hist_compare_std(
    X0,
    X1,
    image_shape,
    labels=["observed", "shuffled"],
    title="",
    dpi=250,
    s=1,
    alpha=0.8,
    image_vmax=100,
    show_dispersion=False,
    dispersion_baseline=50,
):
    _, bins, _ = plt.hist(
        X0.std(axis=0), bins=100, alpha=alpha, label=labels[0]
    )
    plt.hist(
        X1.std(axis=0), bins=bins, alpha=alpha, color="red", label=labels[1]
    )
    plt.xlabel("Standard deviation")
    plt.ylabel("Number of pixels")
    plt.legend()
    plt.title(title)
    plt.show()

    _, ax = plt.subplots(1, 2, dpi=dpi)
    plt.sca(ax[0])
    kwargs_image = dict(cmap="gray", vmax=image_vmax, vmin=0, interpolation="none")
    plt.imshow(
        # 255 * X0.std(axis=0).reshape(image_shape) / X0.std(axis=0).max(),
        X0.std(axis=0).reshape(image_shape),
        **kwargs_image,
    )
    plt.title(f"{labels[0]} std dev")
    plt.sca(ax[1])
    plt.imshow(
        # 255 * X1.std(axis=0).reshape(image_shape) / X1.std(axis=0).max(),
        X1.std(axis=0).reshape(image_shape),
        **kwargs_image,
    )
    plt.title(f"{labels[1]} std dev")
    plt.show()

    kwargs_scatter = dict(s=s, alpha=alpha)
    plt.scatter(
        X0.mean(axis=0),
        X0.std(axis=0),
        color="blue",
        label=labels[0],
        **kwargs_scatter,
    )
    plt.scatter(
        X1.mean(axis=0),
        X1.std(axis=0),
        color="red",
        label=labels[1],
        **kwargs_scatter,
    )
    plt.xlabel("Mean")
    plt.ylabel("Standard deviation")
    plt.legend()
    plt.show()

    if show_dispersion:
        X0_mean = X0.mean(axis=0)
        X0_std = X0.std(axis=0)

        dispersion = X0_std / X0_mean
        dispersion[X0_mean < dispersion_baseline] = 0

        plt.hist(
            dispersion,
            bins=100,
            alpha=alpha,
            label=labels[0],
        )
        plt.xlabel("Dispersion, coefficient of variation")
        plt.show()

        plt.imshow(
            255 * dispersion.reshape(image_shape) / dispersion.max(),
            **kwargs_image,
        )
        plt.title("Dispersion")
        plt.show()


def export_frames(X, path_out, image_shape):
    path_out.mkdir(exist_ok=True)
    for i, X_frame in enumerate(X):
        frame = X_frame.reshape(image_shape)
        frame = frame + 255 / 2  # Approximately center for visualization
        frame[frame < 0] = 0
        frame[frame > 255] = 255
        imsave(path_out / f"{i:03d}.png", frame.astype(np.uint8))


def compare_std(
    X, image_shape, path_out_centered=None, shuffle_method="within_frame", image_vmax=100,
):

    # Pixel standard deviation, observed vs shuffled values
    # Shows there is structure in the data
    X_shuffle = X.copy()
    if shuffle_method == "within_frame":
        [np.random.shuffle(_) for _ in X_shuffle]  # inplace shuffle
    elif shuffle_method == "within_pixel":
        X_shuffle = X_shuffle.T
        [np.random.shuffle(_) for _ in X_shuffle]
        X_shuffle = X_shuffle.T

    hist_compare_std(
        X,
        X_shuffle,
        image_shape=image_shape,
        image_vmax=image_vmax,
        title="Pixel standard deviation",
    )

    # After centering, pixel standard deviation, observed vs shuffled values
    # Shows that the structure is partially but not completely due to the mean
    X_temp = X.copy()
    X_temp -= X_temp.mean(axis=0)

    X_shuffle = X_temp.copy()
    if shuffle_method == "within_frame":
        [np.random.shuffle(_) for _ in X_shuffle]  # inplace shuffle
    elif shuffle_method == "within_pixel":
        X_shuffle = X_shuffle.T
        [np.random.shuffle(_) for _ in X_shuffle]
        X_shuffle = X_shuffle.T

    hist_compare_std(
        X_temp,
        X_shuffle,
        image_shape=image_shape,
        image_vmax=image_vmax,
        title="Pixel standard deviation after centering",
    )

    # export frames
    if path_out_centered is not None and image_shape is not None:
        path_out_centered = Path(path_out_centered)
        export_frames(X_temp, path_out_centered, image_shape)

In [None]:
def average_image(seq):
    img = np.stack(seq, axis=0)
    img = img.mean(axis=0)

    img = img / img.max()
    return img

## Load and preprocess image data


In [None]:
# load png images from data/frames
seq_rgb = imread_collection("../data/frames/*.png", conserve_memory=True)

# subset frames
first_frame = 1
last_frame = 190
seq_rgb = seq_rgb[first_frame:last_frame]

# trim the black border from the top and bottom
border_width = 23
seq_rgb = [im[border_width:-border_width, :, :] for im in seq_rgb]

# save copy before ROI
seq_rgb_full = seq_rgb.copy()
seq_bw_full = [im.mean(axis=2) for im in seq_rgb_full]

# trim to region of interest
seq_rgb = [im[300:500, 400:1000] for im in seq_rgb]

# convert to grayscale
seq_bw = [im.mean(axis=2) for im in seq_rgb]

# flatten into a 2D matrix
X_full = [im.ravel() for im in seq_bw_full]
X_full = np.vstack(X_full)
X = [im.ravel() for im in seq_bw]
X = np.vstack(X)

In [None]:
print("Full image")
print(f"Number of observations: {len(X_full)}")
print(f"Number of features: {np.product(X_full[0].shape):,}")

print("\nROI")
print(f"Number of observations: {len(X)}")
print(f"Number of features: {np.product(X[0].shape):,}")

## View image and ROI


In [None]:
print(X_full.shape)
print(X.shape)

In [None]:
illustrate_roi_box(seq_rgb_full)

# roi in color
plt.imshow(seq_rgb[0])
plt.show()

# roi in bw
fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(seq_bw[0][:, :], cmap="gray", vmax=255, vmin=0)
plt.show()

## View X, the 2-d matrix of video frames * pixels

In [None]:
fig, ax = plt.subplots(1, 1, dpi=300)
plt.grid(False)
plt.imshow(
    X, aspect="auto", cmap="gray", vmax=255, vmin=0, interpolation="none"
)
plt.xlabel("Pixel #")
plt.ylabel("Frame #")
plt.show()

## Characterize variance vs shuffled control

In [None]:
# Fll image version
path_out_centered = "/home/cameron.cowan/data/presentations/iob_retreat_2023/data/frames_bw_full_centered"
image_shape = seq_bw_full[0].shape
compare_std(
    X_full,
    image_shape=image_shape,
    path_out_centered=path_out_centered,
    shuffle_method="within_frame",
)

In [None]:
# ROI version
path_out_centered = "/home/cameron.cowan/data/presentations/iob_retreat_2023/data/frames_bw_roi_centered"
image_shape = seq_bw[0].shape
compare_std(X, image_shape=image_shape, path_out_centered=path_out_centered)

## Shuffle within pixels, use ROI

In [None]:
# Calculate mean and variance of X, then log-transform
Xlog_mean = np.log10(X.mean(axis=0))
Xlog_var = np.log10(X.var(axis=0))

# Fit a LOESS curve to the mean-variance relationship
argsort = np.argsort(Xlog_mean)
i = np.linspace(0, len(Xlog_mean) - 1, 100000).astype(int)
l = loess(Xlog_mean[argsort], Xlog_var[argsort], span=0.5, degree=2)

# Plot the mean-variance relationship and the LOESS curve
plt.scatter(Xlog_mean, Xlog_var, s=0.3, alpha=0.1, label="pixel")
i = np.linspace(0, len(Xlog_mean) - 1, 1000).astype(int)
x0 = Xlog_mean[argsort]
plt.plot(x0, l.predict(x0).values, color="red", label="LOESS prediction")
plt.xlabel("log10(mean)")
plt.ylabel("log10(variance)")
plt.legend()
plt.show()

# Calculate the expected standard deviation from the LOESS curve
Xlog_var_expected = l.predict(Xlog_mean).values
Xlog_std_expected = np.sqrt(Xlog_var_expected)
X_std_expected = 10**Xlog_std_expected

# Calculate the residuals of the mean-variance relationship relative to the prediction
Xlog_var_adj = Xlog_var - Xlog_var_expected

# Plot the residuals of the mean-variance relationship relative to the prediction
plt.scatter(Xlog_mean, Xlog_var_adj, s=0.3, alpha=0.1, label="pixel")
plt.xlabel("Mean")
plt.ylabel("Observed variance - expected variance")
plt.legend()
plt.show()


## Z-score the data

In [None]:
# Z-score with global std dev
Z = (X - X.mean(axis=0)) / X.std()
Z_var = Z.var(axis=0)

# Z-score with expected std dev
Z_adj = (X - X.mean(axis=0)) / X_std_expected
Z_adj_var = Z_adj.var(axis=0)

plt.scatter(X.mean(axis=0), Z_var, s=0.3, alpha=1, label="pixel")
plt.xlabel("Mean")
plt.ylabel("Z, expected std. dev. from mean")
plt.legend()
plt.title('Z-score with observed global variance')
plt.ylim([0,8])
plt.show()

plt.scatter(X.mean(axis=0), Z_adj_var, s=0.3, alpha=1, label="pixel")
plt.xlabel("Mean")
plt.ylabel("Z, expected std. dev. from mean")
plt.legend()
plt.title('Z-score with mean-variance correction')
plt.ylim([0,8])
plt.show()

In [None]:
# path_out_zvar = "/home/cameron.cowan/data/presentations/iob_retreat_2023/data/frames_bw_roi_zvar"
image_shape = seq_bw[0].shape
compare_std(
    Z_adj,
    image_shape=image_shape,
    # path_out_centered=path_out_zvar,
    image_vmax=4,
)

## Feature selection based on Z-score outliers

In [None]:
z_thresh = 4
mask_Z_var =  Z_var > z_thresh
mask_Z_adj_var = Z_adj_var > z_thresh

plt.subplots(dpi=150)
plt.imshow(
    seq_rgb[180],
    interpolation="none",
)
plt.show()

plt.subplots(dpi=150)
plt.imshow(
    mask_Z_var.reshape(seq_bw[0].shape),
    interpolation="none",
    cmap="gray",
)
plt.show()

plt.subplots(dpi=150)
plt.imshow(
    mask_Z_adj_var.reshape(seq_bw[0].shape),
    interpolation="none",
    cmap="gray",
)
plt.show()

In [None]:
plt.subplots(1, 1, dpi=200)
plt.imshow(
    255 * Z_adj_var.reshape(image_shape) / Z_adj_var.max(),
    interpolation="none",
    cmap="gray",
    vmin=0,
    vmax=255,
)
plt.show()

plt.hist(
    Z_adj_var,
    bins=200,
)
plt.yscale("log")
plt.show()


Z_temp = Z_adj_var.copy()
Z_temp[~mask_Z_adj_var] = 0
plt.subplots(1, 1, dpi=200)
plt.imshow(
    Z_temp.reshape(image_shape),
    interpolation="none",
    cmap="viridis",
)
# plt.colorbar()
plt.show()

X_hv = X[:, mask_Z_adj_var].copy()

plt.subplots(figsize=(7,3),dpi=100)
plt.imshow(X_hv.T, interpolation="none", cmap="gray", vmax=255, vmin=0)
plt.xlabel("Time, frame #")
plt.ylabel("High variance pixels")
plt.axis("tight")
plt.show()

## Linear decomposition and dimensionality reduction. PCA

In [None]:
pca = decomposition.PCA(n_components=4)
scores = pca.fit_transform(X_hv)

plt.scatter(scores[:, 0], scores[:, 1], s=30, alpha=0.3, label="pixel", linewidths=0)
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.legend(loc="upper left")
plt.show()

plt.plot(scores[:, 0], scores[:, 1], ".-", alpha=.5)
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("Pixels connected in time")
plt.show()

plt.imshow(pca.components_, interpolation="none", aspect=4)
plt.yticks(range(4), ["PC1", "PC2", "PC3", "PC4"])
plt.xlabel("pixel #")
plt.show()

## Show average image for different PC score ranges

In [None]:
seq_rgb_pc1_high = [
    im for i, im in enumerate(seq_rgb) if scores[i,0] > 400
]
seq_rgb_pc1_low = [
    im for i, im in enumerate(seq_rgb) if scores[i,0] < -200
]
seq_rgb_pc2_low = [
    im for i, im in enumerate(seq_rgb) if scores[i,1] < -200
]

seq_rgb_pc1_high = average_image(seq_rgb_pc1_high)
seq_rgb_pc1_low = average_image(seq_rgb_pc1_low)
seq_rgb_pc2_low = average_image(seq_rgb_pc2_low)


plt.subplots(dpi=200)
plt.imshow(seq_rgb_pc1_high)
plt.show()

plt.subplots(dpi=200)
plt.imshow(seq_rgb_pc1_low)
plt.show()

plt.subplots(dpi=200)
plt.imshow(seq_rgb_pc2_low)
plt.show()



## Non-linear dimensionality reduction. UMAP

In [None]:
reducer = UMAP(n_components=2, n_neighbors=30, force_approximation_algorithm=True, random_state=42)
reducer.fit(X_hv)
umap_scores = reducer.transform(X_hv)

plt.scatter(umap_scores[:, 0], umap_scores[:, 1], s=30, alpha=0.5, label="pixel", linewidths=0)
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.legend(loc="upper left")
plt.show()

plt.plot(umap_scores[:, 0], umap_scores[:, 1], '.-', alpha=0.5)
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.show()

## Visualize nearest neighbor graph underlying the UMAP

In [None]:
n_pixels = X_hv.shape[0]

# Distance matrix, only includes nearest neighbors
A_dist = np.zeros((n_pixels, n_pixels))
for i in np.arange(n_pixels):
    A_dist[i, reducer._knn_indices[i]]=1

# Binary mask showing nearest neighbors
G_binary = 0*np.ones((n_pixels, n_pixels))
for i in np.arange(n_pixels):
    G_binary[i, reducer._knn_indices[i]]=np.log10((1+reducer._knn_dists[i]))

plt.imshow(A_dist, interpolation="none", cmap="viridis", vmax=1)
cb = plt.colorbar()
cb.set_label("Pair are nearest neighbors")
cb.set_ticks([0, 1])
cb.set_ticklabels(["False", "True"])
plt.xlabel("Frame #")
plt.ylabel("Frame #")
plt.show()


plt.imshow(G_binary, interpolation="none", cmap="viridis")
cb = plt.colorbar()
cb.set_label("log10(distance)")
plt.xlabel("Frame #")
plt.ylabel("Frame #")
plt.show()

## Graph-based clustering. Leiden

In [None]:
G = ig.Graph.Weighted_Adjacency(A_dist.tolist(), mode="directed")
part = leidenalg.find_partition(G, leidenalg.RBERVertexPartition, resolution_parameter=.7, seed=42)
clusters = part.membership

plt.scatter(umap_scores[:, 0], umap_scores[:, 1], c=clusters, s=30, alpha=0.8, label="pixel", linewidths=0)
cb = plt.colorbar()
cb.set_ticks(np.unique(clusters))
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.grid(False)
plt.show()


## Visualize average frames for each cluster

In [None]:
seq_rgb_clusters = {}
for cluster in np.unique(clusters):
    seq_rgb_clusters[cluster] = [
        im for j, im in enumerate(seq_rgb) if clusters[j] == cluster
    ]

img_rgb_clusters = {}
for key, seq in seq_rgb_clusters.items():
    img_rgb_clusters[key] = average_image(seq)

for key, img in img_rgb_clusters.items():
    plt.subplots(dpi=200)
    img_above_background = img - (np.mean(seq_rgb, axis=0) / 255)
    img_above_background = img_above_background / img_above_background.max()
    plt.imshow(img_above_background)
    plt.show()

## Demo with scanpy

In [None]:
adata = sc.AnnData(X, obs={"frame": np.arange(len(seq_bw))})
sc.pp.normalize_total(adata, target_sum=1000)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=30, flavor="seurat_v3")
adata.obsm["X_hv"] = adata.X[:, adata.var["highly_variable"]]
sc.pp.neighbors(adata, n_neighbors=10, use_rep="X_hv")
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=0.1)

## View scanpy results

In [None]:
sc.set_figure_params(dpi=300)
sc.pl.highly_variable_genes(adata, show=False)
plt.grid(False)
plt.sca(plt.gcf().axes[0])
plt.gca().get_legend().remove()
plt.grid(False)
plt.gcf().tight_layout()

In [None]:
img_overdispersed = adata.var["highly_variable"].values.reshape(
    seq_bw[0].shape
)

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(
    img_overdispersed, cmap="gray", vmax=1, vmin=0, interpolation="none"
)
plt.grid(False)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(
    adata[:, adata.var["highly_variable"]].X.T,
    cmap="gray",
    interpolation="none",
)
plt.show()

edges = np.linspace(0, 0.04, 20)
fig, ax = plt.subplots(1, 1, dpi=300)
example_good_trace = np.array(
    adata[:, adata.var["highly_variable"]][:, 0].X
).ravel()
plt.hist(example_good_trace, bins=edges)
plt.ylim(0, 55)
plt.gca().set_aspect(0.0001)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
example_bad_trace = np.array(
    adata[:, np.logical_not(adata.var["highly_variable"])][:, 0].X
).ravel()
plt.hist(example_bad_trace, bins=edges)
plt.ylim(0, 110)
plt.gca().set_aspect(0.00005)
plt.show()

In [None]:
adata[:, np.logical_not(adata.var["highly_variable"])].X.shape

In [None]:
sc.pp.pca(adata, n_comps=10, svd_solver="arpack", use_highly_variable=True)
sc.pl.pca(
    adata, color="frame", components=["1,2", "3,4", "5,6", "7,8", "9,10"]
)



In [None]:
sc.set_figure_params(figsize=(9, 9))
sc.pl.umap(adata, color="frame", cmap="viridis")
sc.pl.umap(adata, color="leiden", cmap="viridis")
sc.set_figure_params(figsize=None)

In [None]:
seq_rgb_leiden0 = [
    im for i, im in enumerate(seq_rgb) if adata.obs["leiden"][i] == "0"
]
seq_rgb_leiden1 = [
    im for i, im in enumerate(seq_rgb) if adata.obs["leiden"][i] == "1"
]
seq_rgb_leiden2 = [
    im for i, im in enumerate(seq_rgb) if adata.obs["leiden"][i] == "2"
]


def average_image(seq):
    img = np.stack(seq, axis=0)
    img = img.mean(axis=0)

    img = img / img.max()
    return img


img_leiden0 = average_image(seq_rgb_leiden0)
img_leiden1 = average_image(seq_rgb_leiden1)
img_leiden2 = average_image(seq_rgb_leiden2)

In [None]:
fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden0)
plt.grid(False)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden1)
plt.grid(False)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden2)
plt.grid(False)

In [None]:
plt.rcParams["figure.dpi"] = 200

# Green vs red
fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden2 - img_leiden0)
plt.grid(False)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden0 - img_leiden2)
plt.grid(False)
plt.show()

In [None]:
# Partially green vs partially red
plt.imshow(img_leiden2 - img_leiden1)
plt.grid(False)
plt.show()

plt.imshow(img_leiden0 - img_leiden1)
plt.grid(False)
plt.show()

In [None]:
x_min, x_max = 190, 400
y_min, y_max = 142, 170

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden0 - (img_leiden1 + img_leiden2))
plt.grid(False)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden1 - (img_leiden0 + img_leiden2))
plt.grid(False)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(img_leiden2 - (img_leiden0 + img_leiden1))
plt.grid(False)
plt.show()

# ROI within ROI to illustrate final scanpy results

In [None]:
x_min, x_max = 190, 400
y_min, y_max = 142, 170

roi = np.zeros(img_leiden0.shape, dtype=np.int32)
mask_x = np.logical_and(
    np.arange(roi.shape[1]) > x_min, np.arange(roi.shape[1]) < x_max
)
mask_y = np.logical_and(
    np.arange(roi.shape[0]) > y_min, np.arange(roi.shape[0]) < y_max
)
roi[mask_y, :, :] += 1
roi[:, mask_x, :] += 1
roi = roi > 1

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(
    roi.astype("float32"),
)
plt.grid(False)
plt.show()

# Illustrate the mask with a box
line_width = 5
roi_box = np.zeros(img_leiden0.shape, dtype=np.int32)
mask_x = np.logical_and(
    np.arange(roi_box.shape[1]) > x_min - line_width,
    np.arange(roi_box.shape[1]) < x_max + line_width,
)
mask_y = np.logical_and(
    np.arange(roi_box.shape[0]) > y_min - line_width,
    np.arange(roi_box.shape[0]) < y_max + line_width,
)
roi_box[mask_y, :, 0] += 1
roi_box[:, mask_x, 0] += 1
roi_box = roi_box - roi
roi_box = roi_box > 1

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(
    img_leiden0 + roi_box.astype("float32"),
)
plt.grid(False)
plt.show()

# Apply ROI

roi_leiden0 = img_leiden0[roi].reshape(
    [y_max - y_min - 1, x_max - x_min - 1, 3]
)
roi_leiden1 = img_leiden1[roi].reshape(
    [y_max - y_min - 1, x_max - x_min - 1, 3]
)
roi_leiden2 = img_leiden2[roi].reshape(
    [y_max - y_min - 1, x_max - x_min - 1, 3]
)

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(roi_leiden0 - np.mean((roi_leiden1, roi_leiden2), axis=0))
plt.grid(False)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(roi_leiden1 - np.mean((roi_leiden0, roi_leiden2), axis=0))
plt.grid(False)
plt.show()

fig, ax = plt.subplots(1, 1, dpi=300)
plt.imshow(roi_leiden2 - np.mean((roi_leiden0, roi_leiden1), axis=0))
plt.grid(False)
plt.show()