In [3]:
import os
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split
 
# Configuration
image_dir = "../../../data/images"
mask_dir = "../../../data/masks"
mask_suffixes = ["root_mask"]
image_ext = ".png"
mask_ext = ".tif"
dry_run = False
 
# Cleanup Stats
deleted_images = 0
deleted_masks = 0
total_checked_images = 0
kept_images = []
 
# 1. Load mask base names (without extension)
mask_basenames = set(os.path.splitext(f)[0] for f in os.listdir(mask_dir) if f.endswith(mask_ext))
 
# 2. Clean images that are missing any mask
for filename in os.listdir(image_dir):
    if not filename.endswith(image_ext):
        continue
 
    total_checked_images += 1
    image_base = filename[:-len(image_ext)]
 
    required_masks = [f"{image_base}_{suffix}" for suffix in mask_suffixes]
    missing_masks = [m for m in required_masks if m not in mask_basenames]
 
    if missing_masks:
        print(f"Missing masks for {filename}: {missing_masks}")
        if not dry_run:
            os.remove(os.path.join(image_dir, filename))
            print(f"Deleted image: {filename}")
        deleted_images += 1
    else:
        kept_images.append(image_base)
 
# Clean orphaned masks (masks without corresponding image)
valid_image_bases = set(kept_images)
 
for mask_file in os.listdir(mask_dir):
    if not mask_file.endswith(mask_ext):
        continue
 
    mask_base, _ = os.path.splitext(mask_file)
 
    # Extract the original image base name by removing one of the known suffixes
    image_base_candidates = [
        mask_base.replace(f"_{suffix}", "") for suffix in mask_suffixes if f"_{suffix}" in mask_base
    ]
 
    if not image_base_candidates:
        print(f"Unrecognized mask format: {mask_file}")
        os.remove(os.path.join(mask_dir, mask_file))
        continue
 
    image_base = image_base_candidates[0]
 
    if image_base not in valid_image_bases:
        mask_path = os.path.join(mask_dir, mask_file)
        if dry_run:
            print(f"[Dry-run] Would delete orphaned mask: {mask_file}")
        else:
            os.remove(mask_path)
            print(f"Deleted orphaned mask: {mask_file}")
        deleted_masks += 1
 
# Summary
print("\nCleanup Summary")
print(f"Checked images       : {total_checked_images}")
print(f"Deleted images       : {deleted_images}")
print(f"Deleted orphaned masks: {deleted_masks}")
print(f"Dry run mode         : {dry_run}")


Cleanup Summary
Checked images       : 0
Deleted images       : 0
Deleted orphaned masks: 0
Dry run mode         : False


In [4]:
# Updated get_id_from_filename to handle both images and masks
def get_id_from_filename(filename: str, mask_suffixes=None) -> str:
    name, _ = os.path.splitext(filename)
    if mask_suffixes:
        for suffix in mask_suffixes:
            if name.endswith(f"_{suffix}"):
                name = name[: -len(f"_{suffix}")]
                break
    return name

In [5]:
def main(image_dir, mask_dir, train_images, train_masks, val_images, val_masks, test_images, test_masks):
    image_dir = Path(image_dir)
    mask_dir = Path(mask_dir)
 
    mask_suffixes = ["root_mask"]
    image_ext = ".png"
    mask_ext = ".tif"
 
    # Ensure output directories exist
    for out_dir in [train_images, train_masks, val_images, val_masks, test_images, test_masks]:
        os.makedirs(out_dir, exist_ok=True)
 
    # Map image files to IDs
    image_files = list(image_dir.glob(f"*{image_ext}"))
    image_id_map = {get_id_from_filename(img.name): img for img in image_files}
 
    # Map mask files to IDs using suffix stripping
    mask_files = list(mask_dir.glob(f"*{mask_ext}"))
    mask_id_map = {
        get_id_from_filename(msk.name, mask_suffixes=mask_suffixes): msk
        for msk in mask_files
    }
 
    # Match only IDs that exist in both
    common_ids = sorted(set(image_id_map.keys()) & set(mask_id_map.keys()))
 
    if not common_ids:
        raise ValueError("No matching image-mask ID pairs found. Please check your data and naming conventions.")
 
    # Split into train/val/test (e.g., 70/15/15)
    train_ids, temp_ids = train_test_split(common_ids, test_size=0.3, random_state=42)
    val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)
 
    def copy_files(ids, img_out, mask_out):
        for id_ in ids:
            shutil.copy(image_id_map[id_], os.path.join(img_out, image_id_map[id_].name))
            shutil.copy(mask_id_map[id_], os.path.join(mask_out, mask_id_map[id_].name))
 
    # Copy files
    copy_files(train_ids, train_images, train_masks)
    copy_files(val_ids, val_images, val_masks)
    copy_files(test_ids, test_images, test_masks)
 
    print(f"[INFO] Data split completed: {len(train_ids)} train, {len(val_ids)} val, {len(test_ids)} test")

In [7]:
# Define output directories for train, validation, and test splits
train_images = "../../../data/train_images"
train_masks = "../../../data/train_masks"
val_images = "../../../data/val_images"
val_masks = "../../../data/val_masks"
test_images = "../../../data/test_images"
test_masks = "../../../data/test_masks"

# Call the main function
main(image_dir, mask_dir, train_images, train_masks, val_images, val_masks, test_images, test_masks)

ValueError: No matching image-mask ID pairs found. Please check your data and naming conventions.