## 1. Configuration

Set all paths and parameters in this cell.

In [1]:
# ======================================================================================
# CONFIGURATION - Adjust these paths and parameters
# ======================================================================================

# Paths
DATA_ROOT = "/media/thanakornbuath/data SSD/her2-attention-classifier/data"
OUTPUTS_ROOT = "/media/thanakornbuath/data SSD/her2-attention-classifier/outputs"
ZARR_OUTPUT_DIR = "/media/thanakornbuath/patch/zarr_norm"  # CHANGE THIS
PATCHES_ROOT = f"{OUTPUTS_ROOT}/patches"  # For reference sampling

# Reference stain normalization
REF_STAIN_STATS_PATH = f"{OUTPUTS_ROOT}/ref_stain_stats.npz"
NUM_REF_SUBFOLDERS = 100   # Number of random subfolders to sample
IMAGES_PER_FOLDER = 200  # Max images per sampled subfolder (reduced to prevent OOM)

# Patch extraction parameters
PATCH_SIZE = 512
STRIDE = 512  # No overlap
LEVEL = 0  # Highest resolution
TISSUE_THRESHOLD = 0.2  # Minimum 20% tissue in patch (via quick HSV-based check)
DOWNSAMPLE_MASK = 1  # Downsample mask for memory efficiency

# Performance parameters
NUM_WORKERS = 8  # Parallel patch extraction workers (threads)
BATCH_SIZE = 128  # Patches per batch write to Zarr
USE_GPU = True  # GPU acceleration for Macenko normalization (CuPy)
SKIP_EXISTING = True  # Skip existing .zarr files

# Dataset selection
COHORTS = [
    "TCGA_BRCA_Filtered",
    "Yale_HER2_cohort",
    "Yale_trastuzumab_response_cohort"
]

print("✓ Configuration loaded")
print(f"  Data root: {DATA_ROOT}")
print(f"  Zarr output: {ZARR_OUTPUT_DIR}")
print(f"  Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"  Workers: {NUM_WORKERS}, GPU: {USE_GPU}")

✓ Configuration loaded
  Data root: /media/thanakornbuath/data SSD/her2-attention-classifier/data
  Zarr output: /media/thanakornbuath/patch/zarr_norm
  Patch size: 512x512
  Workers: 8, GPU: True


## 2. Setup Environment

In [2]:
import os
import sys
import logging
import random
import json
import math
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import zarr
from PIL import Image, ImageDraw

# Add project root to path
project_root = Path("/media/thanakornbuath/data SSD/her2-attention-classifier")
sys.path.insert(0, str(project_root))

# External libs used during execution
try:
    import openslide
except Exception as e:
    print("⚠ openslide not available. Please install openslide-python and libopenslide.")
    raise

# Setup logging
os.makedirs(Path(OUTPUTS_ROOT) / "logs", exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f"{OUTPUTS_ROOT}/logs/main_preprocessing.log"),
        logging.StreamHandler()
    ]
)

# Create output directory
os.makedirs(ZARR_OUTPUT_DIR, exist_ok=True)

print("✓ Environment setup complete")
print(f"  Python version: {sys.version.split()[0]}")

# Check GPU
if USE_GPU:
    try:
        import cupy as cp
        _gpu_count = cp.cuda.runtime.getDeviceCount()
        print(f"  ✓ CuPy available: {_gpu_count} GPU(s)")
    except Exception as e:
        print(f"  ⚠ CuPy not available ({e}), using CPU")
        USE_GPU = False


✓ Environment setup complete
  Python version: 3.12.0
  ✓ CuPy available: 1 GPU(s)


## 3. Stain Normalization (Macenko) - GPU optional

In [3]:
import gc
from typing import List, Tuple, Optional, Dict

class MacenkoNormalizer:
    """Macenko stain normalization with optional CuPy acceleration for linear algebra ops.
    Only matrix ops are on GPU; I/O and conversions stay on CPU.
    """
    def __init__(self, percentiles: Tuple[float, float]=(1, 99), use_gpu: bool=False):
        self.percentiles = percentiles
        self.use_gpu = use_gpu
        self.xp = cp if (use_gpu and 'cp' in globals()) else np

    @staticmethod
    def _rgb_to_od(image_rgb: np.ndarray) -> np.ndarray:
        img = image_rgb.astype(np.float32) + 1.0  # avoid log(0), 1/255 is negligible
        od = -np.log(img / 255.0)
        return od

    @staticmethod
    def _od_to_rgb(image_od: np.ndarray) -> np.ndarray:
        rgb = (255.0 * np.exp(-image_od)).clip(0, 255).astype(np.uint8)
        return rgb

    @staticmethod
    def _to_cpu(arr):
        try:
            import cupy as cp
            if isinstance(arr, cp.ndarray):
                return cp.asnumpy(arr)
        except Exception:
            pass
        return np.asarray(arr)

    def _get_stain_vectors_and_concentrations(self, image_rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray, Tuple[int,int,int]]:
        # Convert to OD
        od = self._rgb_to_od(image_rgb)
        H, W, CH = od.shape
        od_reshaped = od.reshape(-1, 3)

        # Filter near-white background
        mask = np.sum(od_reshaped, axis=1) > 0.2
        od_filtered = od_reshaped[mask]
        if od_filtered.shape[0] < 100:
            # too few pixels; fallback to using all
            od_filtered = od_reshaped

        # Move to xp for PCA
        xp = self.xp
        odf = xp.asarray(od_filtered, dtype=xp.float32)
        # Center
        odf = odf - odf.mean(axis=0, keepdims=True)
        # SVD to get top 2 PCs
        try:
            U, S, VT = xp.linalg.svd(odf, full_matrices=False)
            v = VT.T[:, :2]  # (3,2)
            # Explicitly delete large SVD matrices
            del U, S, VT
        except Exception:
            # CPU fallback if GPU fails
            U, S, VT = np.linalg.svd(od_filtered - od_filtered.mean(axis=0, keepdims=True), full_matrices=False)
            v = VT.T[:, :2]
            del U, S, VT
            xp = np

        # Normalize columns
        v = v / xp.linalg.norm(v, axis=0, keepdims=True)

        # Ensure consistent direction (positive sum)
        for i in range(2):
            if float(v[:, i].sum()) < 0:
                v[:, i] = -v[:, i]

        # CRITICAL: Macenko angle-based stain separation
        # Project filtered OD onto the 2D plane spanned by top 2 PCs
        odf_2d = xp.asarray(od_filtered, dtype=xp.float32) @ v  # (N_filtered, 2)

        # Compute angles in 2D plane
        angles = xp.arctan2(odf_2d[:, 1], odf_2d[:, 0])

        # Find angle percentiles to define stain extremes (THIS WAS MISSING!)
        angles_cpu = self._to_cpu(angles)
        min_angle = float(np.percentile(angles_cpu, self.percentiles[0]))
        max_angle = float(np.percentile(angles_cpu, self.percentiles[1]))

        # Construct stain vectors at these extreme angles
        # min_angle → Hematoxylin (typically more bluish, first component dominant)
        # max_angle → Eosin (typically more pinkish, second component dominant)
        stain_h = xp.cos(min_angle) * v[:, 0] + xp.sin(min_angle) * v[:, 1]
        stain_e = xp.cos(max_angle) * v[:, 0] + xp.sin(max_angle) * v[:, 1]

        # Normalize stain vectors
        stain_h = stain_h / xp.linalg.norm(stain_h)
        stain_e = stain_e / xp.linalg.norm(stain_e)

        # Stack into matrix [H, E] as columns
        stain_matrix = xp.column_stack([stain_h, stain_e])

        # Project ALL pixels (od_reshaped) onto these true stain vectors
        od_all = xp.asarray(od_reshaped.astype(np.float32))
        C = od_all @ stain_matrix  # (N,2) - now correctly separated H&E concentrations
        C = xp.maximum(C, 0)  # non-negative

        # Back to CPU explicitly (avoid implicit CuPy->NumPy conversion)
        stain_vectors = self._to_cpu(stain_matrix)
        concentrations = self._to_cpu(C)

        # Cleanup intermediate arrays
        del odf_2d, angles, stain_h, stain_e, stain_matrix

        # Explicitly delete GPU arrays to free memory immediately
        del odf, od_all, v
        if self.use_gpu:
            try:
                del C  # Delete GPU version
            except Exception:
                pass

        return stain_vectors, concentrations, (H, W, CH)

    def get_mean_reference_stain_characteristics(self, list_of_reference_images_rgb: List[np.ndarray]):
        if not list_of_reference_images_rgb:
            raise ValueError("list_of_reference_images_rgb cannot be empty.")
        all_V = []
        max_h = []
        max_e = []
        for i, img in enumerate(list_of_reference_images_rgb):
            V, C, _ = self._get_stain_vectors_and_concentrations(img)
            all_V.append(V)
            # Extract percentiles and immediately delete huge C array
            h_val = float(np.percentile(C[:, 0], self.percentiles[1]))
            e_val = float(np.percentile(C[:, 1], self.percentiles[1]))
            max_h.append(h_val)
            max_e.append(e_val)
            del C, V  # Free large arrays immediately

            # Free GPU memory every 10 images to prevent pool growth
            if self.use_gpu and (i + 1) % 10 == 0:
                try:
                    import cupy as cp
                    cp.get_default_memory_pool().free_all_blocks()
                    cp.get_default_pinned_memory_pool().free_all_blocks()
                except Exception:
                    pass
            # Force garbage collection every 50 images
            if (i + 1) % 50 == 0:
                gc.collect()

        # Compute final statistics
        mean_V = np.mean(np.stack(all_V, axis=0), axis=0)
        mean_V = mean_V / np.linalg.norm(mean_V, axis=0, keepdims=True)
        mean_max_h = float(np.mean(max_h))
        mean_max_e = float(np.mean(max_e))

        # Clean up intermediate lists
        del all_V, max_h, max_e
        gc.collect()

        # Final GPU cleanup
        if self.use_gpu:
            try:
                import cupy as cp
                cp.get_default_memory_pool().free_all_blocks()
                cp.get_default_pinned_memory_pool().free_all_blocks()
            except Exception:
                pass

        return mean_V, (mean_max_h, mean_max_e)

    def normalize(self, target_image_rgb: np.ndarray,
                  mean_ref_stain_vectors: np.ndarray,
                  mean_ref_max_concentrations_tuple: Tuple[float, float]) -> np.ndarray:
        # Target characteristics
        V_t, C_t, shape = self._get_stain_vectors_and_concentrations(target_image_rgb)
        max_t_h = np.percentile(C_t[:, 0], self.percentiles[1])
        max_t_e = np.percentile(C_t[:, 1], self.percentiles[1])

        # Scale concentrations to reference
        ref_max_h, ref_max_e = mean_ref_max_concentrations_tuple
        scale_h = ref_max_h / (max_t_h + 1e-6)
        scale_e = ref_max_e / (max_t_e + 1e-6)
        Cn = C_t.copy()
        Cn[:, 0] *= scale_h
        Cn[:, 1] *= scale_e
        Cn = np.maximum(Cn, 0)

        # Reconstruct OD using reference stain vectors
        V_ref = mean_ref_stain_vectors.astype(np.float32)
        od_norm = (Cn @ V_ref.T).reshape(shape)
        rgb_norm = self._od_to_rgb(od_norm)
        return rgb_norm

# Utility to load reference params from npz
def load_reference_stain_params(npz_path: Path, use_gpu: bool=False) -> Optional[Dict]:
    try:
        if not npz_path.exists():
            return None
        data = np.load(str(npz_path))
        if 'stain_vectors' in data and ('max_h' in data or 'mean_max_h' in data):
            V = data['stain_vectors']
            max_h = float(data.get('max_h', data.get('mean_max_h')))
            max_e = float(data.get('max_e', data.get('mean_max_e')))
            return {
                'stain_vectors': V,
                'max_concentrations': (max_h, max_e),
                'use_gpu': use_gpu,
                'percentiles': (1, 99)
            }
    except Exception as e:
        logging.error(f"Failed to load reference stain parameters: {e}")
    return None


## 4. Compute Reference Stain Statistics

Sample images from existing patches and compute mean Macenko stain characteristics.

In [4]:

def load_images_from_folder(folder_path: Path, max_images: int = 10) -> List[np.ndarray]:
    images: List[np.ndarray] = []
    supported_ext = ('.png', '.jpg', '.jpeg', '.tif', '.tiff')
    if not folder_path.exists():
        return images
    # gather files (non-recursive, matches original request: per subfolder)
    image_files: List[Path] = []
    for ext in supported_ext:
        image_files.extend(list(folder_path.glob(f"*{ext}")))
        image_files.extend(list(folder_path.glob(f"*{ext.upper()}")))
    if len(image_files) == 0:
        return images
    if len(image_files) > max_images:
        image_files = random.sample(image_files, max_images)
    for p in image_files:
        try:
            with Image.open(p) as img:
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                images.append(np.array(img))
        except Exception as e:
            logging.debug(f"Skip {p}: {e}")
    return images

# Check if reference stats already exist
ref_stats_path = Path(REF_STAIN_STATS_PATH)
if ref_stats_path.exists():
    print(f"✓ Reference stats already exist at {REF_STAIN_STATS_PATH}")
    try:
        data = np.load(ref_stats_path)
        keys = list(data.keys())
        print(f"  Keys: {keys}")
        if 'max_h' in keys:
            print(f"  Max H: {float(data['max_h']):.4f}, Max E: {float(data['max_e']):.4f}")
        elif 'mean_max_h' in keys:
            print(f"  Max H: {float(data['mean_max_h']):.4f}, Max E: {float(data['mean_max_e']):.4f}")
    except Exception as e:
        print(f"  ⚠ Could not load stats: {e}. Will regenerate...")
        ref_stats_path = None
else:
    print("Reference stats not found. Computing from patches...")
    patches_root = Path(PATCHES_ROOT)
    if not patches_root.exists():
        print(f"⚠ Patches root not found: {PATCHES_ROOT}\n  Skipping reference computation.")
    else:
        # Sample reference images
        print(f"Sampling from up to {NUM_REF_SUBFOLDERS} random folders...")
        subfolders = [d for d in patches_root.iterdir() if d.is_dir()]
        num_to_sample = min(NUM_REF_SUBFOLDERS, len(subfolders))
        random.seed(42)
        sampled_folders = random.sample(subfolders, num_to_sample) if num_to_sample > 0 else []

        all_reference_images: List[np.ndarray] = []
        print(f"Loading reference images from {len(sampled_folders)} folders (max {IMAGES_PER_FOLDER} per folder)...")
        for folder_idx, folder in enumerate(tqdm(sampled_folders, desc="Loading reference images")):
            imgs = load_images_from_folder(folder, max_images=IMAGES_PER_FOLDER)
            all_reference_images.extend(imgs)
            del imgs  # Free folder batch immediately

            # Force garbage collection every 5 folders to prevent accumulation
            if (folder_idx + 1) % 5 == 0:
                gc.collect()


        print(f"✓ Loaded {len(all_reference_images)} reference images")
        if len(all_reference_images) > 0:
            print("Computing Macenko reference statistics (with aggressive memory management)...")
            normalizer = MacenkoNormalizer(use_gpu=USE_GPU, percentiles=(1, 99))

            # Compute with progress tracking
            print(f"  Processing {len(all_reference_images)} images for stain characteristics...")
            mean_V, (mean_max_h, mean_max_e) = normalizer.get_mean_reference_stain_characteristics(all_reference_images)

            # Save
            ref_stats_path = Path(REF_STAIN_STATS_PATH)
            ref_stats_path.parent.mkdir(parents=True, exist_ok=True)
            np.savez(ref_stats_path, stain_vectors=mean_V, max_h=mean_max_h, max_e=mean_max_e)

            print(f"✓ Reference statistics saved to {REF_STAIN_STATS_PATH}")
            print(f"  Max H: {mean_max_h:.4f}, Max E: {mean_max_e:.4f}")

            # Explicitly free large reference image list to avoid memory growth in notebook sessions
            del all_reference_images
            del mean_V
            del normalizer
            gc.collect()

            # Final GPU cleanup
            if USE_GPU:
                try:
                    import cupy as cp
                    cp.get_default_memory_pool().free_all_blocks()
                    cp.get_default_pinned_memory_pool().free_all_blocks()
                    print("  ✓ GPU memory pools freed")
                except Exception as e:
                    print(f"  ⚠ GPU cleanup warning: {e}")
        else:
            print("⚠ No reference images found; proceeding without normalization")


Reference stats not found. Computing from patches...
Sampling from up to 100 random folders...
Loading reference images from 100 folders (max 200 per folder)...


Loading reference images:   0%|          | 0/100 [00:00<?, ?it/s]

✓ Loaded 15889 reference images
Computing Macenko reference statistics (with aggressive memory management)...
  Processing 15889 images for stain characteristics...
✓ Reference statistics saved to /media/thanakornbuath/data SSD/her2-attention-classifier/outputs/ref_stain_stats.npz
  Max H: 2.4669, Max E: 2.4352
  ✓ GPU memory pools freed


## 5. Load Reference Stain Parameters

In [5]:
normalizer_params = load_reference_stain_params(Path(REF_STAIN_STATS_PATH), use_gpu=USE_GPU)
if normalizer_params is None:
    print("⚠️  Warning: Reference stain parameters not loaded! Patches will NOT be stain normalized.")
else:
    print("✓ Loaded reference stain parameters")
    print(f"  Max H: {normalizer_params['max_concentrations'][0]:.4f}")
    print(f"  Max E: {normalizer_params['max_concentrations'][1]:.4f}")
    print(f"  GPU: {normalizer_params['use_gpu']}")


✓ Loaded reference stain parameters
  Max H: 2.4669
  Max E: 2.4352
  GPU: True


## 6. Slide Discovery (SVS + XML)

In [6]:

def discover_slides(data_root: str, cohorts: list):
    slides = []
    for cohort in cohorts:
        cohort_dir = Path(data_root) / cohort
        svs_dir = cohort_dir / "SVS"
        xml_dir = cohort_dir / "Annotations"
        if not svs_dir.exists():
            print(f"⚠️  SVS directory not found: {svs_dir}")
            continue
        if not xml_dir.exists():
            print(f"⚠️  Annotations directory not found: {xml_dir}")
            continue
        # Optional labels
        labels_dict = {}
        for label_file_name in ['labels.csv', 'HER2_TCGA_clean.csv']:
            label_file = cohort_dir / label_file_name
            if label_file.exists():
                try:
                    df_labels = pd.read_csv(label_file)
                    if 'slide_id' in df_labels.columns and 'label' in df_labels.columns:
                        labels_dict = dict(zip(df_labels['slide_id'], df_labels['label']))
                    elif 'case_id' in df_labels.columns and 'HER2_IHC_Status' in df_labels.columns:
                        def map_her2_status(status):
                            if isinstance(status, str) and (('Positive' in status) or ('3+' in status) or ('2+' in status)):
                                return 1
                            return 0
                        labels_dict = {row['case_id']: map_her2_status(row['HER2_IHC_Status']) for _, row in df_labels.iterrows()}
                    print(f"✓ Loaded labels for {len(labels_dict)} slides from {cohort}/{label_file_name}")
                    break
                except Exception as e:
                    print(f"⚠️  Failed to load labels from {label_file}: {e}")
        # SVS files
        svs_files = list(svs_dir.glob("*.svs")) + list(svs_dir.glob("*.SVS"))
        for svs_path in svs_files:
            slide_name = svs_path.stem

            # For TCGA slides (e.g., TCGA-XX-YYYY.01234.svs), use only part before first dot
            # to match XML (e.g., TCGA-XX-YYYY.xml)
            xml_base_name = slide_name
            if slide_name.startswith("TCGA-") and "." in slide_name:
                xml_base_name = slide_name.split(".")[0]

            xml_candidates = [
                xml_dir / f"{xml_base_name}.xml",
                xml_dir / f"{xml_base_name}.XML",
                xml_dir / f"{slide_name}.xml",  # Also try full name as fallback
                xml_dir / f"{slide_name}.XML"
            ]
            xml_path = None
            for cand in xml_candidates:
                if cand.exists():
                    xml_path = cand
                    break
            if xml_path is None:
                print(f"⚠️  No XML found for {slide_name} (tried {xml_base_name}), skipping")
                continue
            label = int(labels_dict.get(slide_name, 0))
            slides.append({
                'slide_id': slide_name,
                'svs_path': str(svs_path),
                'xml_path': str(xml_path),
                'cohort': cohort,
                'label': label
            })
    return slides

slides = discover_slides(DATA_ROOT, COHORTS)
print(f"\n✓ Discovered {len(slides)} slides across {len(COHORTS)} cohorts")
print(f"\nCohort breakdown:")
for cohort in COHORTS:
    count = sum(1 for s in slides if s['cohort'] == cohort)
    print(f"  {cohort}: {count} slides")
if slides:
    labels = [s['label'] for s in slides]
    print(f"\nLabel distribution: HER2- (0): {sum(1 for l in labels if l == 0)}, HER2+ (1): {sum(1 for l in labels if l == 1)}")
else:
    print("\n⚠️  No slides discovered! Check data directory structure.")


✓ Loaded labels for 0 slides from TCGA_BRCA_Filtered/HER2_TCGA_clean.csv
⚠️  No XML found for Her2Neg_Case_34 (tried Her2Neg_Case_34), skipping
⚠️  No XML found for Her2Pos_Case_27 (tried Her2Pos_Case_27), skipping
⚠️  No XML found for Her2Pos_Case_52 (tried Her2Pos_Case_52), skipping
⚠️  No XML found for Her2Pos_Case_45 (tried Her2Pos_Case_45), skipping
⚠️  No XML found for Her2Neg_Case_31 (tried Her2Neg_Case_31), skipping

✓ Discovered 454 slides across 3 cohorts

Cohort breakdown:
  TCGA_BRCA_Filtered: 182 slides
  Yale_HER2_cohort: 187 slides
  Yale_trastuzumab_response_cohort: 85 slides

Label distribution: HER2- (0): 454, HER2+ (1): 0


## 7. XML → Mask, Grid Sampling, and Zarr Writer

In [7]:
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor, as_completed

# Simple HSV tissue filter (fast) to exclude whitespace patches
def tissue_fraction_rgb(patch_rgb: np.ndarray) -> float:
    # Convert to HSV via matplotlib (fast enough)
    import matplotlib.colors as mcolors
    hsv = mcolors.rgb_to_hsv(patch_rgb.astype(np.float32) / 255.0)
    s = hsv[..., 1]
    v = hsv[..., 2]
    tissue = (s > 0.1) & (v < 0.95)
    return float(tissue.mean())

# Parse Aperio-like XML polygons
def parse_xml_polygons(xml_path: str) -> List[np.ndarray]:
    polys: List[np.ndarray] = []
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        # Look for Regions/Vertices/Vertex
        for region in root.iter("Region"):
            pts = []
            vertices = list(region.iter("Vertex"))
            if not vertices:
                # Some schemas use 'Coordinate'
                vertices = list(region.iter("Coordinate"))
            for v in vertices:
                x = v.get('X') or v.get('x') or v.get('XCoord')
                y = v.get('Y') or v.get('y') or v.get('YCoord')
                if x is None or y is None:
                    continue
                pts.append((float(x), float(y)))
            if len(pts) >= 3:
                polys.append(np.array(pts, dtype=np.float32))
    except Exception as e:
        logging.error(f"Failed to parse XML {xml_path}: {e}")
    return polys

# Rasterize polygons into a downsampled mask
def polygons_to_mask(polygons: List[np.ndarray], level0_size: Tuple[int,int], downsample: int=16) -> np.ndarray:
    W, H = level0_size  # OpenSlide gives (width,height)
    w_small = max(1, W // downsample)
    h_small = max(1, H // downsample)
    mask_img = Image.new('L', (w_small, h_small), 0)
    draw = ImageDraw.Draw(mask_img)
    for poly in polygons:
        scaled = [(p[0]/downsample, p[1]/downsample) for p in poly]
        try:
            draw.polygon(scaled, outline=1, fill=1)
        except Exception:
            # If polygon invalid, skip
            continue
    return (np.array(mask_img) > 0)

# Generate grid centers at level 0
def generate_grid_centers(level0_size: Tuple[int,int], patch: int, stride: int) -> List[Tuple[int,int]]:
    W, H = level0_size
    xs = list(range(patch//2, W - patch//2 + 1, stride))
    ys = list(range(patch//2, H - patch//2 + 1, stride))
    centers = [(x, y) for y in ys for x in xs]
    return centers

# Create Zarr group and arrays
def create_zarr_group(zarr_path: Path, num_patches: int, patch: int) -> zarr.hierarchy.Group:
    store = zarr.DirectoryStore(str(zarr_path))
    root = zarr.group(store=store, overwrite=True)
    compressor = zarr.Blosc(cname='zstd', clevel=5, shuffle=2)
    # allow resize after filtering by setting maxshape None on first dimension
    root.create_dataset('patches', shape=(num_patches, patch, patch, 3), maxshape=(None, patch, patch, 3), chunks=(min(256, num_patches), patch, patch, 3), dtype='u1', compressor=compressor, overwrite=True)
    root.create_dataset('coords', shape=(num_patches, 2), maxshape=(None, 2), chunks=(min(4096, num_patches), 2), dtype='i4', compressor=compressor, overwrite=True)
    root.create_dataset('labels', shape=(num_patches,), maxshape=(None,), chunks=(min(4096, num_patches),), dtype='i1', compressor=compressor, overwrite=True)
    return root

# Process a single slide → Zarr
def process_slide(slide_info: dict,
                  patch_size: int,
                  stride: int,
                  level: int,
                  tissue_threshold: float,
                  downsample_mask: int,
                  normalizer_params: Optional[Dict],
                  out_dir: Path,
                  num_workers: int=4,
                  batch_size: int=128) -> bool:
    slide_id = slide_info['slide_id']
    svs_path = slide_info['svs_path']
    xml_path = slide_info['xml_path']
    label = int(slide_info['label'])

    zarr_path = out_dir / f"{slide_id}.zarr"
    if zarr_path.exists() and SKIP_EXISTING:
        # ensure it's complete by checking meta.json
        meta_ok = (zarr_path / 'meta.json').exists()
        if meta_ok:
            logging.info(f"Skip existing: {slide_id}")
            return True
        else:
            logging.info(f"Existing zarr missing meta; rewriting: {slide_id}")

    # Open slide
    slide = None
    try:
        slide = openslide.OpenSlide(svs_path)
        W, H = slide.level_dimensions[0]
        # XML polygons → mask
        polygons = parse_xml_polygons(xml_path)
        if len(polygons) == 0:
            logging.warning(f"No polygons in XML for {slide_id}; skipping")
            slide.close()
            return False
        mask = polygons_to_mask(polygons, (W, H), downsample=downsample_mask)
        # Grid centers
        centers = generate_grid_centers((W, H), patch_size, stride)
        # Keep centers inside mask
        ds = downsample_mask
        valid_centers = [(x, y) for (x, y) in centers if mask[min(H//ds-1, y//ds), min(W//ds-1, x//ds)]]
        if len(valid_centers) == 0:
            logging.warning(f"No valid centers after masking for {slide_id}; skipping")
            slide.close()
            return False

        # Create normalizer if params available
        normalizer = None
        if normalizer_params is not None:
            normalizer = MacenkoNormalizer(percentiles=normalizer_params.get('percentiles', (1,99)),
                                           use_gpu=normalizer_params.get('use_gpu', False))
            V_ref = normalizer_params['stain_vectors']
            ref_max = normalizer_params['max_concentrations']
        else:
            V_ref = None
            ref_max = None

        # Pre-create zarr arrays
        z = create_zarr_group(zarr_path, len(valid_centers), patch_size)

        # Metadata (MPP, magnification)
        mpp_x = slide.properties.get(openslide.PROPERTY_NAME_MPP_X)
        mpp_y = slide.properties.get(openslide.PROPERTY_NAME_MPP_Y)
        try:
            magnification = float(slide.properties.get('aperio.AppMag') or slide.properties.get('openslide.objective-power') or 0)
        except Exception:
            magnification = None

        # Producer-consumer: thread pool per batch
        def read_and_process(idx_center_pair):
            idx, (cx, cy) = idx_center_pair
            x0 = cx - patch_size//2
            y0 = cy - patch_size//2
            # Read region returns RGBA; convert to RGB
            region = slide.read_region((x0, y0), level, (patch_size, patch_size))
            try:
                region = region.convert('RGB')
                patch = np.array(region)
            finally:
                region.close()
            # Quick tissue filter
            if tissue_threshold > 0:
                if tissue_fraction_rgb(patch) < tissue_threshold:
                    return idx, None, (cx, cy)
            # Normalize if available
            if normalizer is not None and V_ref is not None and ref_max is not None:
                try:
                    patch = normalizer.normalize(patch, mean_ref_stain_vectors=V_ref, mean_ref_max_concentrations_tuple=ref_max)
                except Exception as e:
                    # If normalization fails, fall back to original patch
                    logging.debug(f"Norm fail at idx {idx}: {e}")
            return idx, patch, (cx, cy)

        total = len(valid_centers)
        written = 0
        with ThreadPoolExecutor(max_workers=num_workers) as ex:
            for start in tqdm(range(0, total, batch_size), total=math.ceil(total/batch_size), desc=f"{slide_id}"):
                end = min(start + batch_size, total)
                futures = [ex.submit(read_and_process, (i, valid_centers[i])) for i in range(start, end)]
                for fut in as_completed(futures):
                    idx, patch, coord = fut.result()
                    if patch is None:
                        continue
                    # write sequentially to avoid holes and reduce disk
                    z['patches'][written] = patch
                    z['coords'][written] = coord
                    z['labels'][written] = label
                    written += 1
                # Free GPU pools per batch to avoid leaks
                if USE_GPU and 'cp' in globals():
                    try:
                        cp.get_default_memory_pool().free_all_blocks()
                        cp.get_default_pinned_memory_pool().free_all_blocks()
                    except Exception:
                        pass
                gc.collect()

        # shrink datasets to actual written size to save disk/memory
        try:
            z['patches'].resize((written, patch_size, patch_size, 3))
            z['coords'].resize((written, 2))
            z['labels'].resize((written,))
        except Exception as e:
            logging.debug(f"Resize failed for {slide_id}: {e}")

        # Write meta.json
        meta = {
            'slide_id': slide_id,
            'label': label,
            'num_patches': int(written),
            'mpp_x': float(mpp_x) if mpp_x else None,
            'mpp_y': float(mpp_y) if mpp_y else None,
            'magnification': magnification,
            'patch_size': patch_size,
            'stride': stride,
            'level': level
        }
        with open(zarr_path / 'meta.json', 'w') as f:
            json.dump(meta, f, indent=2)

        slide.close()
        return True
    except Exception as e:
        logging.error(f"Exception processing slide {slide_id}: {e}")
        try:
            if slide is not None:
                slide.close()
        except Exception:
            pass
        return False


## 8. Run SVS → Zarr Processing

In [8]:
print("Starting preprocessing...")
print(f"  Output directory: {ZARR_OUTPUT_DIR}")
print(f"  Config: patch={PATCH_SIZE}, stride={STRIDE}, workers={NUM_WORKERS}, batch={BATCH_SIZE}, GPU={USE_GPU}, skip={SKIP_EXISTING}")

successful = 0
failed = 0
skipped = 0
failed_slides = []

for slide_info in tqdm(slides, desc="Processing slides"):
    zarr_path = Path(ZARR_OUTPUT_DIR) / f"{slide_info['slide_id']}.zarr"
    if SKIP_EXISTING and zarr_path.exists() and (zarr_path / 'meta.json').exists():
        skipped += 1
        continue
    ok = process_slide(
        slide_info=slide_info,
        patch_size=PATCH_SIZE,
        stride=STRIDE,
        level=LEVEL,
        tissue_threshold=TISSUE_THRESHOLD,
        downsample_mask=DOWNSAMPLE_MASK,
        normalizer_params=normalizer_params,
        out_dir=Path(ZARR_OUTPUT_DIR),
        num_workers=NUM_WORKERS,
        batch_size=BATCH_SIZE,
    )
    if ok:
        successful += 1
    else:
        failed += 1
        failed_slides.append(slide_info['slide_id'])
    if (successful + failed) % 3 == 0:
        gc.collect()
        if USE_GPU and 'cp' in globals():
            try:
                cp.get_default_memory_pool().free_all_blocks()
                cp.get_default_pinned_memory_pool().free_all_blocks()
            except Exception:
                pass

print(f"\n{'='*60}")
print("Processing complete!")
print(f"  Successful: {successful}")
print(f"  Failed: {failed}")
print(f"  Skipped (existing): {skipped}")
print(f"  Total: {len(slides)}")
if failed_slides:
    print(f"Failed slides: {', '.join(failed_slides[:10])}{'...' if len(failed_slides) > 10 else ''}")
print(f"{'='*60}")


Starting preprocessing...
  Output directory: /media/thanakornbuath/patch/zarr_norm
  Config: patch=512, stride=512, workers=8, batch=128, GPU=True, skip=True


Processing slides:   0%|          | 0/454 [00:00<?, ?it/s]

  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)


TCGA-AN-A046-01Z-00-DX1.C529B94F-AFE3-4701-BC98-5D6EDF7B82C0:   0%|          | 0/14 [00:00<?, ?it/s]

KeyboardInterrupt: 

## 9. Create Train/Val Split Manifest

In [None]:
from sklearn.model_selection import train_test_split

zarr_files = list(Path(ZARR_OUTPUT_DIR).glob("*.zarr"))
print(f"Found {len(zarr_files)} Zarr files")

if len(zarr_files) > 0:
    zarr_manifest = []
    for zarr_path in tqdm(zarr_files, desc="Reading Zarr metadata"):
        meta_path = zarr_path / "meta.json"
        if meta_path.exists():
            with open(meta_path, 'r') as f:
                meta = json.load(f)
            zarr_manifest.append({
                'zarr_path': str(zarr_path),
                'slide_id': meta.get('slide_id', zarr_path.stem),
                'label': meta.get('label', 0),
                'num_patches': meta.get('num_patches', 0)
            })
    df_zarr = pd.DataFrame(zarr_manifest)
    print(f"\n✓ Loaded metadata for {len(df_zarr)} Zarr files")
    print(f"Total patches: {df_zarr['num_patches'].sum():,}")
    if len(df_zarr) > 1 and df_zarr['label'].nunique() > 1:
        train_df, val_df = train_test_split(df_zarr, test_size=0.2, stratify=df_zarr['label'], random_state=42)
    else:
        train_df, val_df = train_test_split(df_zarr, test_size=0.2, random_state=42)
    train_manifest_path = f"{OUTPUTS_ROOT}/zarr_train_manifest.csv"
    val_manifest_path = f"{OUTPUTS_ROOT}/zarr_val_manifest.csv"
    train_df.to_csv(train_manifest_path, index=False)
    val_df.to_csv(val_manifest_path, index=False)
    print(f"\n✓ Train manifest saved: {train_manifest_path}")
    print(f"✓ Val manifest saved: {val_manifest_path}")
else:
    print("⚠️  No Zarr files found! Preprocessing may have failed.")


## 10. Verify Zarr Loading and Visualization

In [None]:
zarr_files = list(Path(ZARR_OUTPUT_DIR).glob("*.zarr"))
if len(zarr_files) > 0:
    test_zarr_path = zarr_files[0]
    print(f"Testing Zarr: {test_zarr_path.name}")
    z = zarr.open(str(test_zarr_path), mode='r')
    print(f"patches: {z['patches'].shape} {z['patches'].dtype}")
    print(f"coords: {z['coords'].shape} {z['coords'].dtype}")
    print(f"labels: {z['labels'].shape} {z['labels'].dtype}")
    meta_path = test_zarr_path / "meta.json"
    if meta_path.exists():
        with open(meta_path, 'r') as f:
            meta = json.load(f)
        print(f"Meta: slide={meta['slide_id']}, label={meta['label']}, num_patches={meta['num_patches']}")
    # Show a grid of patches
    n = min(12, z['patches'].shape[0])
    if n > 0:
        fig, axes = plt.subplots(3, 4, figsize=(12, 9))
        axes = axes.flatten()
        for i in range(n):
            axes[i].imshow(z['patches'][i])
            axes[i].set_title(f"{i}")
            axes[i].axis('off')
        for j in range(n, len(axes)):
            axes[j].axis('off')
        plt.tight_layout()
        plt.show()
else:
    print("⚠️  No Zarr files found for testing")


## 11. Summary

In [None]:
print("="*60)
print("PREPROCESSING PIPELINE SUMMARY")
print("="*60)

print(f"\n1. Reference Stain Normalization")
if Path(REF_STAIN_STATS_PATH).exists():
    print(f"   ✓ Reference stats: {REF_STAIN_STATS_PATH}")
else:
    print(f"   ✗ Reference stats not found")

print(f"\n2. Zarr Files Created")
zarr_files = list(Path(ZARR_OUTPUT_DIR).glob("*.zarr"))
print(f"   Total Zarr files: {len(zarr_files)}")
if zarr_files:
    total_patches = 0
    for zf in zarr_files:
        meta_path = zf / "meta.json"
        if meta_path.exists():
            with open(meta_path, 'r') as f:
                meta = json.load(f)
                total_patches += meta.get('num_patches', 0)
    print(f"   Total patches: {total_patches:,}")

print(f"\n3. Train/Val Manifests")
train_manifest = Path(f"{OUTPUTS_ROOT}/zarr_train_manifest.csv")
val_manifest = Path(f"{OUTPUTS_ROOT}/zarr_val_manifest.csv")
print(f"   Train manifest: {'✓' if train_manifest.exists() else '✗'} {train_manifest}")
print(f"   Val manifest: {'✓' if val_manifest.exists() else '✗'} {val_manifest}")

print("\n4. Next Steps: Use PatchDataset to train ResNet-50 / EfficientNet-B0")
print("="*60)
print("✅ Preprocessing pipeline complete!")
print("="*60)