In [1]:
from PIL import Image
import numpy as np
import os
import glob

# Create output directory
os.makedirs("./data/overlay", exist_ok=True)

# Get all .tiff and .tif files from images directory
image_files = glob.glob("./data/images/*.tiff") + glob.glob("./data/images/*.tif")

print(f"Found {len(image_files)} image files to process\n")

for img_path in image_files:
    # Get base filename (without path and extension)
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    
    # Try to find corresponding mask with either extension
    mask_path = None
    for mask_ext in ['_masks.tif', '_masks.tiff']:
        candidate_path = f"./data/instance_masks/{base_name}{mask_ext}"
        if os.path.exists(candidate_path):
            mask_path = candidate_path
            break
    
    # Check if mask exists
    if mask_path is None:
        print(f"Warning: Mask not found for {base_name}, skipping...")
        continue
    
    print(f"Processing: {base_name}")
    
    # Load original image
    img = Image.open(img_path)
    img_np = np.array(img)

    # Load mask
    mask = Image.open(mask_path)
    mask_np = np.array(mask)

    # If mask is RGB, take one channel
    if mask_np.ndim == 3:
        mask_np = mask_np[..., 0]

    # Normalize image for display - improved version with subtle contrast enhancement
    if img.mode not in ['RGB', 'L', 'RGBA']:
        if img.mode in ['I', 'I;16', 'F']:
            # Normalize 16-bit or float images to 8-bit with subtle contrast stretching
            img_array = img_np.copy().astype(np.float32)
            
            # Use gentler percentile-based contrast stretching
            p_low = np.percentile(img_array, 0.5)
            p_high = np.percentile(img_array, 99.5)
            
            # Clip and normalize
            img_array = np.clip(img_array, p_low, p_high)
            img_array = ((img_array - p_low) / (p_high - p_low) * 255).astype(np.uint8)
            
            img = Image.fromarray(img_array, mode='L')
            img_disp = img_array
        else:
            img = img.convert('L')
            img_disp = np.array(img)
    else:
        # Apply subtle contrast stretching to 8-bit images too
        img_array = img_np.copy().astype(np.float32)
        p_low = np.percentile(img_array, 0.5)
        p_high = np.percentile(img_array, 99.5)
        img_array = np.clip(img_array, p_low, p_high)
        img_disp = ((img_array - p_low) / (p_high - p_low) * 255).astype(np.uint8)

    print(f"  Image shape: {img_np.shape}, dtype: {img_np.dtype}, mode: {img.mode}")
    print(f"  Mask shape: {mask_np.shape}, unique values: {len(np.unique(mask_np))}")

    # -----------------------
    # Pre-compute colors for ALL objects (fixed seed for consistency)
    # -----------------------
    rng = np.random.default_rng(42)  # Fixed seed
    all_instance_ids = np.unique(mask_np)
    color_map = {}
    for instance_id in all_instance_ids:
        if instance_id == 0:  # Skip background
            continue
        color_map[instance_id] = rng.random(3)

    print(f"  Created color map for {len(color_map)} unique objects")

    # -----------------------
    # Create full image overlay
    # -----------------------
    rgb_mask_full = np.zeros((*mask_np.shape, 3), dtype=np.float32)
    for instance_id in np.unique(mask_np):
        if instance_id == 0:
            continue
        rgb_mask_full[mask_np == instance_id] = color_map[instance_id]

    # Create full overlay
    overlay_full = np.stack([img_disp, img_disp, img_disp], axis=-1).astype(np.float32) / 255.0
    overlay_full = overlay_full * 0.6 + rgb_mask_full * 0.4
    overlay_full = (overlay_full * 255).astype(np.uint8)

    # Save full overlay with same filename
    output_path = f"./data/overlay/{base_name}_overlay.tif"
    Image.fromarray(overlay_full).save(output_path)
    print(f"  Saved overlay to {output_path}\n")

print("All files processed!")

Found 1 image files to process

Processing: 20211222_125057_petiole4_00012
  Image shape: (2560, 2560), dtype: float32, mode: L
  Mask shape: (2560, 2560), unique values: 2085
  Created color map for 2084 unique objects
  Saved overlay to ./data/overlay/20211222_125057_petiole4_00012_overlay.tif

All files processed!


In [2]:
from PIL import Image
import numpy as np
import os
import glob
import tifffile as tiff

def rgb_tif_to_gray_tif(
    input_tif: str,
    output_tif: str,
    method: str = "luminance"
):
    """
    Convert a 3-channel TIFF to a grayscale TIFF with shape [1, H, W].
    """
    img = tiff.imread(input_tif)

    if img.ndim != 3 or img.shape[-1] != 3:
        raise ValueError(f"Expected 3-channel TIFF, got shape {img.shape}")

    in_dtype = img.dtype
    img = img.astype(np.float32)

    if method == "luminance":
        gray = (
            0.299 * img[..., 0] +
            0.587 * img[..., 1] +
            0.114 * img[..., 2]
        )
    elif method == "average":
        gray = img.mean(axis=-1)
    else:
        raise ValueError("method must be 'luminance' or 'average'")

    # Restore dtype
    if np.issubdtype(in_dtype, np.integer):
        max_val = np.iinfo(in_dtype).max
        gray = np.clip(gray, 0, max_val).astype(in_dtype)

    # ðŸ”´ ADD CHANNEL AXIS â†’ (1, H, W)
    gray = gray[np.newaxis, :, :]

    # Save without squeezing
    tiff.imwrite(
        output_tif,
        gray,
        photometric="minisblack"
    )

    print(f"  Saved grayscale TIFF with shape {gray.shape}: {output_tif}")


# Create output directory
os.makedirs("./data/overlay", exist_ok=True)

# Get all .tiff and .tif files from images directory
image_files = glob.glob("./data/images/*.tiff") + glob.glob("./data/images/*.tif")

print(f"Found {len(image_files)} image files to process\n")

for img_path in image_files:
    # Get base filename (without path and extension)
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    
    # Try to find corresponding mask with either extension
    mask_path = None
    for mask_ext in ['_masks.tif', '_masks.tiff']:
        candidate_path = f"./data/instance_masks/{base_name}{mask_ext}"
        if os.path.exists(candidate_path):
            mask_path = candidate_path
            break
    
    # Check if mask exists
    if mask_path is None:
        print(f"Warning: Mask not found for {base_name}, skipping...")
        continue
    
    print(f"Processing: {base_name}")
    
    # Define output paths
    output_path = f"./data/overlay/{base_name}_overlay.tif"
    gray_output_path = f"./data/overlay/{base_name}_overlay_gray.tif"
    
    # Check if gray overlay already exists
    gray_exists = os.path.exists(gray_output_path)
    
    # Skip if gray overlay exists (no need to create RGB overlay either)
    if gray_exists:
        print(f"  Gray overlay already exists, skipping...")
        continue
    
    # Load original image
    img = Image.open(img_path)
    img_np = np.array(img)

    # Load mask
    mask = Image.open(mask_path)
    mask_np = np.array(mask)

    # If mask is RGB, take one channel
    if mask_np.ndim == 3:
        mask_np = mask_np[..., 0]

    # Normalize image for display - improved version with subtle contrast enhancement
    if img.mode not in ['RGB', 'L', 'RGBA']:
        if img.mode in ['I', 'I;16', 'F']:
            # Normalize 16-bit or float images to 8-bit with subtle contrast stretching
            img_array = img_np.copy().astype(np.float32)
            
            # Use gentler percentile-based contrast stretching
            p_low = np.percentile(img_array, 0.5)
            p_high = np.percentile(img_array, 99.5)
            
            # Clip and normalize
            img_array = np.clip(img_array, p_low, p_high)
            img_array = ((img_array - p_low) / (p_high - p_low) * 255).astype(np.uint8)
            
            img = Image.fromarray(img_array, mode='L')
            img_disp = img_array
        else:
            img = img.convert('L')
            img_disp = np.array(img)
    else:
        # Apply subtle contrast stretching to 8-bit images too
        img_array = img_np.copy().astype(np.float32)
        p_low = np.percentile(img_array, 0.5)
        p_high = np.percentile(img_array, 99.5)
        img_array = np.clip(img_array, p_low, p_high)
        img_disp = ((img_array - p_low) / (p_high - p_low) * 255).astype(np.uint8)

    print(f"  Image shape: {img_np.shape}, dtype: {img_np.dtype}, mode: {img.mode}")
    print(f"  Mask shape: {mask_np.shape}, unique values: {len(np.unique(mask_np))}")

    # -----------------------
    # Pre-compute colors for ALL objects (fixed seed for consistency)
    # -----------------------
    rng = np.random.default_rng(42)  # Fixed seed
    all_instance_ids = np.unique(mask_np)
    color_map = {}
    for instance_id in all_instance_ids:
        if instance_id == 0:  # Skip background
            continue
        color_map[instance_id] = rng.random(3)

    print(f"  Created color map for {len(color_map)} unique objects")

    # -----------------------
    # Create full image overlay
    # -----------------------
    rgb_mask_full = np.zeros((*mask_np.shape, 3), dtype=np.float32)
    for instance_id in np.unique(mask_np):
        if instance_id == 0:
            continue
        rgb_mask_full[mask_np == instance_id] = color_map[instance_id]

    # Create full overlay
    overlay_full = np.stack([img_disp, img_disp, img_disp], axis=-1).astype(np.float32) / 255.0
    overlay_full = overlay_full * 0.6 + rgb_mask_full * 0.4
    overlay_full = (overlay_full * 255).astype(np.uint8)

    # Save full RGB overlay
    Image.fromarray(overlay_full).save(output_path)
    print(f"  Saved RGB overlay to {output_path}")
    
    # Convert RGB overlay to grayscale
    rgb_tif_to_gray_tif(
        output_path,
        gray_output_path,
        method="luminance"
    )
    print()

print("All files processed!")

Found 1 image files to process

Processing: 20211222_125057_petiole4_00012
  Image shape: (2560, 2560), dtype: float32, mode: L
  Mask shape: (2560, 2560), unique values: 2085
  Created color map for 2084 unique objects
  Saved RGB overlay to ./data/overlay/20211222_125057_petiole4_00012_overlay.tif
  Saved grayscale TIFF with shape (1, 2560, 2560): ./data/overlay/20211222_125057_petiole4_00012_overlay_gray.tif

All files processed!


In [None]:
import numpy as np
from PIL import Image
import cv2
from pathlib import Path
import matplotlib.pyplot as plt


def find_class_contours(semantic_mask, class_id):
    """
    Find contours for a specific class in the semantic mask.
    
    Args:
        semantic_mask: semantic segmentation mask
        class_id: the class ID to find contours for
    
    Returns:
        list of contours
    """
    # Create binary mask for this class
    binary_mask = (semantic_mask == class_id).astype(np.uint8)
    
    # Find contours
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    return contours


def draw_polygons_on_image(img_normalized, semantic_mask, class_colors):
    """
    Draw polygon boundaries for each class on the normalized image.
    
    Args:
        img_normalized: normalized grayscale image
        semantic_mask: semantic segmentation mask
        class_colors: dictionary mapping class IDs to RGB colors
    
    Returns:
        image with polygons drawn
    """
    # Convert to RGB if grayscale
    if len(img_normalized.shape) == 2:
        img_with_polygons = np.stack([img_normalized] * 3, axis=-1)
    else:
        img_with_polygons = img_normalized.copy()
    
    # Get unique classes (excluding background class 0)
    unique_classes = np.unique(semantic_mask)
    unique_classes = unique_classes[unique_classes != 0]
    
    # Draw contours for each class
    for class_id in unique_classes:
        contours = find_class_contours(semantic_mask, class_id)
        color = class_colors.get(int(class_id), [255, 255, 255])
        
        # Draw contours with thick lines
        cv2.drawContours(img_with_polygons, contours, -1, color, thickness=3)
    
    return img_with_polygons


def visualize_and_save(original_img, semantic_mask, output_path, norm_images_dir):
    """
    Visualize semantic mask as overlay on original image and save both.
    
    Args:
        original_img: original image array
        semantic_mask: semantic segmentation mask
        output_path: path to save the semantic mask
        norm_images_dir: directory to save normalized images
    """
    # Define colors for each class
    class_colors = {
        0: [128, 128, 128],  # Gray - Background
        1: [0, 0, 255],      # Blue - Cortex
        2: [0, 255, 0],      # Green - Phloem Fibers
        3: [128, 0, 128],    # Purple - Phloem
        4: [255, 0, 0],      # Red - Xylem vessels
        5: [255, 255, 0],    # Yellow - Air-based Pith cells
        6: [255, 165, 0],    # Orange - Water-based Pith cells
    }
    
    # Create colored semantic mask
    h, w = semantic_mask.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    for class_id, color in class_colors.items():
        colored_mask[semantic_mask == class_id] = color
    
    # Print class distribution
    unique_classes = np.unique(semantic_mask)
    print(f"Classes present: {unique_classes}")
    for class_id in unique_classes:
        pixel_count = np.sum(semantic_mask == class_id)
        print(f"Class {class_id}: {pixel_count} pixels")
    
    # Normalize original image
    img_array = original_img.copy().astype(np.float32)
    
    # Use percentile-based contrast stretching
    p_low = np.percentile(img_array, 0.5)
    p_high = np.percentile(img_array, 99.5)
    
    # Clip and normalize
    img_array = np.clip(img_array, p_low, p_high)
    img_normalized = ((img_array - p_low) / (p_high - p_low) * 255).astype(np.uint8)
    
    # Save normalized image
    norm_image_path = norm_images_dir / f"{output_path.stem.replace('_semantic_mask', '')}_normalized.png"
    Image.fromarray(img_normalized).save(norm_image_path)
    print(f"Normalized image saved to: {norm_image_path}")
    
    # Convert to RGB if grayscale
    if len(img_normalized.shape) == 2:
        img_normalized_rgb = np.stack([img_normalized] * 3, axis=-1)
    else:
        img_normalized_rgb = img_normalized
    
    # Create overlay with transparency
    alpha = 0.2
    overlay = cv2.addWeighted(img_normalized_rgb, 1-alpha, colored_mask, alpha, 0)
    
    # Draw polygons on normalized image
    img_with_polygons = draw_polygons_on_image(img_normalized, semantic_mask, class_colors)
    
    # Save colored semantic mask (same colors as overlay)
    Image.fromarray(colored_mask).save(output_path)
    print(f"Colored semantic mask saved to: {output_path}")
    
    # Save overlay image in the same directory
    overlay_path = output_path.parent / f"{output_path.stem}_overlay.png"
    Image.fromarray(overlay).save(overlay_path)
    print(f"Overlay image saved to: {overlay_path}")
    
    # Save image with polygons
    polygon_path = output_path.parent / f"{output_path.stem}_polygons.png"
    Image.fromarray(img_with_polygons).save(polygon_path)
    print(f"Image with polygons saved to: {polygon_path}")
    
    # Visualize (now with 4 subplots)
    fig, axes = plt.subplots(2, 2, figsize=(18, 18))
    axes = axes.flatten()
    
    axes[0].imshow(img_normalized, cmap='gray')
    axes[0].set_title('Original Image (Normalized)', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(colored_mask)
    axes[1].set_title('Semantic Mask', fontsize=12, fontweight='bold')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay (20% transparency)', fontsize=12, fontweight='bold')
    axes[2].axis('off')
    
    axes[3].imshow(img_with_polygons)
    axes[3].set_title('Polygon Boundaries', fontsize=12, fontweight='bold')
    axes[3].axis('off')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor=np.array(class_colors[i])/255, label=f'Class {i}')
        for i in sorted(class_colors.keys())
    ]
    axes[3].legend(handles=legend_elements, loc='upper right', fontsize=8)
    
    plt.tight_layout()
    plt.show()


def process_single_file(image_path, mask_path, output_path, norm_images_dir):
    """
    Process a single image and mask file to create overlay visualization.
    
    Args:
        image_path: path to the original image
        mask_path: path to the .npy mask file
        output_path: path to save the semantic mask
        norm_images_dir: directory to save normalized images
    """
    print(f"\n{'='*80}")
    print(f"Processing: {Path(image_path).name}")
    print(f"{'='*80}")
    
    # Load image
    original_img = np.array(Image.open(image_path))
    print(f"Image shape: {original_img.shape}")
    
    # Load mask
    semantic_mask = np.load(mask_path)
    print(f"Mask shape: {semantic_mask.shape}")
    
    # Visualize and save
    print("\nVisualizing and saving...")
    visualize_and_save(original_img, semantic_mask, output_path, norm_images_dir)


def batch_process(images_dir, masks_dir, output_dir, norm_images_dir):
    """
    Process all matching files in the specified directories.
    
    Args:
        images_dir: directory containing image files (.tiff or .tif)
        masks_dir: directory containing mask files (*_mask.npy)
        output_dir: directory to save semantic masks and overlays
        norm_images_dir: directory to save normalized images
    """
    # Create output directories if they don't exist
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    norm_images_dir = Path(norm_images_dir)
    norm_images_dir.mkdir(parents=True, exist_ok=True)
    
    # Get all image files
    images_dir = Path(images_dir)
    image_files = list(images_dir.glob("*.tiff")) + list(images_dir.glob("*.tif"))
    
    print(f"Found {len(image_files)} image files")
    
    processed_count = 0
    skipped_count = 0
    
    for image_path in sorted(image_files):
        base_name = image_path.stem
        
        # Find corresponding mask file
        masks_dir_path = Path(masks_dir)
        mask_path = masks_dir_path / f"{base_name}_mask.npy"
        
        # Check if mask file exists
        if not mask_path.exists():
            print(f"\nSkipping {base_name}: mask file not found")
            skipped_count += 1
            continue
        
        # Output path for semantic mask (will also save overlay in same directory)
        output_path = output_dir / f"{base_name}_semantic_mask.png"
        
        # Process the files
        try:
            process_single_file(image_path, mask_path, output_path, norm_images_dir)
            processed_count += 1
            # Close all matplotlib figures to free memory
            plt.close('all')
        except Exception as e:
            print(f"\nError processing {base_name}: {str(e)}")
            import traceback
            traceback.print_exc()
            skipped_count += 1
            # Close all matplotlib figures even on error
            plt.close('all')
    
    print(f"\n{'='*80}")
    print(f"Batch processing complete!")
    print(f"Processed: {processed_count} files")
    print(f"Skipped: {skipped_count} files")
    print(f"{'='*80}")


if __name__ == "__main__":
    # Batch process all files
    images_dir = "./data/images"
    masks_dir = "./data/annotations"  # Directory containing *_mask.npy files
    output_dir = "./data/semantic_masks"
    norm_images_dir = "./data/norm_images"
    
    batch_process(images_dir, masks_dir, output_dir, norm_images_dir)