# Image Splitting Script - WSI Tile Extraction

This notebook handles the extraction and organization of image tiles from whole slide images (WSI) in NDPI format.

## Pipeline Overview
1. Convert NDPI whole slide images to thumbnails (optional preview)
2. Extract tiles from WSI at specified resolution
3. Organize extracted tiles by stain type

## 1. Imports

In [None]:
import os
import shutil

import numpy as np
from PIL import Image
from tqdm import tqdm
import openslide

## 2. Configuration

In [None]:
# =============================================================================
# PATH CONFIGURATION
# =============================================================================

# Input directory containing NDPI files
INPUT_DIR = "/Volumes/TOSHIBA EXT/Melanomas IHC"

# Output directory for processed images
OUTPUT_DIR = "/Users/cameronpreasmyer/Desktop/Images/Melanomas IHC"

# Directory containing tiles to organize
TILES_SOURCE_DIR = "Images/2048Tiles"

# Root directory for organized tiles by stain
TILES_DEST_ROOT = "Images/"

# =============================================================================
# PROCESSING PARAMETERS
# =============================================================================

# Thumbnail size for WSI previews
THUMBNAIL_SIZE = (2048, 2048)

# Tile extraction parameters
TILE_SIZE = (1024, 1024)
STRIDE = 1024  # Non-overlapping if stride == tile_size

# Stain types to organize by
STAIN_KEYWORDS = ["PIR", "PBP", "MITF", "HE", "ANX", "BRAF", "BCL2", "BCL3"]

# Files to skip (known problematic files)
SKIP_FILES = [
    'EL-O001-PIR.ndpi',
    'EL-O001P-ANX - 2018-05-23 22.50.37.ndpi'
]

## 3. WSI Thumbnail Generation

In [None]:
def resize_ndpi_image(ndpi_path, output_path, target_size=(2048, 2048)):
    """
    Create a thumbnail from a whole slide image (NDPI format).
    
    Uses an appropriate pyramid level for efficient downsampling,
    then resizes to the target dimensions.
    
    Args:
        ndpi_path (str): Path to input NDPI file
        output_path (str): Path for output thumbnail image
        target_size (tuple): Target dimensions (width, height)
    """
    slide = openslide.OpenSlide(ndpi_path)
    
    # Choose a pyramid level with reasonable resolution for thumbnails
    level = slide.get_best_level_for_downsample(32)
    thumbnail = slide.read_region((0, 0), level, slide.level_dimensions[level])
    
    # Convert to RGB (removes alpha channel) and resize
    thumbnail = thumbnail.convert("RGB")
    resized = thumbnail.resize(target_size, Image.Resampling.LANCZOS)
    
    resized.save(output_path)
    
    slide.close()

## 4. Batch Thumbnail Generation

In [None]:
def batch_create_thumbnails(input_dir, output_dir, target_size=THUMBNAIL_SIZE, skip_files=None):
    """
    Create thumbnails for all NDPI files in a directory.
    
    Args:
        input_dir (str): Directory containing NDPI files
        output_dir (str): Directory for output thumbnails
        target_size (tuple): Target thumbnail dimensions
        skip_files (list): Filenames to skip
    """
    if skip_files is None:
        skip_files = []
    
    os.makedirs(output_dir, exist_ok=True)
    
    ndpi_files = [f for f in os.listdir(input_dir) if f.endswith(".ndpi")]
    print(f"Found {len(ndpi_files)} NDPI files")
    
    for filename in tqdm(ndpi_files, desc="Creating thumbnails"):
        if filename in skip_files:
            print(f"Skipping: {filename}")
            continue
        
        input_path = os.path.join(input_dir, filename)
        output_path = os.path.join(output_dir, filename.replace(".ndpi", ".png"))
        
        try:
            resize_ndpi_image(input_path, output_path, target_size)
        except Exception as e:
            print(f"Error processing {filename}: {e}")


# Run thumbnail generation
# batch_create_thumbnails(INPUT_DIR, OUTPUT_DIR, skip_files=SKIP_FILES)

## 5. Tile Extraction from WSI

In [None]:
def extract_tiles(slide_path, output_dir, tile_size=(512, 512), stride=512, 
                   filter_informative=False, threshold=0.92):
    """
    Extract tiles from a whole slide image.
    
    Systematically extracts tiles at the highest resolution level (level 0)
    with configurable tile size and stride.
    
    Args:
        slide_path (str): Path to WSI file (NDPI, SVS, etc.)
        output_dir (str): Directory for output tiles
        tile_size (tuple): Tile dimensions (width, height)
        stride (int): Step size between tiles (use tile_size for non-overlapping)
        filter_informative (bool): If True, only save tiles with tissue content
        threshold (float): White pixel ratio threshold for informativeness
    
    Returns:
        int: Number of tiles extracted
    """
    slide = openslide.OpenSlide(slide_path)
    width, height = slide.dimensions
    os.makedirs(output_dir, exist_ok=True)
    
    # Calculate total number of potential tiles
    n_tiles_x = (width - tile_size[0]) // stride + 1
    n_tiles_y = (height - tile_size[1]) // stride + 1
    total_tiles = n_tiles_x * n_tiles_y
    
    count = 0
    saved = 0
    
    # Get base filename for naming tiles
    base_name = os.path.splitext(os.path.basename(slide_path))[0]
    
    with tqdm(total=total_tiles, desc=f"Extracting tiles from {base_name}") as pbar:
        for y in range(0, height, stride):
            for x in range(0, width, stride):
                # Check if tile fits within image bounds
                if x + tile_size[0] <= width and y + tile_size[1] <= height:
                    tile = slide.read_region((x, y), 0, tile_size).convert("RGB")
                    
                    # Optionally filter for informative tiles
                    should_save = True
                    if filter_informative:
                        should_save = is_informative(tile, threshold)
                    
                    if should_save:
                        tile_name = f"{base_name}_tile_{saved}.png"
                        tile.save(os.path.join(output_dir, tile_name))
                        saved += 1
                    
                    count += 1
                    pbar.update(1)
    
    slide.close()
    print(f"Extracted {saved} tiles from {count} regions")
    return saved


def is_informative(tile, threshold=0.92):
    """
    Check if a tile contains enough tissue content.
    
    A tile is considered informative if less than `threshold` fraction
    of its pixels are white/near-white (background).
    
    Args:
        tile: PIL Image object
        threshold (float): Maximum allowed white pixel ratio
    
    Returns:
        bool: True if tile is informative (has tissue), False otherwise
    """
    arr = np.array(tile)
    gray = np.mean(arr, axis=-1)
    white_pixels = np.sum(gray > 224)
    total_pixels = gray.size
    return (white_pixels / total_pixels) < threshold

## 6. Run Tile Extraction

In [None]:
# Example: Extract tiles from a single WSI
# input_ndpi = "/path/to/your/slide.ndpi"
# output_tiles = "/path/to/output/tiles"
# extract_tiles(input_ndpi, output_tiles, tile_size=TILE_SIZE, stride=STRIDE)

# Example: Batch extract from multiple WSI files
def batch_extract_tiles(input_dir, output_dir, tile_size=(1024, 1024), stride=1024,
                        filter_informative=False, skip_files=None):
    """
    Extract tiles from all WSI files in a directory.
    
    Args:
        input_dir (str): Directory containing WSI files
        output_dir (str): Directory for output tiles
        tile_size (tuple): Tile dimensions
        stride (int): Step size between tiles
        filter_informative (bool): Filter for tissue-containing tiles
        skip_files (list): Filenames to skip
    """
    if skip_files is None:
        skip_files = []
    
    os.makedirs(output_dir, exist_ok=True)
    
    wsi_files = [f for f in os.listdir(input_dir) 
                 if f.endswith((".ndpi", ".svs", ".tif", ".tiff"))]
    
    print(f"Found {len(wsi_files)} WSI files")
    
    total_tiles = 0
    for filename in wsi_files:
        if filename in skip_files:
            print(f"Skipping: {filename}")
            continue
        
        input_path = os.path.join(input_dir, filename)
        
        try:
            n_tiles = extract_tiles(
                input_path, output_dir, 
                tile_size=tile_size, 
                stride=stride,
                filter_informative=filter_informative
            )
            total_tiles += n_tiles
        except Exception as e:
            print(f"Error processing {filename}: {e}")
    
    print(f"\nTotal tiles extracted: {total_tiles}")


# Uncomment to run batch extraction
# batch_extract_tiles(INPUT_DIR, OUTPUT_DIR, tile_size=TILE_SIZE, stride=STRIDE)

## 7. Organize Tiles by Stain Type

In [None]:
def organize_by_stain(source_dir, dest_root, keywords=None, move=False):
    """
    Organize image tiles into subdirectories based on stain type keywords.
    
    Scans filenames for stain type keywords and copies/moves matching files
    to corresponding subdirectories.
    
    Args:
        source_dir (str): Directory containing tiles to organize
        dest_root (str): Root directory for organized output
        keywords (list): Stain type keywords to match
        move (bool): If True, move files; if False, copy files
    
    Returns:
        dict: Count of files organized per stain type
    """
    if keywords is None:
        keywords = STAIN_KEYWORDS
    
    counts = {kw: 0 for kw in keywords}
    unmatched = 0
    
    files = os.listdir(source_dir)
    print(f"Processing {len(files)} files...")
    
    for filename in tqdm(files, desc="Organizing files"):
        upper_name = filename.upper()
        matched = False
        
        for keyword in keywords:
            if keyword in upper_name:
                # Create stain-specific subdirectory
                keyword_folder = os.path.join(dest_root, keyword)
                os.makedirs(keyword_folder, exist_ok=True)
                
                src_path = os.path.join(source_dir, filename)
                dst_path = os.path.join(keyword_folder, filename)
                
                if move:
                    shutil.move(src_path, dst_path)
                else:
                    shutil.copy(src_path, dst_path)
                
                counts[keyword] += 1
                matched = True
                break  # Avoid copying to multiple folders
        
        if not matched:
            unmatched += 1
    
    # Print summary
    print("\nOrganization Summary:")
    print("-" * 30)
    for kw, count in counts.items():
        if count > 0:
            print(f"  {kw}: {count} files")
    print(f"  Unmatched: {unmatched} files")
    print(f"  Total processed: {sum(counts.values()) + unmatched}")
    
    return counts


# Run organization
# organize_by_stain(TILES_SOURCE_DIR, TILES_DEST_ROOT, move=False)

## 8. Utility Functions

In [None]:
def get_slide_info(slide_path):
    """
    Print information about a whole slide image.
    
    Args:
        slide_path (str): Path to WSI file
    """
    slide = openslide.OpenSlide(slide_path)
    
    print(f"Slide: {os.path.basename(slide_path)}")
    print(f"  Dimensions (Level 0): {slide.dimensions}")
    print(f"  Level count: {slide.level_count}")
    print(f"  Level dimensions: {slide.level_dimensions}")
    print(f"  Level downsamples: {slide.level_downsamples}")
    print(f"  Properties: {dict(list(slide.properties.items())[:5])}...")
    
    slide.close()


def count_tiles_by_stain(tiles_dir, keywords=None):
    """
    Count tiles by stain type in a directory.
    
    Args:
        tiles_dir (str): Directory containing tiles
        keywords (list): Stain keywords to search for
    
    Returns:
        dict: Counts per stain type
    """
    if keywords is None:
        keywords = STAIN_KEYWORDS
    
    counts = {kw: 0 for kw in keywords}
    counts['other'] = 0
    
    for filename in os.listdir(tiles_dir):
        upper_name = filename.upper()
        matched = False
        for kw in keywords:
            if kw in upper_name:
                counts[kw] += 1
                matched = True
                break
        if not matched:
            counts['other'] += 1
    
    print("Tile counts by stain:")
    for kw, count in counts.items():
        if count > 0:
            print(f"  {kw}: {count}")
    
    return counts