# Image Segmentation

This is a minimal tutorial on how to use ngio for image segmentation.

## Step 1: Setup

We will first implement a very simple function to segment an image. We will use skimage to do this. 


In [None]:
# Setup a simple segmentation function
import numpy as np
import skimage


def otsu_threshold_segmentation(image: np.ndarray, max_label: int) -> np.ndarray:
    """Simple segmentation using Otsu thresholding."""
    threshold = skimage.filters.threshold_otsu(image)
    binary = image > threshold
    label_image = skimage.measure.label(binary)
    label_image += max_label
    label_image = np.where(binary, label_image, 0)
    return label_image.astype(np.uint32)

## Step 2: Open the OmeZarr container

In [None]:
from pathlib import Path

from ngio import open_ome_zarr_container
from ngio.utils import download_ome_zarr_dataset

# Download the dataset
download_dir = Path(".").absolute().parent.parent / "data"
hcs_path = download_ome_zarr_dataset("CardiomyocyteTiny", download_dir=download_dir)
image_path = hcs_path / "B" / "03" / "0"

# Open the ome-zarr container
ome_zarr = open_ome_zarr_container(image_path)

## Step 3: Segment the image

For this example, we will not segment the image all at once. Instead we will iterate over the image FOVs and segment them one by one.

In [None]:
# First we will need the image object and the FOVs table
image = ome_zarr.get_image()
roi_table = ome_zarr.get_roi_table("FOV_ROI_table")

# Second we need to derive a new label image to use as target for the segmentation

label = ome_zarr.derive_label("new_label", overwrite=True)

max_label = 0  # We will use this to avoid label collisions
for roi in roi_table.rois():
    # Get the image data for the ROI
    image_data = image.get_roi_as_numpy(roi=roi, c=0, axes_order=["z", "y", "x"])

    roi_segmentation = otsu_threshold_segmentation(
        image_data, max_label
    )  # Segment the image

    max_label = roi_segmentation.max()  # Get the max label for the next iteration

    label.set_roi(
        roi=roi, patch=roi_segmentation
    )  # Write the segmentation to the label image

# Step 4: Consolidate the segmentation

The `new_label` has data only at a single resolution lebel. To consolidate the segmentation to all other levels we will 
need to call the `consolidate` method.

In [None]:
label.consolidate()

## Plot the segmentation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap

rand_cmap = np.random.rand(1000, 3)
rand_cmap[0] = 0
rand_cmap = ListedColormap(rand_cmap)

fig, axs = plt.subplots(2, 1, figsize=(8, 4))
axs[0].set_title("Original image")
axs[0].imshow(image.get_as_numpy(c=0, z=1, axes_order=["y", "x"]), cmap="gray")
axs[1].set_title("Final segmentation")
axs[1].imshow(label.get_as_numpy(z=1, axes_order=["y", "x"]), cmap=rand_cmap)
for ax in axs:
    ax.axis("off")
plt.tight_layout()
plt.show()