# mCherry detection

The detection works in three steps:

- **Preprocessing**: Unmix the background and signal channels.
- **Detection**: Use a simple thresholding to detect the mCherry signal.
- **Filter**: Remove all masks that do not look like a cell using extracted features.

First let's load the libraries and the data.

# Loading data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import iss_preprocess as iss

In [None]:
# Set the path to the data
data_path = "becalia_rabies_barseq/BRAC8498.3e/chamber_10/"
prefix = "mCherry_1"

# Tile to use to plot example data, should contain mCherry signal
example_tile = [2, 1, 5]

In [None]:
short_prefix = prefix.split("_")[0].lower()
target_folder = iss.io.get_processed_path(data_path) / "cells" / f"{prefix}_cells"
target_folder.mkdir(exist_ok=True, parents=True)

ops = iss.io.load_ops(data_path)

In [None]:
tile_coors = ops[f"{short_prefix}_ref_tiles"]

if tile_coors == "random":
    roi_dim = iss.io.get_roi_dimensions(data_path)
    rng = np.random.default_rng(seed=42)
    tile_coors = []
    for i in range(5):
        roi_index = rng.integers(roi_dim.shape[0])
        x = rng.integers(roi_dim[roi_index, 1] + 1)
        y = rng.integers(roi_dim[roi_index, 2] + 1)
        roi = roi_dim[roi_index, 0]
        tile_coors.append((roi, x, y))

if isinstance(tile_coors[0], int):
    tile_coors = [tile_coors]
background_channel = ops[f"{short_prefix}_background_channel"]
signal_channel = ops[f"{short_prefix}_signal_channel"]
projection = ops[f"{short_prefix}_projection"]

signal = []
background = []
for tile in tile_coors[::-1]:
    print(f"Loading tile {tile}")
    stack, _ = iss.pipeline.load_and_register_tile(
        data_path,
        tile_coors=tile,
        prefix=prefix,
        filter_r=False,
        projection=projection,
        correct_illumination=True,
    )
    # we have a 4d image with "nrounds" as last dim, we only need the first
    stack = stack[..., 0]
    signal.append(stack[..., signal_channel])
    background.append(stack[..., background_channel])

# Preprocessing: unmixing the channels

In [None]:
signal_image = np.vstack(signal)
tile_id = np.vstack([np.zeros(stack.shape[:2]) + i for i in range(len(signal))])
background_image = np.vstack(background)
plt.subplot(1, 4, 1)
plt.imshow(signal_image, vmin=10, vmax=50)
plt.title("Signal")
plt.subplot(1, 4, 2)
plt.imshow(background_image, vmin=10, vmax=50)
plt.title("Background")
plt.subplot(1, 4, 3)
plt.imshow(tile_id, cmap="tab10", vmax=10, interpolation="none")
plt.title("Tile ID")
plt.subplot(1, 4, 4)
plt.imshow(signal_image / background_image, cmap="inferno", vmax=10)
plt.title("Ratio")

for x in plt.gcf().get_axes():
    x.axis("off")

In [None]:
# Display old fit or calculate new fit
REFIT = False
if REFIT:
    from iss_preprocess.image.correction import calculate_unmixing_coefficient

    ops = iss.io.load_ops(data_path)
    print("Parameters for unmixing:")
    print(f"Threshold background: {ops['unmixing_threshold_background']}")
    print(f"Threshold signal: {ops['unmixing_threshold_signal']}")
    print(f"Background fudge coef: {ops['unmixing_background_coef']}")
    pure_signal, coef, intercept, valid_pixel = calculate_unmixing_coefficient(
        signal_image=signal_image,
        background_image=background_image,
        background_coef=ops["unmixing_background_coef"],
        threshold_background=ops["unmixing_threshold_background"],
        threshold_signal=ops["unmixing_threshold_signal"],
    )
else:
    seg_folder = iss.io.get_processed_path(data_path) / "cells" / f"{prefix}_cells"
    unmix_param = np.load(seg_folder / f"unmixing_coef_{short_prefix}.npz")
    coef = unmix_param["coef"]
    intercept = unmix_param["intercept"]
    threshold_background = (ops["unmixing_threshold_background"],)
    threshold_signal = (ops["unmixing_threshold_signal"],)
    mixed_signal_flat = signal_image.ravel()
    background_flat = background_image.ravel()
    valid_pixel = (background_flat < 4090) & (mixed_signal_flat < 4090)
    if threshold_signal is None:
        threshold_signal = np.percentile(mixed_signal_flat[valid_pixel], 99)
    valid_pixel &= mixed_signal_flat < threshold_signal
    valid_pixel &= background_flat > 0  # to that before for the median
    if threshold_background is None:
        threshold_background = np.nanmedian(background_flat[valid_pixel])
    valid_pixel &= background_flat > threshold_background
    # also remove pixels that are pure signal
    valid_pixel &= mixed_signal_flat < 20 * background_flat

In [None]:
# Check if the tiles are similar
tile_coors = ops[f"{short_prefix}_ref_tiles"]
if isinstance(tile_coors[0], int):
    tile_coors = [tile_coors]
ntiles = len(tile_coors)
mixed_signal_flat = signal_image.ravel()
background_flat = background_image.ravel()
tile_id_flat = tile_id.ravel()
maxs = [np.nanmax(x) for x in [background_flat, mixed_signal_flat]] + [1500]
lims = [0, min(maxs)]

if False:
    # Optionally if you think there might be one tile throwing off the rest, you can plot
    # the tiles individually

    if ntiles == 1:
        print("Only one tile, no need to check")
    else:
        fig, axes = plt.subplots(
            1, ntiles, figsize=(ntiles * 5, 5), sharex=True, sharey=True
        )
        for i in range(ntiles):
            valid = tile_id_flat == i
            step = np.sum(valid) // 100000
            axes[i].scatter(
                background_flat[valid][::step],
                mixed_signal_flat[valid][::step],
                s=1,
                c=tile_id_flat[valid][::step],
                cmap="tab10",
                vmax=10,
                vmin=0,
            )

            axes[i].set_title(f"Tile {i}: {tile_coors[i]}")
        for x in axes:
            x.set_aspect("equal")
            x.set_xlim(lims)
            x.set_ylim(lims)

In [None]:
# plot linear regression
plt.figure()
plt.subplot(1, 1, 1, aspect="equal")
plt.scatter(
    background_flat[::100],
    mixed_signal_flat[::100],
    s=1,
    c=tile_id_flat[::100],
    cmap="tab10",
    vmax=10,
)
plt.scatter(
    background_flat[valid_pixel][::100],
    mixed_signal_flat[valid_pixel][::100],
    s=1,
    color="k",
    alpha=0.5,
)
x = np.arange(background_flat.max())
plt.plot(x, x * coef + intercept, color="red")
plt.xlabel("Background")
plt.ylabel("Signal")
plt.title("Linear Regression")
plt.text(
    0.5,
    0.9,
    f"y = {coef:.2f}x + {intercept:.2f}",
    horizontalalignment="center",
    verticalalignment="center",
    transform=plt.gca().transAxes,
)
plt.xlim(0, 400)
plt.ylim(0, 400)
plt.show()

In [None]:
(roi, tilex, tiley) = example_tile
unmixed_image, mixed_stack = iss.pipeline.segment.unmix_tile(
    data_path, "mCherry_1", (roi, tilex, tiley)
)
rgb = iss.vis.to_rgb(
    np.dstack([unmixed_image, mixed_stack]),
    colors=[(1, 0, 0), (0, 0, 1), (0, 1, 0)],
    vmax=[100, 100, 100],
)
fig = plt.figure(figsize=(20, 15))
plt.subplot(2, 2, 1)
plt.imshow(rgb)
plt.axis("off")
for i in range(3):
    valid_img = np.zeros_like(rgb)
    valid_img[:, :, i] = rgb[:, :, i]
    ax = plt.subplot(2, 2, i + 2)
    ax.imshow(valid_img)
    plt.axis("off")
plt.tight_layout()
print("Red: unmixed mCherry signal, Blue: Mixed signal, Green: Background")
if False:
    for x in fig.axes:
        x.set_xlim(1500, 3000)
        x.set_ylim(0, 1500)

# Segmentation: 

## Filter and Threshold

The bright cells have an annoying halo around them. To remove that we do a difference
of gaussian, but overweight the larger filter that we subtrack.
Then we need to set a threshold that is low enough to include all cells

### Filtering

In [None]:
import cv2

ops = iss.io.load_ops(data_path)

# filter
r1 = ops["mcherry_r1"]
r2 = ops["mcherry_r2"]
r1 = 9
r2 = 31

th = ops["mcherry_detection_threshold"]
th = 2

print(f"Using radii {r1} and {r2}")
kernel_size = int(np.ceil(3 * r2) * 2 + 1)
center = cv2.GaussianBlur(unmixed_image, (kernel_size, kernel_size), r1)
surround = cv2.GaussianBlur(unmixed_image, (kernel_size, kernel_size), r2)
filt = center - 1.5 * surround

fig, axes = plt.subplots(2, 3, figsize=(21, 14))
ax = axes[0, 0]
im = ax.imshow(rgb)
ax.set_title("RGB")
ax.axis("off")

ax = axes[0, 1]
im = ax.imshow(unmixed_image, interpolation="none", cmap="viridis", vmax=50)
ax.set_title("Unmixed")
ax.axis("off")

ax = axes[1, 0]
ax.imshow(center, cmap="inferno", vmax=50)
ax.set_title("Center")
ax.axis("off")
ax = axes[1, 1]
ax.imshow(surround, cmap="inferno", vmax=50)
ax.set_title("Surround")
ax.axis("off")
ax = axes[0, 2]
cax, cb = iss.vis.plot_matrix_with_colorbar(
    filt.astype(float), ax, vmin=0, vmax=th * 2, cmap="coolwarm"
)
cax.axhline(th, color="k", linestyle="-")
ax.set_title("Filtered = Center - 1.5*Surround")
ax.axis("off")
# ax.set_xticks([])
# ax.set_yticks([])
if False:
    for x in axes.ravel():
        x.set_xlim(1500, 3000)
        x.set_ylim(1500, 0)
        x.set_xticks([])
        x.set_yticks([])
fig.tight_layout()

### Thresholding

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(20, 20))
axes[0, 0].hist(filt.ravel(), bins=np.arange(0, 200, 1), log=True, color="dodgerblue")
axes[0, 0].axvline(th, color="k", linestyle="--")


axes[0, 1].imshow(filt, vmax=th * 3, cmap="inferno", vmin=0)
axes[0, 1].set_title("Filtered Image")

binary = filt > th

axes[1, 0].set_title("Binary Image")
axes[1, 0].imshow(np.array(binary))
from skimage.morphology import binary_closing

footprint = int(2.5 / iss.io.get_pixel_size(data_path))
binary = binary_closing(binary, footprint=np.ones((footprint, footprint)))

axes[1, 1].set_title("After morphological closing")
axes[1, 1].imshow(binary)


for ax in axes.ravel()[1:]:
    ax.set_xlim(1500, 3000)
    ax.set_ylim(1500, 0)

In [None]:
labeled_image, props_df = iss.segment.cells.label_bin_image(
    binary, mixed_stack, roi, tilex, tiley
)

print("Filtering masks")
filtered_masks, filtered_df, rejected_masks = iss.pipeline.segment._filter_masks(
    ops, props_df, labeled_image
)

kwargs = dict(cmap="tab20", interpolation="none", vmin=0, vmax=20)
plt.figure(figsize=(20, 20))
plt.subplot(1, 1, 1)
plt.imshow(unmixed_image, cmap="inferno", vmax=150)
plt.contour(rejected_masks > 0, colors="r", levels=[0.5])
plt.contour(filtered_masks > 0, colors="g", levels=[0.5])
plt.axis("off")

if False:
    for ax in plt.gcf().get_axes():
        ax.set_xlim(1500, 3000)
        ax.set_ylim(1500, 0)

In [None]:
unmixed_image, mixed_stack = iss.pipeline.segment.unmix_tile(
    data_path, "mCherry_1", (roi, tilex, tiley)
)
fmasks, filtered_df, rejected_masks = iss.pipeline.segment.segment_mcherry_tile(
    data_path,
    "mCherry_1",
    roi,
    tilex,
    tiley,
)


plt.figure(figsize=(20, 16))
plt.subplot(2, 2, 1)
plt.imshow(mixed_stack[..., 0], vmax=200)
plt.contour(fmasks > 0, levels=[0.5], colors="r")
plt.title("Raw mCherry image")
plt.subplot(2, 2, 2)
plt.imshow(fmasks % 20, cmap="tab20", interpolation="none")
plt.title("Segmented cells")
plt.subplot(2, 2, 3)
rgb = iss.vis.to_rgb(mixed_stack, colors=[(1, 0, 0), (0, 1, 0)], vmax=[200, 200])
plt.imshow(rgb)
plt.contour(fmasks > 0, levels=[0.5], colors="green")
plt.title("Raw stack, 2 channels")
plt.subplot(2, 2, 4)
plt.imshow(unmixed_image, vmax=200, cmap="inferno")
plt.contour(fmasks > 0, levels=[0.5], colors="green")
plt.title("Unmixed image")
if False:
    for ax in plt.gcf().axes:
        ax.set_xlim(1500, 3000)
        ax.set_ylim(1500, 0)

plt.tight_layout()

# Filtering masks

After the initial detection, we filter things that are not cells.
Here is what it does:

In [None]:
from skimage import measure

labeled_image = measure.label(binary)
props = measure.regionprops_table(
    labeled_image,
    intensity_image=mixed_stack,
    properties=(
        "label",
        "area",
        "centroid",
        "eccentricity",
        "major_axis_length",
        "minor_axis_length",
        "intensity_max",
        "intensity_mean",
        "intensity_min",
        "perimeter",
        "solidity",
    ),
)

props_df = pd.DataFrame(props)
props_df["circularity"] = 4 * np.pi * props_df["area"] / (props_df["perimeter"] ** 2)
# unmixed_image has two channels, signal and background
props_df["intensity_ratio"] = (
    props_df["intensity_mean-0"] / props_df["intensity_mean-1"]
)
plt.imshow(labeled_image % 20, cmap="tab20", interpolation="none")
plt.xticks([])
plt.yticks([])
plt.suptitle("Segmented cells - before filtering")

In [None]:
print((1 / iss.io.get_pixel_size(data_path)) ** 2)
min_thresholds = dict(area=500)  # ,solidity=0.9, )
max_thresholds = dict()  # solidity=1)#, area=5000, eccentricity=0.99)
max_thresholds["intensity_mean-1"] = 50
all_unvalid = pd.Series(False, index=props_df.index)

fig, axes = plt.subplots(
    3, len(min_thresholds) + len(max_thresholds) + 1, figsize=(20, 15)
)
for icol, (prop, val) in enumerate(min_thresholds.items()):
    unvalid = props_df[prop] <= val
    all_unvalid |= unvalid
    valid_img = np.empty(labeled_image.shape) + np.nan
    rejected_img = np.empty(labeled_image.shape) + np.nan
    img_lvl = np.empty(labeled_image.shape) + np.nan
    for index, unvalid in unvalid.items():
        label = props_df.loc[index].label
        if unvalid:
            rejected_img[labeled_image == label] = label
        else:
            valid_img[labeled_image == label] = label

        img_lvl[labeled_image == label] = props_df.loc[index][prop]
    axes[0, icol].imshow(valid_img % 20, cmap="tab20", interpolation="none")
    axes[0, icol].set_title(f"{prop} > {val}")
    axes[1, icol].imshow(rejected_img % 20, cmap="tab20", interpolation="none")
    axes[1, icol].set_title(f"{prop} <= {val}")
    im = axes[2, icol].imshow(
        img_lvl, cmap="viridis", vmax=min(val * 2, img_lvl.max()), interpolation="none"
    )
    cb = plt.colorbar(im, ax=axes[2, icol])
    axes[2, icol].set_title(f"{prop}")
for icol, (prop, val) in enumerate(max_thresholds.items()):
    unvalid = props_df[prop] >= val
    all_unvalid |= unvalid
    icol += len(min_thresholds)
    valid_img = np.empty(labeled_image.shape) + np.nan
    rejected_img = np.empty(labeled_image.shape) + np.nan
    img_lvl = np.empty(labeled_image.shape) + np.nan
    for index, unvalid in unvalid.items():
        label = props_df.loc[index].label
        if unvalid:
            rejected_img[labeled_image == label] = label
        else:
            valid_img[labeled_image == label] = label

        img_lvl[labeled_image == label] = props_df.loc[index][prop]
    axes[0, icol].imshow(valid_img % 20, cmap="tab20", interpolation="none")
    axes[0, icol].set_title(f"{prop} <= {val}")
    axes[1, icol].imshow(rejected_img % 20, cmap="tab20", interpolation="none")
    axes[1, icol].set_title(f"{prop} > {val}")
    im = axes[2, icol].imshow(
        img_lvl,
        cmap="viridis",
        vmin=max(val / 2, img_lvl.min()),
        vmax=min(val * 2, img_lvl.max()),
        interpolation="none",
    )
    cb = plt.colorbar(im, ax=axes[2, icol])
    axes[2, icol].set_title(f"{prop}")

valid_img = np.empty(labeled_image.shape) + np.nan
rejected_img = np.empty(labeled_image.shape) + np.nan
for index, unvalid in all_unvalid.items():
    label = props_df.loc[index].label
    if unvalid:
        rejected_img[labeled_image == label] = label
    else:
        valid_img[labeled_image == label] = label
axes[0, icol + 1].imshow(valid_img % 20, cmap="tab20", interpolation="none")
axes[0, icol + 1].set_title("All Valid")
axes[1, icol + 1].imshow(rejected_img % 20, cmap="tab20", interpolation="none")
axes[1, icol + 1].set_title("All Rejected")


axes[0, 0].set_ylabel("Valid")
axes[1, 0].set_ylabel("Rejected")
for x in axes.ravel():
    x.set_xticks([])
    x.set_yticks([])

# [optional] Use GMM to cluster the cells

After the intial filtering, cell detection can be improved using a GMM.

As we need more data to train the GMM, we will load the features of all tiles. Please
run segmentation for all tiles before running this.

In [None]:
processed_path = iss.io.get_processed_path(data_path)
mask_dir = processed_path / "cells" / f"{prefix}_cells"
roi_dims = iss.io.get_roi_dimensions(data_path)
all_df = []
for roi, nx, ny in roi_dims:
    nx, ny = int(nx) + 1, int(ny) + 1
    for tilex in range(nx):
        for tiley in range(ny):
            df_file = mask_dir / f"{prefix}_df_{roi}_{tilex}_{tiley}.pkl"
            if not df_file.exists():
                print(f"Missing {df_file}")
                continue
            df = pd.read_pickle(df_file)
            all_df.append(df)
print(f"Loaded {len(all_df)} dataframes")
df = pd.concat(all_df)
print(f"Total cells: {len(df)}")

In [None]:
# scale the features

df["clamped_ratio"] = np.log10(df["intensity_ratio"].clip(0, 100))
features = [
    "area",
    "circularity",
    "solidity",
    "intensity_mean-1",
    "clamped_ratio",
]
# Not used: 'solidity' 'major_axis_length', 'minor_axis_length' #  'solidity',
# 'intensity_mean-0' ,'eccentricity',

df.dropna(
    subset=features,
    inplace=True,
)

# We define some clusters:
px_size = iss.io.get_pixel_size(data_path)

intentity_th = np.nanpercentile(df["intensity_mean-1"], [10, 50, 70])
area_th = np.nanpercentile(df["area"], [90, 99, 10])
# Cells with mid area, high circularity, high solidity, high intensity_mean-0 and low intensity_mean-1
cell_initial_guess = [(10 / px_size) ** 2, 0.9, 1, intentity_th[0], 1]
# Debris cluster for things that are too bright in background
# Two clusters: one for round like cells, one for less round and less bright
debris_initial_guess = [(4 / px_size) ** 2, 0.9, 0.9, intentity_th[1], 0.3]
debris2_initial_guess = [(4 / px_size) ** 2, 0.95, 1, intentity_th[2], 0.3]
initial_unscaled = np.vstack(
    [cell_initial_guess, debris_initial_guess, debris2_initial_guess]
)

n_components = initial_unscaled.shape[0]
from sklearn.mixture import GaussianMixture

print(f"GMM with {n_components} clusters")
gmm = GaussianMixture(
    n_components=n_components,
    means_init=initial_unscaled,
    random_state=42,
    verbose=2,
)

# Fit the model
# gmm.fit(df_scaled_features[features])
gmm.fit(df[features])

# Predict the cluster labels
# labels = gmm.predict(df_scaled_features[features])
labels = gmm.predict(df[features])

cluster_centers = gmm.means_


df["cluster_label"] = labels
image_df = df[(df["roi"] == roi) & (df["tilex"] == tilex) & (df["tiley"] == tiley)]
df.cluster_label.value_counts()

In [None]:
# plot scatter of clusters
import seaborn as sns

pairplot_fig = sns.pairplot(
    df[["cluster_label"] + features],
    diag_kind="hist",
    hue="cluster_label",
    palette={i: f"C{i}" for i in range(n_components)},
    plot_kws={"s": 5, "alpha": 0.3},
)

# overlay the cluster centers on the pairplot
axes = pairplot_fig.axes
feature_names = features
for i, feature_i in enumerate(feature_names):
    for j, feature_j in enumerate(feature_names):
        # if i != j:
        # Only plot on the off-diagonal plots
        for ic, center in enumerate(cluster_centers):
            axes[i, j].scatter(
                center[j], center[i], c=f"C{ic}", s=50, edgecolors="black"
            )
        for ic, center in enumerate(initial_unscaled):
            axes[i, j].scatter(
                center[j], center[i], s=50, facecolors="none", edgecolors=f"C{ic}"
            )

plt.show()

In [None]:
roi = 5
labels = np.arange(n_components)
d = df.query(f"roi == {roi}")
print("Number of cells per tile:")
v = d.groupby(["tiley", "tilex"]).size().unstack().fillna(0)
print(v)
maxval = v.max().max()

nx, ny = d[["tilex", "tiley"]].max().values + 1


for label in labels:
    ax = plt.subplot(1, n_components, label + 1, aspect="equal")
    d = df.query(f"roi == {roi} and cluster_label == {label}")
    npertile = d.groupby(["tiley", "tilex"]).size().unstack().fillna(0)
    img = np.zeros((ny, nx))
    for r, row in npertile.iterrows():
        img[r, row.index] = row.values
    ax.imshow(img, cmap="inferno", vmin=0)
    ax.set_title(f"Cluster {label}")

In [None]:
# Label masks by the cluster label in the example tile
roi, tilex, tiley = 5, 1, 4
tile_df = df.query(f"roi == {roi} & tilex == {tilex} & tiley == {tiley}")
print(f"Number of cells: {len(tile_df)}")
mask_file = mask_dir / f"{prefix}_masks_{roi}_{tilex}_{tiley}.npy"
labeled_image = np.load(mask_file)
cluster_img = np.empty(labeled_image.shape) + np.nan

umixed, mixed = iss.pipeline.segment.unmix_tile(data_path, prefix, (roi, tilex, tiley))
rgb = iss.vis.to_rgb(mixed, colors=[(1, 0, 0), (0, 1, 0)], vmax=[100, 100])
for index, row in tile_df.iterrows():
    label = row.label
    cluster = row.cluster_label
    cluster_img[labeled_image == label] = cluster
plt.figure(figsize=(25, 20))
plt.subplot(2, 1, 1)
plt.imshow(rgb)
m = cluster_img + 1
m[np.isnan(cluster_img)] = 0
if len(tile_df) > 0:
    plt.contour(
        m,
        levels=np.arange(n_components) + 0.5,
        colors=[f"C{i}" for i in range(n_components)],
        linewidths=1,
    )
plt.subplot(2, 1, 2)
plt.imshow(cluster_img, cmap="tab10", interpolation="none", vmax=10)

for x in plt.gcf().axes:
    if False:
        x.axis("off")
        x.set_ylim(500, 0)
        x.set_xlim(2500, 3000)

# Run it all

To run the whole pipeline use the following cell

In [None]:
if False:
    for chamber in [f"chamber_{i:02}" for i in [10]]:
        print(f"Processing {chamber}")
        data_path = f"becalia_rabies_barseq/BRAC8498.3e/{chamber}/"
        iss.pipeline.segment_and_stitch_mcherry_cells(data_path, prefix)

# Display example ROI

The following cells displays an example ROI with the detected cells before and 
after filtering.

In [None]:
chambers = [f"chamber_{i:02}" for i in range(7, 11)]
chamber = chambers[0]
data_path = f"becalia_rabies_barseq/BRAC8498.3e/{chamber}/"
# Try stitching masks
roi = 5
# first get the mcherry mask
print(data_path)
ops = iss.io.load_ops(data_path)
print("Stitching mCherry masks")
mcherry_stitched_raw = iss.pipeline.stitch_registered(
    data_path,
    prefix="mCherry_1_masks",
    roi=roi,
    projection="",
    # channels=ops['mcherry_signal_channel'],
)
print("Stitching mCherry masks corrected")
mcherry_stitched_corr = iss.pipeline.stitch_registered(
    data_path,
    prefix="mCherry_1_masks",
    roi=roi,
    projection="corrected",
    # channels=ops['mcherry_signal_channel'],
)
print("Stitching mCherry image")
mcherry_stitched_img = iss.pipeline.stitch_registered(
    data_path,
    prefix="mCherry_1",
    roi=roi,
    channels=ops["mcherry_signal_channel"],
)

In [None]:
import cv2

fig, axes = plt.subplots(1, 2, figsize=(20, 10))
small_img = cv2.resize(
    mcherry_stitched_img[..., 0], (0, 0), fx=0.2, fy=0.2, interpolation=cv2.INTER_LINEAR
)
for iax, w in enumerate([mcherry_stitched_raw, mcherry_stitched_corr]):
    small = cv2.resize(w, (0, 0), fx=0.2, fy=0.2, interpolation=cv2.INTER_NEAREST)
    small = small.astype(float)
    small[small == 0] = np.nan
    axes[iax].imshow(small_img, cmap="inferno", interpolation="none", vmax=100)
    axes[iax].imshow(small % 20, cmap="tab20", interpolation="none")
    axes[iax].contour((~np.isnan(small)).astype(int), levels=[0.5], colors="g")
    axes[iax].set_title(f"Stitched mCherry masks {['raw', 'corrected'][iax]}")
    # axes[iax].axis("off")
    axes[iax].set_xlim(2000, 3500)
    axes[iax].set_ylim(3500, 2000)
fig.tight_layout()