# ERK Band Extraction Pipeline

This notebook implements the ERK band extraction methodology described in the paper. The pipeline:

1. **Extracts ERK signaling images** from specific timepoints across multiple fields of view (FOVs)
2. **Creates manual annotations** of ERK signaling zones using Napari
3. **Applies distance-based segmentation** to define outer and inner ERK bands based on measured distances
4. **Maps ERK band information** to single-cell tracking data
5. **Saves results** as both image labels and parquet files for downstream analysis

The ERK bands are defined based on the intersection distance between pSMAD and pERK signals, providing a biologically relevant spatial framework for analyzing ERK signaling patterns.


In [None]:
import os
import natsort
import glob


def get_project_path():
    """
    Returns the project path based on the operating system.

    Returns:
        str: Network path on Windows, mount path on Linux/Mac
    """
    if os.name == "nt":
        # Windows network path to the imaging data server
        return "\\\\izbkingston.izb.unibe.ch\\imaging.data\\PertzLab\\StemCellProject\\"
    else:
        # Unix/Linux mount path
        return "/mnt/imaging.data/PertzLab/StemCellProject"


def get_output_path():
    """
    Returns the output path for the analysed data.

    Returns:
        str: Full path to the experiment's analyzed data directory
    """
    project_path = get_project_path()
    # Define the specific experiment directory structure
    output_path_parts = [
        "20240609_20xConf_Colonies_E6_bFGF_BMP4_10minInterval_ExpPAOLO",
        "1stPart_liveImaging_Confocal1plane",
        "Analysed_Data",
    ]
    return os.path.join(project_path, *output_path_parts)


def get_fovs(output_path: str = get_output_path()):
    """
    Returns a sorted list of FOV names from the output path.

    Args:
        output_path: Path to search for FOV directories

    Returns:
        list: Naturally sorted list of FOV names (e.g., ['FOV_0', 'FOV_1', ...])
    """
    # Find all directories matching the FOV pattern
    fovs = glob.glob(os.path.join(output_path, "FOV_*"))
    fovsname = []
    for fov in fovs:
        if os.path.isdir(fov):
            fovsname.append(os.path.basename(fov))
    # Use natural sorting to ensure proper numeric order (FOV_1, FOV_2, ..., FOV_10)
    return natsort.natsorted(fovsname)


# Configuration constants
FRAME_TO_DRAW_ERK_BAND = 36  # Timepoint at which ERK bands are most clearly visible
ERK_CHANNEL = 2  # Channel index for ERK reporter (0-based indexing)

## Step 1: Extract ERK Images and Create Manual Annotations

This section extracts ERK channel images from a specific timepoint across all FOVs and opens them in Napari for manual annotation of ERK signaling zones.

Load one specific frame for all fovs of the experiment

### Extract ERK Images from Specific Timepoint

We extract ERK channel images from frame 36 (when ERK bands are most clearly visible) across all FOVs in the experiment. This creates a stack of images suitable for manual annotation in Napari.

In [None]:
import napari


def extract_frame_from_all_fovs(
    frame: int = FRAME_TO_DRAW_ERK_BAND, erk_channel: int = ERK_CHANNEL
):
    """
    Extract a specific frame and channel from all FOVs using ome_zarr.reader.

    Args:
        frame: Timepoint to extract (default: 36)
        erk_channel: Channel index for ERK reporter (default: 2)

    Returns:
        dask.array: Stack of ERK images with shape (n_fovs, height, width)
    """
    import dask.array as da
    from ome_zarr.io import parse_url
    from ome_zarr.reader import Reader

    image_data = []
    output_path = get_output_path()

    # Loop through all available FOVs
    for fov in get_fovs():
        # Construct path to OME-Zarr dataset
        url = os.path.join(output_path, fov)
        reader = Reader(parse_url(url))
        nodes = list(reader())

        # Extract the desired frame and ERK channel
        # nodes[0].data[0] contains the image data with shape (t, c, y, x)
        image_data.append(nodes[0].data[0][frame, erk_channel, :, :])

    # Stack all images into a single Dask array
    image_data = da.stack(image_data, axis=0)
    return image_data


# Extract ERK images and open in Napari for visualization
image_data = extract_frame_from_all_fovs(FRAME_TO_DRAW_ERK_BAND, ERK_CHANNEL)
napari.view_image(image_data)

### Manual Annotation Protocol

The extracted images are downscaled 5x to facilitate manual annotation. In Napari, the following annotation scheme is used:

- **Label 1**: Background regions (to be set to 0)
- **Label 2**: Outer ERK zone (high ERK activity)
- **Label 3**: Inner zone (low ERK activity)
- **Label 4**: Inner ERK zone (regions where ERK activity can be observed again)

The manually annotated masks are saved as `frame36_segmentation.tif`. These initial annotations are then processed to create standardized ERK bands based on the measured distance between pSMAD and pERK signals (56 μm, as determined in Figure 5 of the paper).

In [None]:
# Configuration constants for distance-based segmentation
CONVERSION_FACTOR_PIXEL_UM = 0.649  # Conversion factor from pixels to micrometers
DISTANCE_ERK = 56  # ERK band width in micrometers (from Figure 5)
MASK_SCALING = 0.2  # Scaling factor applied to original images (5x downscaling)
MASK_FILE = "frame36_segmentation.tif"  # File containing manually annotated masks

import tifffile

# Load the manually annotated segmentation masks
segmentation_tot = tifffile.imread(os.path.join(get_output_path(), MASK_FILE))

In [59]:
import numpy as np
from scipy.ndimage import distance_transform_edt, binary_fill_holes

new_segmentation = []

# Preprocessing: Clean up the annotation labels
segmentation_tot[segmentation_tot == 1] = 0  # Set background to 0
segmentation_tot[segmentation_tot == 4] = 3  # Merge inner ERK zones (labels 3 and 4)

# Process each FOV individually
for i in range(segmentation_tot.shape[0]):
    segmentation = segmentation_tot[i]

    # Create a working copy for ring mask generation
    segmentation_ring = segmentation.copy()
    segmentation_ring[segmentation_ring == 1] = 0  # Remove background
    segmentation_ring[segmentation_ring == 3] = 2  # Merge all ERK zones
    segmentation_ring[segmentation_ring == 4] = 0  # Remove any remaining label 4

    # Create a binary mask of all ERK-positive regions
    ring_mask = segmentation_ring > 0

    # Fill holes in the ring to create a solid mask for distance calculation
    ring_mask = binary_fill_holes(ring_mask)

    # Calculate distance transform from the outer edge inward
    # This gives us the distance from each pixel to the nearest edge
    distance_from_outer = distance_transform_edt(ring_mask)

    # Calculate the fixed distance threshold in pixels
    # This represents the biologically relevant ERK band width
    fixed_distance = DISTANCE_ERK / CONVERSION_FACTOR_PIXEL_UM * MASK_SCALING

    # Create masks based on distance from outer edge
    mask_outer = (distance_from_outer <= fixed_distance) & ring_mask  # Outer ERK band
    mask_inner = (distance_from_outer > fixed_distance) & ring_mask  # Inner region

    # Generate the final segmentation with standardized bands
    segmentation_fixed_distance = segmentation.copy()
    segmentation_fixed_distance[mask_outer] = 1  # Outer ERK band
    segmentation_fixed_distance[mask_inner] = 2  # Inner region
    segmentation_fixed_distance[segmentation == 3] = 3  # Preserve center regions

    new_segmentation.append(segmentation_fixed_distance)

# Stack all processed segmentations
new_segmentation = np.stack(new_segmentation, axis=0)

# Visualize the results and save
napari.view_labels(new_segmentation, name="segmentation_fixed_distance")
tifffile.imwrite(
    os.path.join(get_output_path(), "frame36_segmentation_fixed_distance.tif"),
    new_segmentation,
    ome=True,
    compression="zlib",
)

Viewer(camera=Camera(center=(0.0, np.float64(276.0), np.float64(276.0)), zoom=np.float64(0.9912296564195298), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(np.float64(4.0), 1.0, 0.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=np.float64(10.0)), dims=Dims(ndim=3, ndisplay=2, order=(0, 1, 2), axis_labels=('0', '1', '2'), rollable=(True, True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(9.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(552.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(552.0), step=np.float64(1.0))), margin_left=(0.0, 0.0, 0.0), margin_right=(0.0, 0.0, 0.0), point=(np.float64(4.0), np.float64(276.0), np.float64(276.0)), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'segmentation_fixed_distance' at 0x1ee042ec590>], help='use <7> for transform, use <1> for activate the label era

## Step 3: Map ERK Bands to Single-Cell Tracking Data

This section applies the ERK band masks to single-cell tracking data and identifies lineage relationships. The process involves:

1. **Loading tracking data** from ultrack results
2. **Mapping ERK band labels** to individual cell positions  
3. **Identifying lineage relationships** by finding ancestor cells
4. **Saving results** as parquet files for downstream analysis

In [None]:
import lzma
import pandas as pd
import pickle
import zarr
import os
import tifffile
import ome_zarr.io as ozi
import ome_zarr.writer as ozw
import os
import numpy as np
import ome_zarr.reader as ozr
import skimage
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
import dask.array as da
import dask

# Load the processed ERK band segmentation
segmentation_tot = tifffile.imread(
    os.path.join(get_output_path(), "frame36_segmentation_fixed_distance.tif")
)

TRACKING_BATCH_ID = 2  # Identifier for the tracking batch used


@dask.delayed
def add_new_label_info_to_df(tracksdf, label, name):
    """
    Add ERK band information to tracking dataframe based on cell positions.

    Args:
        tracksdf: DataFrame with cell tracking information
        label: 2D array with ERK band labels
        name: Column name for the new label information

    Returns:
        DataFrame with added ERK band information
    """
    tracksdf = tracksdf.copy()
    # Get cell positions and convert to appropriate data type
    x = tracksdf["x"].to_numpy().astype(np.uint16)
    y = tracksdf["y"].to_numpy().astype(np.uint16)
    # Map ERK band labels to cell positions
    tracksdf.loc[:, name] = label[y, x]
    return tracksdf


def find_root_vectorized(track_ids, parent_dict):
    """
    Find the root (ancestor) cell for each track using vectorized operations.

    Args:
        track_ids: Series of track IDs
        parent_dict: Dictionary mapping track_id to parent_track_id

    Returns:
        Series of root track IDs
    """
    roots = track_ids.copy()
    while True:
        new_roots = roots.map(parent_dict)
        mask = new_roots != -1
        if not mask.any():
            break
        roots[mask] = new_roots[mask]
    return roots


dfs = []  # List to collect DataFrames from all FOVs

# Process each FOV individually
for fov in get_fovs():
    fov_i = int(fov.split("_")[1])  # Extract FOV number

    # Set up OME-Zarr access
    dest = os.path.join(get_output_path(), fov)
    store = ozi.parse_url(dest, mode="a").store
    root = zarr.group(store=store)
    X_dim = root["0"].shape[-1]
    Y_dim = root["0"].shape[-2]

    # Get the ERK band segmentation for this FOV
    segmentation = segmentation_tot[fov_i]

    # Load existing label data from OME-Zarr
    nodes = list(ozr.Reader(ozi.parse_url(dest, mode="r"))())
    i_tracked = nodes[1].zarr.root_attrs["labels"].index("tracked_2")
    tracked = nodes[i_tracked + 2].data[0]
    i_erk_band = nodes[1].zarr.root_attrs["labels"].index("ERK band")
    erk_band = nodes[i_erk_band + 2].data[0]

    # Load tracking graph and dataframe
    with lzma.open(
        os.path.join(get_output_path(), f"FOV_{fov_i}_graph_{TRACKING_BATCH_ID}.xz"),
        "rb",
    ) as f:
        graph = pickle.load(f)

    tracks_df = pd.read_pickle(
        os.path.join(get_output_path(), f"FOV_{fov_i}_df_tracks_{TRACKING_BATCH_ID}.xz")
    )

    # Set up lineage analysis starting from the reference frame
    base_frame_mother = 36  # Reference timepoint for ERK band analysis

    # Identify tracks present at the base frame
    tracks_at_t = tracks_df[tracks_df["t"] == base_frame_mother].track_id

    # Set parent relationships: cells at base frame are considered roots
    tracks_df.loc[
        tracks_df.query("track_id in @tracks_at_t").index, "parent_track_id"
    ] = -1

    # Filter data to only include timepoints from base frame onward
    tracks_df = tracks_df[tracks_df["t"] >= base_frame_mother]
    tracks_df.loc[tracks_df["t"] == base_frame_mother, "parent_track_id"] = -1

    # Create parent dictionary and find root cells
    parent_dict = tracks_df.set_index("track_id")["parent_track_id"].to_dict()
    tracks_df["mother_track_id"] = find_root_vectorized(
        tracks_df["track_id"], parent_dict
    )

    # Create clean tracking dataframe with lineage information
    tracks_df_clean = (
        tracks_df[["track_id", "parent_track_id", "mother_track_id"]]
        .drop_duplicates()
        .set_index("track_id")
    )

    # Upscale segmentation to match original image resolution
    segmentation_up = skimage.transform.rescale(
        segmentation.astype(np.uint8), 5, anti_aliasing=False
    )
    segmentation_up = skimage.util.img_as_ubyte(segmentation_up)

    # Load cell position data
    tracks_extracted_wholedf = pd.read_parquet(
        os.path.join(get_output_path(), f"FOV_{fov_i}_df.parquet")
    )
    tracks_extracted = tracks_extracted_wholedf[["t", "label", "x", "y"]]

    # Merge tracking and position data
    df_tot = tracks_df_clean.merge(
        tracks_extracted, left_index=True, right_on="label", how="right"
    )

    # Add ERK band information for the reference timepoint
    df_t = add_new_label_info_to_df(
        df_tot[df_tot["t"] == 36], segmentation_up, "simplified_erk_band"
    ).compute()

    # Save results for this FOV
    df_t.to_parquet(
        os.path.join(
            get_output_path(), f"FOV_{fov_i}_df_new_simplified_erk_band.parquet"
        )
    )

    # Integrate ERK band information into the main DataFrame
    if "erk_band" in tracks_extracted_wholedf.columns:
        tracks_extracted_wholedf.drop(columns=["erk_band"], inplace=True)
    tracks_extracted_wholedf = tracks_extracted_wholedf.merge(
        df_t[["label", "t", "simplified_erk_band"]],
        how="left",
        on=["t", "label"],
    )
    tracks_extracted_wholedf.rename(
        columns={"simplified_erk_band": "erk_band"}, inplace=True
    )
    tracks_extracted_wholedf.to_parquet(
        os.path.join(get_output_path(), f"FOV_{fov_i}_df.parquet")
    )

    df_t["fov"] = fov_i
    dfs.append(df_t)

# Combine results from all FOVs
df_tot = pd.concat(dfs, axis=0).reset_index(drop=True)
df_tot.to_parquet(
    os.path.join(get_output_path(), "all_FOVs_new_simplified_erk_band.parquet")
)

If required, the ERK band can be saved back as a picture and laid ontop of the segmented cells.

## Step 4: Optional - Save ERK Bands as OME-Zarr Labels

This optional step saves the ERK band information back as label images that can be overlaid on the original segmented cells. This is useful for:

In [None]:
import ome_zarr


# --- Map DataFrame values to label image (Dask, all frames) ---
def label_to_value_dask(tracks, labels_stack, what):
    """
    Map values from a DataFrame to a label image stack using Dask for parallelization.

    This function creates a new image where each cell (identified by its label) is
    assigned the corresponding value from the DataFrame.

    Args:
        tracks: DataFrame with cell tracking data including the column 'what'
        labels_stack: 3D array of cell labels (t, y, x)
        what: Column name to map to the image

    Returns:
        dask.array: Image stack with mapped values
    """
    # Prepare tracking data for mapping
    tracks_df_norm = tracks[["t", "label", what]].copy()
    tracks_df_norm.replace([np.inf, -np.inf], np.nan, inplace=True)
    tracks_df_norm.dropna(inplace=True)

    # Determine appropriate output data type
    dtype = np.uint16
    if tracks_df_norm[what].dtype in [np.float16, np.float32, np.float64]:
        dtype = np.float32
    elif tracks_df_norm[what].dtype in [np.uint32, np.uint64, np.int32, np.int64]:
        dtype = np.uint32
    elif tracks_df_norm[what].dtype in [np.uint8, np.int8]:
        dtype = np.uint8

    def block_func(labels_f, block_info=None):
        """Process a single frame/block of the label image."""
        frame = block_info[0]["chunk-location"][0]
        if labels_f.shape[0] == 1:
            labels_f = labels_f[0]
        result = label_to_value_frame(
            tracks_df_norm, labels_f, frame, what, out_dtype=dtype
        )
        return result[None, :, :]  # Output shape: (1, Y, X)

    # Convert to Dask array if necessary
    if not isinstance(labels_stack, da.Array):
        labels_stack = da.from_array(
            labels_stack, chunks=(1, labels_stack.shape[1], labels_stack.shape[2])
        )

    # Apply mapping function to each frame
    gen_image = labels_stack.map_blocks(
        block_func,
        dtype=dtype,
        drop_axis=[],
        new_axis=[],
    )
    return gen_image


def label_to_value_frame(tracks_df_norm, labels_f, frame, what, out_dtype=None):
    """
    Map values from DataFrame to a single frame of labels.

    Args:
        tracks_df_norm: Normalized tracking DataFrame
        labels_f: 2D label array for single frame
        frame: Frame number
        what: Column name to map
        out_dtype: Output data type

    Returns:
        2D array with mapped values
    """
    # Get data for this specific frame
    tracks_f = tracks_df_norm[tracks_df_norm["t"] == frame]
    from_label = tracks_f["label"].values.astype(np.float16)
    to_particle = tracks_f[what].to_numpy()

    if out_dtype is None:
        out_dtype = to_particle.dtype

    # Create output array and map values
    out = np.zeros_like(labels_f, dtype=out_dtype)
    skimage.util.map_array(labels_f, from_label, to_particle, out=out)
    return out


# --- Utility function to save label images to OME-Zarr ---
def save_labels(label, label_name, root, greyscale=False):
    """
    Save a label image to the OME-Zarr group, removing any existing label with the same name.

    Args:
        label: Label image array
        label_name: Name for the label in OME-Zarr
        root: OME-Zarr root group
        greyscale: Whether to treat as greyscale label
    """
    # Remove existing label if present to avoid conflicts
    if "labels" in root:
        if label_name in root.labels.attrs["labels"]:
            del root["labels"][label_name]
            current_labels = root.labels.attrs["labels"]
            new_labels = [lbl for lbl in current_labels if lbl != label_name]
            root.labels.attrs["labels"] = new_labels
        try:
            del root["labels"][label_name]
        except:
            pass

    # Get image dimensions
    Y_dim = root["0"].shape[-2]
    X_dim = root["0"].shape[-1]

    # Write label data to OME-Zarr
    ozw.write_labels(
        labels=label,
        group=root,
        name=label_name,
        axes="tyx",
        scaler=ome_zarr.scale.Scaler(max_layer=1),
        chunks=(1, Y_dim, X_dim),
        storage_options={
            "compressor": zarr.storage.Blosc(cname="zstd", clevel=5),
        },
        metadata={"is_grayscale_label": greyscale},
        delayed=True,
    )


# Process each FOV and save ERK band labels
for fov in get_fovs():
    # Set up OME-Zarr access
    dest = os.path.join(get_output_path(), fov)
    store = ozi.parse_url(dest, mode="a").store
    root = zarr.group(store=store)
    X_dim = root["0"].shape[-1]
    Y_dim = root["0"].shape[-2]

    # Load existing tracking labels
    nodes = list(ozr.Reader(ozi.parse_url(dest, mode="r"))())
    i_tracked = nodes[1].zarr.root_attrs["labels"].index("tracked_2")
    tracked = nodes[i_tracked + 2].data[0]

    # Load ERK band data for this FOV
    df = pd.read_parquet(
        os.path.join(
            get_output_path(),
            f"FOV_{fov.split('_')[1]}_df_new_simplified_erk_band.parquet",
        )
    )

    # Map ERK band values to label image and save
    save_labels(
        label_to_value_dask(df, tracked, "simplified_erk_band"),
        "simplified_erk_band",
        root,
        greyscale=False,
    )