In [1]:
%load_ext autoreload
%autoreload 2

# Imports
import pandas as pd
import numpy as np
import re
import os
from tqdm import tqdm
from IPython.display import display
from glob import glob
import cv2

# -----------------------------------------------------------------------------
# Global configuration (single top cell)
# -----------------------------------------------------------------------------
# Prefer a color-accurate WSI backend for patch extraction
# Options: 'openslide' (recommended), 'tiffslide' (good for TIFF/SVS), 'cucim' (fast GPU)
os.environ['HER2_WSI_BACKEND'] = os.environ.get('HER2_WSI_BACKEND', 'openslide')
print(f"WSI backend preference: {os.environ['HER2_WSI_BACKEND']}")

# Dataset configuration
BASE_DIR = 'data'
SOURCES = [
    'Yale_HER2_cohort',
    'Yale_trastuzumab_response_cohort',
    'TCGA_BRCA_Filtered',
]
OUTPUT_CSV = 'outputs/index/wsi_index.csv'

# Patch extraction configuration
PATCH_SIZE = 512
PATCH_STRIDE = 512
PATCH_SAVE_FORMAT = 'png'  # e.g., 'png' or 'jpg'

# Normalization configuration
# - NORM_INPLACE: if True, overwrite original patch files; else, write to NORM_OUTPUT_DIR mirroring structure
# - NORM_OUTPUT_DIR: target directory for normalized patches when not in-place (can be an absolute path to an external drive)
# - USE_NORMALIZED_FOR_TRAINING: if True and normalized CSV exists, use it in training
NORM_INPLACE = False
NORM_OUTPUT_DIR = os.environ.get('/media/thanakornbuath/patch/norm', 'outputs/patches_norm')
USE_NORMALIZED_FOR_TRAINING = False

# Logging configuration
log_dir = 'outputs/preprocessing/logs'
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, 'preprocessing.log')

def log(msg: str):
    print(msg)
    try:
        with open(log_path, 'a') as f:
            f.write(msg + '\n')
    except Exception:
        # Fallback to console if file logging fails
        pass

# -----------------------------------------------------------------------------
# Project imports
# -----------------------------------------------------------------------------
from src.preprocessing.generate_metadata import discover_wsi
from src.preprocessing.xml_to_mask import get_mask
from src.preprocessing.annotation_utils import resolve_annotation_path
from src.preprocessing.extract_patches import extract_patches
from src.preprocessing.load_wsi import load_wsi
from src.train.train_phase1 import train_phase1


WSI backend preference: openslide


In [3]:
# Discover WSIs and build index CSV
csv_path = discover_wsi(
    base_dir=BASE_DIR,
    sources=SOURCES,
    output_path=OUTPUT_CSV,
)

# Load and display the results
df = pd.read_csv(csv_path)
display(df.head(5))


Processing sources:   0%|          | 0/3 [00:00<?, ?it/s]
Processing Yale_HER2_cohort:   0%|          | 0/192 [00:00<?, ?it/s][A
                                                                    [A
Processing Yale_trastuzumab_response_cohort:   0%|          | 0/85 [00:00<?, ?it/s][A
                                                                                   [A
Processing TCGA_BRCA_Filtered:   0%|          | 0/182 [00:00<?, ?it/s][A
Processing sources: 100%|██████████| 3/3 [00:00<00:00, 85.31it/s]     [A
                                                         

Unnamed: 0,wsi_path,slide_id,slide_name,annotation_name,annotation_path
0,data/Yale_HER2_cohort/SVS/Her2Neg_Case_01.svs,Her2Neg_Case_01,Her2Neg_Case_01.svs,Her2Neg_Case_01.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
1,data/Yale_HER2_cohort/SVS/Her2Neg_Case_02.svs,Her2Neg_Case_02,Her2Neg_Case_02.svs,Her2Neg_Case_02.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
2,data/Yale_HER2_cohort/SVS/Her2Neg_Case_03.svs,Her2Neg_Case_03,Her2Neg_Case_03.svs,Her2Neg_Case_03.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
3,data/Yale_HER2_cohort/SVS/Her2Neg_Case_04.svs,Her2Neg_Case_04,Her2Neg_Case_04.svs,Her2Neg_Case_04.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
4,data/Yale_HER2_cohort/SVS/Her2Neg_Case_05.svs,Her2Neg_Case_05,Her2Neg_Case_05.svs,Her2Neg_Case_05.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...


In [None]:
# Extract patches for each slide
for idx, row in tqdm(df.iterrows(), total=len(df), desc='Processing slides'):
    wsi_path = row['wsi_path']

    # Define patch output directory to check if it's already processed
    slide_base = os.path.splitext(os.path.basename(wsi_path))[0]
    out_dir_patches = os.path.join('outputs', 'patches', slide_base)

    # --- Skip if patches already exist ---
    if os.path.isdir(out_dir_patches) and len(os.listdir(out_dir_patches)) > 0:
        log(f"Skipping already extracted slide: {wsi_path}")
        continue

    # Resolve annotation path using helper (handles pandas NA, relative paths, and glob fallback)
    annotation_path = resolve_annotation_path(row.get('annotation_path', None), wsi_path, base_dir=BASE_DIR)
    if not annotation_path:
        log(f"Skipping slide without annotation: {wsi_path}")
        continue

    log(f"Processing slide: {wsi_path} with annotation: {annotation_path}")
    try:
        mask = get_mask(annotation_path, wsi_path)
    except Exception as e:
        log(f"Failed to generate mask for {wsi_path}: {e}")
        continue
    if mask is None:
        log(f"No mask generated for {wsi_path}")
        continue

    # At this point mask should be a 2D uint8 array (0 or 255)
    log(f'Mask shape: {mask.shape}')

    # Load WSI
    try:
        wsi_slide = load_wsi(wsi_path)
    except Exception as e:
        log(f"Failed to load WSI ({wsi_path}): {e}")
        continue
    if wsi_slide is None:
        log(f"Failed to load WSI: {wsi_path}")
        continue

    # Log which backend the loader selected
    backend = getattr(wsi_slide, 'backend', 'unknown')
    log(f'Loaded WSI backend: {backend}')

    # Determine if we should use GPU for patch extraction
    use_gpu_extraction = (backend == 'cucim')
    if use_gpu_extraction:
        log(f'Enabling GPU-accelerated patch extraction for {wsi_path}')

    # Extract patches
    try:
        patches = extract_patches(
            wsi_slide,
            mask=mask,
            size=PATCH_SIZE,
            stride=PATCH_STRIDE,
            save_dir=out_dir_patches,
            save_prefix=slide_base,
            save_format=PATCH_SAVE_FORMAT,
            use_gpu=use_gpu_extraction,
        )
    except Exception as e:
        log(f"Failed to extract patches for {wsi_path}: {e}")
        continue

    saved = sum(1 for p in patches if p.get('path'))
    log(f'Extracted {len(patches)} patches from {wsi_path}; saved {saved} to {out_dir_patches}')


Processing slides:  67%|██████▋   | 308/459 [00:00<00:00, 3066.76it/s]

Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_01.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_02.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_03.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_04.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_05.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_06.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_07.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_08.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_09.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_10.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_11.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_12.svs
Skipping already extracted slide: data/Y

In [None]:
import os
from glob import glob
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm

log("=" * 80)
log("Starting train/val CSV generation")
log("=" * 80)

# --- 1. Create a mapping from Case Name to Label ---
log("Building case-to-label mapping...")
case_to_label = {}
skipped_tcga = []
skipped_unknown = []

# Load TCGA labels from Excel
tcga_excel_path = 'data/TCGA_BRCA_Filtered/case&annotation_counts_clean.xlsx'
if os.path.exists(tcga_excel_path):
    tcga_df = pd.read_excel(tcga_excel_path)
    log(f"Loaded TCGA labels from: {tcga_excel_path}")
    tcga_labels = {
        str(row['Slide']).strip(): 1 if str(row['Clinical.HER2.status']).strip() == 'Positive' else 0
        for _, row in tcga_df.iterrows()
        if str(row['Clinical.HER2.status']).strip() in ['Positive', 'Negative']
    }
    log(f"Created TCGA label lookup for {len(tcga_labels)} cases.")
else:
    tcga_labels = {}
    log(f"Warning: TCGA Excel file not found at {tcga_excel_path}")

# Get all case directories
patches_dir = 'outputs/patches'
case_dirs = [d for d in glob(os.path.join(patches_dir, '*')) if os.path.isdir(d)]
log(f"Found {len(case_dirs)} case directories to process.")

for case_dir in tqdm(case_dirs, desc='Determining labels for cases'):
    case_name = os.path.basename(case_dir)
    label = -1

    if case_name.startswith('TCGA-'):
        tcga_id = case_name.split('.')[0]
        if tcga_id in tcga_labels:
            label = tcga_labels[tcga_id]
        else:
            skipped_tcga.append(tcga_id)
            continue
    elif case_name.startswith(('S', 'O')):
        label = 1
    elif 'Her2Pos' in case_name or 'Pos' in case_name:
        label = 1
    elif 'Her2Neg' in case_name or 'Neg' in case_name:
        label = 0
    else:
        skipped_unknown.append(case_name)
        continue
    
    if label != -1:
        case_to_label[case_name] = label

if skipped_tcga:
    log(f"Skipped {len(skipped_tcga)} TCGA cases not found in Excel: {', '.join(skipped_tcga[:5])}{'...' if len(skipped_tcga) > 5 else ''}")
if skipped_unknown:
    log(f"Skipped {len(skipped_unknown)} cases with unknown labels: {', '.join(skipped_unknown[:5])}{'...' if len(skipped_unknown) > 5 else ''}")
log(f"Successfully mapped {len(case_to_label)} cases to labels.")

# --- 2. Collect Patches using the Case-to-Label Map ---
log("Collecting patches from all mapped cases...")
paths, labels, cases = [], [], []
for case_name, label in tqdm(case_to_label.items(), desc='Collecting patches'):
    patch_files = glob(os.path.join(patches_dir, case_name, f'*.{PATCH_SAVE_FORMAT}'))
    paths.extend(patch_files)
    labels.extend([label] * len(patch_files))
    cases.extend([case_name] * len(patch_files))

patches_df = pd.DataFrame({'path': paths, 'label': labels, 'case': cases})
log(f"Total patches collected: {len(patches_df)}")
print(f"\nTotal patches found: {len(patches_df)}")

# --- 3. Display Statistics ---
print("\nLabel distribution:")
print(patches_df['label'].value_counts())
neg_patches = (patches_df['label'] == 0).sum()
pos_patches = (patches_df['label'] == 1).sum()
log(f"Label distribution - Negative: {neg_patches}, Positive: {pos_patches}")

cases_by_label = patches_df.groupby('label')['case'].nunique()
print(f"\nNumber of cases:")
print(f"Negative cases: {cases_by_label.get(0, 0)}")
print(f"Positive cases: {cases_by_label.get(1, 0)}")
log(f"Number of cases - Negative: {cases_by_label.get(0, 0)}, Positive: {cases_by_label.get(1, 0)}")

# --- 4. Split Data by Case to Avoid Leakage ---
log("Performing train/val split by case...")
from sklearn.model_selection import train_test_split as _tts
unique_cases = list(case_to_label.keys())
case_labels = list(case_to_label.values())

train_cases, val_cases = _tts(
    unique_cases,
    test_size=0.2,
    random_state=42, 
    stratify=case_labels
)

train_df = patches_df[patches_df['case'].isin(train_cases)][['path', 'label']]
val_df = patches_df[patches_df['case'].isin(val_cases)][['path', 'label']]

log(f"Train patches: {len(train_df)}, Val patches: {len(val_df)}")
log(f"Train cases: {len(train_cases)}, Val cases: {len(val_cases)}")
print(f"\nTrain patches: {len(train_df)}")
print(f"Val patches: {len(val_df)}")
print(f"Train cases: {len(train_cases)}")
print(f"Val cases: {len(val_cases)}")

print(f"\nTrain label distribution:")
print(train_df['label'].value_counts())
print(f"\nVal label distribution:")
print(val_df['label'].value_counts())
log(f"Train split - Negative: {(train_df['label']==0).sum()}, Positive: {(train_df['label']==1).sum()}")
log(f"Val split - Negative: {(val_df['label']==0).sum()}, Positive: {(val_df['label']==1).sum()}")

# --- 5. Save CSV Files ---
train_csv_path = 'outputs/patches_index_train.csv'
val_csv_path = 'outputs/patches_index_val.csv'

train_df.to_csv(train_csv_path, index=False)
val_df.to_csv(val_csv_path, index=False)

log(f"Saved train CSV to: {train_csv_path}")
log(f"Saved val CSV to: {val_csv_path}")
log("Train/val CSV generation completed successfully")
log("=" * 80)

print(f"\nSaved train CSV to: {train_csv_path}")
print(f"\nSaved val CSV to: {val_csv_path}")

print("\nSample from train set:")
display(train_df.head())


### Macenko Nomralization


In [None]:
# Macenko stain normalization of training patches (based on train_csv)
from src.preprocessing.stain_normalization import macenko_normalization, CUPY_AVAILABLE
import shutil

# Decide whether to use GPU: honor CUPY availability
use_gpu_norm = bool(CUPY_AVAILABLE)
log("CuPy is available. Using GPU for stain normalization." if use_gpu_norm else "CuPy not found. Using CPU for stain normalization. This may be slow.")

# Read training CSV
if not os.path.exists(train_csv_path):
    log(f"Training CSV not found: {train_csv_path}")
else:
    df_train = pd.read_csv(train_csv_path)
    # Filter to paths that actually exist to avoid I/O warnings
    exists_mask = df_train['path'].apply(os.path.exists)
    n_total = len(df_train)
    n_exist = int(exists_mask.sum())
    if n_exist == 0:
        log(f"No existing patch files found among {n_total} rows in {train_csv_path}. Run extraction first.")
    else:
        df_train = df_train[exists_mask].reset_index(drop=True)
        log(f"Normalization will process {n_exist}/{n_total} existing patch files.")

        if not NORM_INPLACE:
            os.makedirs(NORM_OUTPUT_DIR, exist_ok=True)
            try:
                total_b, used_b, free_b = shutil.disk_usage(NORM_OUTPUT_DIR)
                log(f"Normalization output dir: {os.path.abspath(NORM_OUTPUT_DIR)} (free: {free_b//(1024**3)} GB)")
            except Exception as e:
                log(f"Warning: Could not determine free space for {NORM_OUTPUT_DIR}: {e}")

        # Helper to map original path to normalized path
        def _norm_out_path(p: str) -> str:
            if NORM_INPLACE:
                return p
            # Mirror original under NORM_OUTPUT_DIR preserving case subdir
            # Expect paths like outputs/patches/<case>/<file>
            rel = os.path.relpath(p, start='outputs/patches') if p.startswith('outputs/patches') else os.path.basename(p)
            out_p = os.path.join(NORM_OUTPUT_DIR, rel)
            os.makedirs(os.path.dirname(out_p), exist_ok=True)
            return out_p

        error_count = 0
        processed = 0
        skipped_existing = 0
        out_records = []  # rows for normalized CSV
        for path in tqdm(df_train['path'], desc="Normalizing training patches"):
            try:
                out_path = _norm_out_path(path)
                # Resume: skip work if target already exists (only for out-of-place)
                if (not NORM_INPLACE) and os.path.exists(out_path):
                    skipped_existing += 1
                    out_records.append({'path': out_path, 'label': int(df_train.loc[df_train['path'] == path, 'label'].iloc[0])})
                    continue

                img_bgr = cv2.imread(path)
                if img_bgr is None:
                    log(f"Warning: Could not read image {path}. Skipping.")
                    error_count += 1
                    continue
                img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
                normalized_rgb = macenko_normalization(img_rgb, use_gpu=use_gpu_norm)
                normalized_bgr = cv2.cvtColor(normalized_rgb, cv2.COLOR_RGB2BGR)
                cv2.imwrite(out_path, normalized_bgr)
                processed += 1
                # Record
                out_records.append({'path': out_path, 'label': int(df_train.loc[df_train['path'] == path, 'label'].iloc[0])})
            except Exception as e:
                log(f"Error normalizing {path}: {e}")
                error_count += 1

        # Save normalized train CSV (only files that exist in output)
        norm_csv_path = 'outputs/patches_index_train_norm.csv'
        if NORM_INPLACE:
            # In-place: originals remain, include all existing rows from df_train
            df_train_norm = df_train.copy()[['path', 'label']]
        else:
            # Out-of-place: only include rows we created or that already existed
            df_train_norm = pd.DataFrame(out_records)
            # Filter again to ensure file presence (safety)
            df_train_norm = df_train_norm[df_train_norm['path'].apply(os.path.exists)].reset_index(drop=True)
        df_train_norm.to_csv(norm_csv_path, index=False)
        log(f"Stain normalization complete. Processed: {processed}, Skipped(existing): {skipped_existing}, Errors: {error_count}")
        log(f"Saved normalized train CSV to: {norm_csv_path}")


In [None]:
# Phase 1 — Train ResNet-50 (module)
# Use normalized CSV for training if requested and available
train_csv_for_training = train_csv_path
_norm_csv_default = 'outputs/patches_index_train_norm.csv'
if USE_NORMALIZED_FOR_TRAINING and os.path.exists(_norm_csv_default):
    train_csv_for_training = _norm_csv_default
    log(f"Using normalized train CSV for training: {train_csv_for_training}")
else:
    if USE_NORMALIZED_FOR_TRAINING:
        log(f"Requested normalized training but file not found: {_norm_csv_default}. Falling back to original train_csv.")

CFG = {
    'train_csv': train_csv_for_training,
    'val_csv': val_csv_path,
    'output_dir': 'outputs/phase1',
    'pretrained': True,
    'input_size': PATCH_SIZE,
    'batch_size': 32,
    'num_workers': 4,
    'epochs': 10,
    'lr': 1e-5,
    'weight_decay': 1e-4,
    'label_col': 'label',
    'path_col': 'path',
    'save_best_by': 'auc',
    'seed': 42,
}

results = train_phase1(CFG)
print('Best model:', results['best_model_path'])
print('Logs dir:', results['logs_dir'])