In [17]:
## imports and file paths
import os
import numpy as np
from skimage import io, segmentation
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
import cv2

# ─── USER PARAMETERS ────────────────────────────────────────────────────────────
DAPI_DIR = '/content/drive/MyDrive/biotech/Retina_Lab/VGG_next_level/dapi_img_seg'
PKC_DIR = '/content/drive/MyDrive/biotech/Retina_Lab/VGG_next_level/pkc_img_seg'
OUT_DIR = '/content/drive/MyDrive/biotech/Retina_Lab/VGG_next_level/extracted_rbcs'
IMAGE_EXT = '.png'
SEG_EXT = '_seg.npy'
OVERLAP_THRESHOLD = 0.3  # Minimum IoU threshold for overlap
NAME_TEMPLATE = 'rod_bipolar_{:04d}.png'
OVERLAY_TMPL = '{}_overlay.png'
# ─

## Calculate Intersection over Union

In [18]:
def calculate_iou(mask1, mask2):
    """
    Calculate Intersection over Union (IoU) between two binary masks.

    Args:
        mask1: First binary mask
        mask2: Second binary mask

    Returns:
        IoU value between 0 and 1
    """
    # Ensure masks are boolean
    mask1_bool = mask1.astype(bool)
    mask2_bool = mask2.astype(bool)

    # Calculate intersection and union
    intersection = np.logical_and(mask1_bool, mask2_bool).sum()
    union = np.logical_or(mask1_bool, mask2_bool).sum()

    # Avoid division by zero
    if union == 0:
        return 0.0

    iou = intersection / union
    return iou

## align images
- fills the 'missing' space in the pkc images with zeros

In [19]:
def align_images(dapi_img, pkc_img, dapi_mask, pkc_mask):
    """
    Align PKC image and mask to DAPI image dimensions.
    DAPI images are 1024x1024, PKC images are 1024x(varying height).
    Alignment is done by matching top corners.

    Args:
        dapi_img: DAPI image (1024x1024)
        pkc_img: PKC image (1024xH where H varies)
        dapi_mask: DAPI segmentation mask
        pkc_mask: PKC segmentation mask

    Returns:
        Aligned PKC image and mask with same dimensions as DAPI
    """
    dapi_height, dapi_width = dapi_img.shape[:2]
    pkc_height, pkc_width = pkc_img.shape[:2]

    # Create aligned versions with same size as DAPI
    if len(dapi_img.shape) == 3:
        aligned_pkc_img = np.zeros_like(dapi_img)
    else:
        aligned_pkc_img = np.zeros((dapi_height, dapi_width), dtype=dapi_img.dtype)

    aligned_pkc_mask = np.zeros_like(dapi_mask)

    # Copy PKC data to aligned arrays (top-left corner alignment)
    copy_height = min(dapi_height, pkc_height)
    copy_width = min(dapi_width, pkc_width)

    if len(pkc_img.shape) == 3:
        aligned_pkc_img[:copy_height, :copy_width] = pkc_img[:copy_height, :copy_width]
    else:
        aligned_pkc_img[:copy_height, :copy_width] = pkc_img[:copy_height, :copy_width]

    aligned_pkc_mask[:copy_height, :copy_width] = pkc_mask[:copy_height, :copy_width]

    return aligned_pkc_img, aligned_pkc_mask

## Overlay Masks

In [20]:
def overlay_masks(dapi_mask, pkc_mask):
    """
    Create an overlay of DAPI and PKC masks for visualization.

    Args:
        dapi_mask: DAPI segmentation mask
        pkc_mask: PKC segmentation mask

    Returns:
        RGB overlay image
    """
    # Create RGB overlay
    overlay = np.zeros((*dapi_mask.shape, 3), dtype=np.uint8)

    # DAPI in blue channel
    overlay[:, :, 2] = (dapi_mask > 0) * 255

    # PKC in green channel
    overlay[:, :, 1] = (pkc_mask > 0) * 255

    # Overlap in cyan (green + blue)
    overlap_region = np.logical_and(dapi_mask > 0, pkc_mask > 0)
    overlay[overlap_region] = [0, 255, 255]  # Cyan for overlap

    return overlay

## Find all matching files

In [21]:
def find_matching_files(dapi_dir, scgn_dir):
    """
    Find matching DAPI and SCGN image files based on sample and slice identifiers.

    Args:
        dapi_dir: Directory containing DAPI images
        scgn_dir: Directory containing SCGN images

    Returns:
        List of matching file pairs
    """
    dapi_files = glob.glob(os.path.join(dapi_dir, f"C1-*{IMAGE_EXT}"))
    scgn_files = glob.glob(os.path.join(scgn_dir, f"C3-*{IMAGE_EXT}"))

    matching_pairs = []

    for dapi_file in dapi_files:
        # Extract sample and slice info from DAPI filename
        dapi_basename = os.path.basename(dapi_file)
        # Expected format: C1-C30000.png -> extract C30000
        sample_slice = dapi_basename.replace('C1-', '').replace(IMAGE_EXT, '')

        # Look for corresponding SCGN file
        scgn_file = os.path.join(scgn_dir, f"C2-{sample_slice}{IMAGE_EXT}")

        if os.path.exists(scgn_file):
            matching_pairs.append((dapi_file, scgn_file))

    return matching_pairs

## Extract all masks from seg data

In [22]:

def extract_masks_from_seg(seg_data):
    """
    Extract mask arrays from segmentation data structure.

    Args:
        seg_data: Raw segmentation data from .npy file

    Returns:
        Mask array or None if extraction fails
    """
    try:
        # Handle different data structures
        if isinstance(seg_data, np.ndarray) and seg_data.dtype == 'object':
            data = seg_data.item()
        else:
            data = seg_data

        if isinstance(data, dict):
            masks = data.get('masks', None)
            if masks is None:
                print("Warning: 'masks' key not found in segmentation data")
                return None
        else:
            # Assume the data itself is the mask
            masks = data

        return masks

    except Exception as e:
        print(f"Error extracting masks: {e}")
        return None


## Main function
- brings it all together

In [23]:
def extract_rod_bipolar_cells():
    """
    Main function to extract rod bipolar cells from DAPI and PKC image pairs.
    """
    print("Starting rod bipolar cell extraction...")

    # Create output directory
    os.makedirs(OUT_DIR, exist_ok=True)

    # Find matching DAPI and PKC files
    matching_pairs = find_matching_files(DAPI_DIR, PKC_DIR)
    print(f"Found {len(matching_pairs)} matching image pairs")

    if not matching_pairs:
        print("No matching pairs found. Please check file naming convention and paths.")
        return

    cell_counter = 0
    total_overlaps = 0

    for dapi_path, pkc_path in tqdm(matching_pairs, desc="Processing image pairs"):
        try:
            # Load images
            dapi_img = io.imread(dapi_path)
            pkc_img = io.imread(pkc_path)

            # Load corresponding segmentation files
            dapi_base = os.path.splitext(dapi_path)[0]
            pkc_base = os.path.splitext(pkc_path)[0]

            dapi_seg_path = dapi_base + SEG_EXT
            pkc_seg_path = pkc_base + SEG_EXT

            if not os.path.exists(dapi_seg_path) or not os.path.exists(pkc_seg_path):
                print(f"Skipping {os.path.basename(dapi_path)}: missing segmentation files")
                continue

            # Load segmentation data
            dapi_seg_raw = np.load(dapi_seg_path, allow_pickle=True)
            pkc_seg_raw = np.load(pkc_seg_path, allow_pickle=True)

            # Extract masks from segmentation data
            dapi_masks = extract_masks_from_seg(dapi_seg_raw)
            pkc_masks = extract_masks_from_seg(pkc_seg_raw)

            if dapi_masks is None or pkc_masks is None:
                print(f"Skipping {os.path.basename(dapi_path)}: could not extract masks")
                continue

            # Align PKC image and mask to DAPI dimensions
            aligned_pkc_img, aligned_pkc_masks = align_images(dapi_img, pkc_img, dapi_masks, pkc_masks)

            # Create overlay for visualization
            overlay = overlay_masks(dapi_masks, aligned_pkc_masks)
            overlay_name = OVERLAY_TMPL.format(os.path.splitext(os.path.basename(dapi_path))[0])
            io.imsave(os.path.join(OUT_DIR, overlay_name), overlay)

            # Find overlapping regions
            dapi_labels = np.unique(dapi_masks)[1:]  # Exclude background (0)
            pkc_labels = np.unique(aligned_pkc_masks)[1:]  # Exclude background (0)

            pair_overlaps = 0

            for dapi_label in dapi_labels:
                dapi_region = (dapi_masks == dapi_label)

                for pkc_label in pkc_labels:
                    pkc_region = (aligned_pkc_masks == pkc_label)

                    # Calculate IoU between regions
                    iou = calculate_iou(dapi_region, pkc_region)

                    if iou >= OVERLAP_THRESHOLD:
                        # Extract DAPI region (rod bipolar cell)
                        cell_counter += 1
                        pair_overlaps += 1

                        # Get bounding box of DAPI region
                        ys, xs = np.where(dapi_region)
                        if len(ys) == 0 or len(xs) == 0:
                            continue

                        y0, x0 = ys.min(), xs.min()
                        y1, x1 = ys.max() + 1, xs.max() + 1

                        # Extract patch from DAPI image
                        patch = dapi_img[y0:y1, x0:x1].copy()
                        mask_crop = dapi_region[y0:y1, x0:x1]

                        # Apply mask to patch
                        if patch.ndim == 3:
                            patch[~mask_crop] = 0
                        else:
                            patch = patch * mask_crop

                        # Save extracted cell
                        out_name = NAME_TEMPLATE.format(cell_counter)
                        Image.fromarray(patch.astype(np.uint8)).save(os.path.join(OUT_DIR, out_name))

                        print(f"  Extracted cell {cell_counter} (IoU: {iou:.3f})")

            total_overlaps += pair_overlaps
            print(f"  Found {pair_overlaps} overlapping cells in {os.path.basename(dapi_path)}")

        except Exception as e:
            print(f"Error processing {os.path.basename(dapi_path)}: {e}")
            continue

    print(f"\nExtraction complete!")
    print(f"Total cells extracted: {cell_counter}")
    print(f"Total overlapping regions found: {total_overlaps}")


In [24]:
# run
extract_rod_bipolar_cells()

Starting rod bipolar cell extraction...
Found 0 matching image pairs
No matching pairs found. Please check file naming convention and paths.
