In [1]:
%load_ext autoreload
%autoreload 2

# Imports
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
from IPython.display import display
import random
import gc
from pathlib import Path

# Additional imports for normalization
import shutil
from PIL import Image
import cupy as cp

# -----------------------------------------------------------------------------
# Project imports
# -----------------------------------------------------------------------------
from src.preprocessing.generate_metadata import discover_wsi
from src.preprocessing.xml_to_mask import get_mask
from src.preprocessing.annotation_utils import resolve_annotation_path
from src.preprocessing.extract_patches import extract_patches
from src.preprocessing.load_wsi import load_wsi
from src.train.train_phase1 import train_phase1


# -----------------------------------------------------------------------------
# Global configuration (single top cell)
# -----------------------------------------------------------------------------
# Prefer a color-accurate WSI backend for patch extraction
# Options: 'openslide' (recommended), 'tiffslide' (good for TIFF/SVS), 'cucim' (fast GPU)
os.environ['HER2_WSI_BACKEND'] = os.environ.get('HER2_WSI_BACKEND', 'cucim')
print(f"WSI backend preference: {os.environ['HER2_WSI_BACKEND']}")

# Dataset configuration
BASE_DIR = 'data'
SOURCES = [
    'Yale_HER2_cohort',
    'Yale_trastuzumab_response_cohort',
    'TCGA_BRCA_Filtered',
]
OUTPUT_CSV = 'outputs/index/wsi_index.csv'

# Patch extraction configuration
PATCH_SIZE = 512
PATCH_STRIDE = 512
PATCH_SAVE_FORMAT = 'png'  # e.g., 'png' or 'jpg'

# Normalization configuration
# - NORM_INPLACE: if True, overwrite original patch files; else, write to NORM_OUTPUT_DIR mirroring structure
# - NORM_OUTPUT_DIR: target directory for normalized patches when not in-place (can be an absolute path to an external drive)
# - USE_NORMALIZED_FOR_TRAINING: if True and normalized CSV exists, use it in training
NORM_INPLACE = False
# prefer an env var named NORM_OUTPUT_DIR; fallback to a reasonable external path or local outputs folder
NORM_OUTPUT_DIR = os.environ.get('NORM_OUTPUT_DIR', '/media/thanakornbuath/patch/norm')
USE_NORMALIZED_FOR_TRAINING = False
# Number of distinct subfolders (cases) to sample to compute the reference stain profile
NORM_REFERENCE_SAMPLE_SUBFOLDERS = 40
# Use GPU for stain normalization numeric work (requires CuPy and GPU drivers). Set False to force CPU.
USE_GPU = True

# Optional: path to CSV column name that contains image path; leave None to auto-detect
CSV_PATH_COLUMN = None
# File extensions to consider (lowercase)
IMAGE_EXTS = {'.png', '.jpg', '.jpeg', '.tif', '.tiff'}
# Safety: chunk size when listing/processing very large CSVs (tweak if needed)
CHUNK_SIZE = 2000

# Logging configuration
log_dir = 'outputs/preprocessing/logs'
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, 'preprocessing.log')

def log(msg: str):
    print(msg)
    try:
        with open(log_path, 'a') as f:
            f.write(msg + '\n')
    except Exception:
        # Fallback to console if file logging fails
        pass



WSI backend preference: cucim


In [2]:
# Discover WSIs and build index CSV
csv_path = discover_wsi(
    base_dir=BASE_DIR,
    sources=SOURCES,
    output_path=OUTPUT_CSV,
)

# Load and display the results
df = pd.read_csv(csv_path)
display(df.head(5))


Processing sources:   0%|          | 0/3 [00:00<?, ?it/s]
Processing Yale_HER2_cohort:   0%|          | 0/192 [00:00<?, ?it/s][A
                                                                    [A
Processing Yale_trastuzumab_response_cohort:   0%|          | 0/85 [00:00<?, ?it/s][A
                                                                                   [A
Processing TCGA_BRCA_Filtered:   0%|          | 0/182 [00:00<?, ?it/s][A
Processing sources: 100%|██████████| 3/3 [00:00<00:00, 88.55it/s]     [A
                                                         

Unnamed: 0,wsi_path,slide_id,slide_name,annotation_name,annotation_path
0,data/Yale_HER2_cohort/SVS/Her2Neg_Case_01.svs,Her2Neg_Case_01,Her2Neg_Case_01.svs,Her2Neg_Case_01.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
1,data/Yale_HER2_cohort/SVS/Her2Neg_Case_02.svs,Her2Neg_Case_02,Her2Neg_Case_02.svs,Her2Neg_Case_02.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
2,data/Yale_HER2_cohort/SVS/Her2Neg_Case_03.svs,Her2Neg_Case_03,Her2Neg_Case_03.svs,Her2Neg_Case_03.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
3,data/Yale_HER2_cohort/SVS/Her2Neg_Case_04.svs,Her2Neg_Case_04,Her2Neg_Case_04.svs,Her2Neg_Case_04.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...
4,data/Yale_HER2_cohort/SVS/Her2Neg_Case_05.svs,Her2Neg_Case_05,Her2Neg_Case_05.svs,Her2Neg_Case_05.xml,data/Yale_HER2_cohort/Annotations/Her2Neg_Case...


In [3]:
# Extract patches for each slide
for idx, row in tqdm(df.iterrows(), total=len(df), desc='Processing slides'):
    wsi_path = row['wsi_path']
    wsi_path = str(wsi_path)  # ensure path-like string to satisfy os.path methods

    # Define patch output directory to check if it's already processed
    slide_base = os.path.splitext(os.path.basename(wsi_path))[0]
    out_dir_patches = os.path.join('outputs', 'patches', slide_base)

    # --- Skip if patches already exist ---
    if os.path.isdir(out_dir_patches) and len(os.listdir(out_dir_patches)) > 0:
        log(f"Skipping already extracted slide: {wsi_path}")
        continue

    # Resolve annotation path using helper (handles pandas NA, relative paths, and glob fallback)
    annotation_path = resolve_annotation_path(row.get('annotation_path', None), wsi_path, base_dir=BASE_DIR)
    if not annotation_path:
        log(f"Skipping slide without annotation: {wsi_path}")
        continue

    log(f"Processing slide: {wsi_path} with annotation: {annotation_path}")
    try:
        mask = get_mask(annotation_path, wsi_path)
    except Exception as e:
        log(f"Failed to generate mask for {wsi_path}: {e}")
        continue
    if mask is None:
        log(f"No mask generated for {wsi_path}")
        continue

    # At this point mask should be a 2D uint8 array (0 or 255)
    log(f'Mask shape: {mask.shape}')

    # Load WSI
    try:
        wsi_slide = load_wsi(wsi_path)
    except Exception as e:
        log(f"Failed to load WSI ({wsi_path}): {e}")
        continue
    if wsi_slide is None:
        log(f"Failed to load WSI: {wsi_path}")
        continue

    # Log which backend the loader selected
    backend = getattr(wsi_slide, 'backend', 'unknown')
    log(f'Loaded WSI backend: {backend}')

    # Determine if we should use GPU for patch extraction
    use_gpu_extraction = (backend == 'cucim')
    if use_gpu_extraction:
        log(f'Enabling GPU-accelerated patch extraction for {wsi_path}')

    # Extract patches
    try:
        patches = extract_patches(
            wsi_slide,
            mask=mask,
            size=PATCH_SIZE,
            stride=PATCH_STRIDE,
            save_dir=out_dir_patches,
            save_prefix=slide_base,
            save_format=PATCH_SAVE_FORMAT,
            use_gpu=use_gpu_extraction,
        )
    except Exception as e:
        log(f"Failed to extract patches for {wsi_path}: {e}")
        continue

    saved = sum(1 for p in patches if p.get('path'))
    log(f'Extracted {len(patches)} patches from {wsi_path}; saved {saved} to {out_dir_patches}')


Processing slides:  72%|███████▏  | 330/459 [00:00<00:00, 3261.57it/s]

Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_01.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_02.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_03.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_04.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_05.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_06.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_07.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_08.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_09.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_10.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_11.svs
Skipping already extracted slide: data/Yale_HER2_cohort/SVS/Her2Neg_Case_12.svs
Skipping already extracted slide: data/Y

Processing slides: 100%|██████████| 459/459 [00:00<00:00, 1126.93it/s]

Skipping already extracted slide: data/TCGA_BRCA_Filtered/SVS/TCGA-EW-A1PA-01Z-00-DX1.03B033F8-62C0-49E1-BDEA-C5217AB3460A.svs
Skipping already extracted slide: data/TCGA_BRCA_Filtered/SVS/TCGA-EW-A1PD-01Z-00-DX1.6F6A0122-A50B-4D00-9A3F-7A2502D44E38.svs
Skipping already extracted slide: data/TCGA_BRCA_Filtered/SVS/TCGA-EW-A1PF-01Z-00-DX1.9420D058-65CE-4DF8-815F-EB407003096E.svs





In [4]:
import os
from glob import glob
import pandas as pd
from tqdm import tqdm

log("=" * 80)
log("Starting train/val CSV generation")
log("=" * 80)

# --- 1. Create a mapping from Case Name to Label ---
log("Building case-to-label mapping...")
case_to_label = {}
skipped_tcga = []
skipped_unknown = []

# Load TCGA labels from Excel
tcga_excel_path = 'data/TCGA_BRCA_Filtered/case&annotation_counts_clean.xlsx'
if os.path.exists(tcga_excel_path):
    tcga_df = pd.read_excel(tcga_excel_path)
    log(f"Loaded TCGA labels from: {tcga_excel_path}")
    tcga_labels = {
        str(row['Slide']).strip(): 1 if str(row['Clinical.HER2.status']).strip() == 'Positive' else 0
        for _, row in tcga_df.iterrows()
        if str(row['Clinical.HER2.status']).strip() in ['Positive', 'Negative']
    }
    log(f"Created TCGA label lookup for {len(tcga_labels)} cases.")
else:
    tcga_labels = {}
    log(f"Warning: TCGA Excel file not found at {tcga_excel_path}")

# Get all case directories
patches_dir = 'outputs/patches'
case_dirs = [d for d in glob(os.path.join(patches_dir, '*')) if os.path.isdir(d)]
log(f"Found {len(case_dirs)} case directories to process.")

for case_dir in tqdm(case_dirs, desc='Determining labels for cases'):
    case_name = os.path.basename(case_dir)
    label = -1

    if case_name.startswith('TCGA-'):
        tcga_id = case_name.split('.')[0]
        if tcga_id in tcga_labels:
            label = tcga_labels[tcga_id]
        else:
            skipped_tcga.append(tcga_id)
            continue
    elif case_name.startswith(('S', 'O')):
        label = 1
    elif 'Her2Pos' in case_name or 'Pos' in case_name:
        label = 1
    elif 'Her2Neg' in case_name or 'Neg' in case_name:
        label = 0
    else:
        skipped_unknown.append(case_name)
        continue
    
    if label != -1:
        case_to_label[case_name] = label

if skipped_tcga:
    log(f"Skipped {len(skipped_tcga)} TCGA cases not found in Excel: {', '.join(skipped_tcga[:5])}{'...' if len(skipped_tcga) > 5 else ''}")
if skipped_unknown:
    log(f"Skipped {len(skipped_unknown)} cases with unknown labels: {', '.join(skipped_unknown[:5])}{'...' if len(skipped_unknown) > 5 else ''}")
log(f"Successfully mapped {len(case_to_label)} cases to labels.")

# --- 2. Collect Patches using the Case-to-Label Map ---
log("Collecting patches from all mapped cases...")
paths, labels, cases = [], [], []
for case_name, label in tqdm(case_to_label.items(), desc='Collecting patches'):
    patch_files = glob(os.path.join(patches_dir, case_name, f'*.{PATCH_SAVE_FORMAT}'))
    paths.extend(patch_files)
    labels.extend([label] * len(patch_files))
    cases.extend([case_name] * len(patch_files))

patches_df = pd.DataFrame({'path': paths, 'label': labels, 'case': cases})
log(f"Total patches collected: {len(patches_df)}")
print(f"\nTotal patches found: {len(patches_df)}")

# --- 3. Display Statistics ---
print("\nLabel distribution:")
print(patches_df['label'].value_counts())
neg_patches = (patches_df['label'] == 0).sum()
pos_patches = (patches_df['label'] == 1).sum()
log(f"Label distribution - Negative: {neg_patches}, Positive: {pos_patches}")

cases_by_label = patches_df.groupby('label')['case'].nunique()
print(f"\nNumber of cases:")
print(f"Negative cases: {cases_by_label.get(0, 0)}")
print(f"Positive cases: {cases_by_label.get(1, 0)}")
log(f"Number of cases - Negative: {cases_by_label.get(0, 0)}, Positive: {cases_by_label.get(1, 0)}")

# --- 4. Split Data by Case to Avoid Leakage ---
log("Performing train/val split by case...")
from sklearn.model_selection import train_test_split as _tts
unique_cases = list(case_to_label.keys())
case_labels = list(case_to_label.values())

train_cases, val_cases = _tts(
    unique_cases,
    test_size=0.2,
    random_state=42, 
    stratify=case_labels
)

train_df = patches_df[patches_df['case'].isin(train_cases)][['path', 'label']]
val_df = patches_df[patches_df['case'].isin(val_cases)][['path', 'label']]

log(f"Train patches: {len(train_df)}, Val patches: {len(val_df)}")
log(f"Train cases: {len(train_cases)}, Val cases: {len(val_cases)}")
print(f"\nTrain patches: {len(train_df)}")
print(f"Val patches: {len(val_df)}")
print(f"Train cases: {len(train_cases)}")
print(f"Val cases: {len(val_cases)}")

print(f"\nTrain label distribution:")
print(train_df['label'].value_counts())
print(f"\nVal label distribution:")
print(val_df['label'].value_counts())
log(f"Train split - Negative: {(train_df['label']==0).sum()}, Positive: {(train_df['label']==1).sum()}")
log(f"Val split - Negative: {(val_df['label']==0).sum()}, Positive: {(val_df['label']==1).sum()}")

# --- 5. Save CSV Files ---
train_csv_path = 'outputs/patches_index_train.csv'
val_csv_path = 'outputs/patches_index_val.csv'

train_df.to_csv(train_csv_path, index=False)
val_df.to_csv(val_csv_path, index=False)

log(f"Saved train CSV to: {train_csv_path}")
log(f"Saved val CSV to: {val_csv_path}")
log("Train/val CSV generation completed successfully")
log("=" * 80)

print(f"\nSaved train CSV to: {train_csv_path}")
print(f"\nSaved val CSV to: {val_csv_path}")

print("\nSample from train set:")
display(train_df.head())


Starting train/val CSV generation
Building case-to-label mapping...
Loaded TCGA labels from: data/TCGA_BRCA_Filtered/case&annotation_counts_clean.xlsx
Created TCGA label lookup for 182 cases.
Found 451 case directories to process.


Determining labels for cases: 100%|██████████| 451/451 [00:00<00:00, 1665168.23it/s]


Successfully mapped 451 cases to labels.
Collecting patches from all mapped cases...


Collecting patches: 100%|██████████| 451/451 [00:01<00:00, 381.00it/s]


Total patches collected: 1172444

Total patches found: 1172444

Label distribution:
label
1    621811
0    550633
Name: count, dtype: int64
Label distribution - Negative: 550633, Positive: 621811

Number of cases:
Negative cases: 187
Positive cases: 264
Number of cases - Negative: 187, Positive: 264
Performing train/val split by case...
Train patches: 933109, Val patches: 239335
Train cases: 360, Val cases: 91

Train patches: 933109
Val patches: 239335
Train cases: 360
Val cases: 91

Train label distribution:
label
1    497173
0    435936
Name: count, dtype: int64

Val label distribution:
label
1    124638
0    114697
Name: count, dtype: int64
Train split - Negative: 435936, Positive: 497173
Val split - Negative: 114697, Positive: 124638
Saved train CSV to: outputs/patches_index_train.csv
Saved val CSV to: outputs/patches_index_val.csv
Train/val CSV generation completed successfully

Saved train CSV to: outputs/patches_index_train.csv

Saved val CSV to: outputs/patches_index_val.csv

S

Unnamed: 0,path,label
0,outputs/patches/TCGA-AR-A0U4-01Z-00-DX1.DE722D...,0
1,outputs/patches/TCGA-AR-A0U4-01Z-00-DX1.DE722D...,0
2,outputs/patches/TCGA-AR-A0U4-01Z-00-DX1.DE722D...,0
3,outputs/patches/TCGA-AR-A0U4-01Z-00-DX1.DE722D...,0
4,outputs/patches/TCGA-AR-A0U4-01Z-00-DX1.DE722D...,0


### Macenko Normalization


In [5]:
# Helper: detect CSV path column
def detect_path_column(df, override=None):
    if override:
        if override in df.columns:
            return override
        raise KeyError(f"CSV_PATH_COLUMN override '{override}' not found in CSV columns")
    candidates = [c for c in df.columns if any(k in c.lower() for k in ('path', 'file', 'img', 'image', 'filename'))]
    return candidates[0] if candidates else df.columns[0]


# Resolve image path: prefer absolute/relative then recursive search under root
def resolve_image_path(raw_path_or_name, root_dir):
    if pd.isna(raw_path_or_name):
        return None
    raw = str(raw_path_or_name)
    if os.path.isabs(raw) and os.path.exists(raw):
        return raw
    cand = os.path.join(root_dir, raw)
    if os.path.exists(cand):
        return cand
    if os.path.exists(raw):
        return raw
    # fallback: recursive search by basename
    base = os.path.basename(raw)
    for rt, _, files in os.walk(root_dir):
        if base in files:
            return os.path.join(rt, base)
    return None


# Main normalization logic
from src.preprocessing.stain_normalization import MacenkoNormalizer

PATCHES_ROOT = 'outputs/patches'
TRAIN_CSV_PATH = globals().get('train_csv_path', 'outputs/patches_index_train.csv')
print('Using train CSV:', TRAIN_CSV_PATH)
df = pd.read_csv(TRAIN_CSV_PATH)
path_col = detect_path_column(df, override=CSV_PATH_COLUMN)
print('Detected path column:', path_col)

# Check for existing reference stats
ref_stats_path = os.path.join(NORM_OUTPUT_DIR, 'ref_stain_stats.npz')
if os.path.exists(ref_stats_path):
    print(f"Found existing reference stain stats: {ref_stats_path}")
    stats = np.load(ref_stats_path)
    mean_ref_stain_vectors = stats['mean_stain_vectors']
    mean_max_h = stats['mean_max_h']
    mean_max_e = stats['mean_max_e']
    print("Loaded reference stats. Skipping sampling and computation.")
else:
    print("Reference stain stats not found. Computing from samples.")
    # Collect candidate subfolders recursively under PATCHES_ROOT that contain at least one supported image
    candidate_dirs = []
    for root, dirs, files in os.walk(PATCHES_ROOT):
        # skip the root if it has no images
        imgs = [f for f in files if os.path.splitext(f)[1].lower() in IMAGE_EXTS]
        if imgs:
            candidate_dirs.append(root)

    if not candidate_dirs:
        raise RuntimeError(f'No image-containing subfolders found under {PATCHES_ROOT}')

    num_samples = min(NORM_REFERENCE_SAMPLE_SUBFOLDERS, len(candidate_dirs))
    random.seed(42)
    sampled_subfolders = random.sample(candidate_dirs, num_samples) if num_samples > 0 else []
    print(f"Sampling {len(sampled_subfolders)} subfolders from {len(candidate_dirs)} available for reference")

    # Instantiate Macenko normalizer (handle fallback to CPU if CuPy not available)
    try:
        mn_sampling = MacenkoNormalizer(use_gpu=USE_GPU)
    except Exception as e:
        print('Warning: requested GPU normalization but failed to initialize CuPy or GPU backend; falling back to CPU. Error:', e)
        mn_sampling = MacenkoNormalizer(use_gpu=False)

    # We'll compute reference stats incrementally to avoid holding many images in memory.
    all_stain_vectors = []  # each entry shape (3,2)
    all_max_h = []
    all_max_e = []
    sample_rows = []  # for saving per-file stats
    failed = 0
    processed_files = 0

    for folder in tqdm(sampled_subfolders, desc='Computing reference stats from sampled folders'):
        # discover images inside folder recursively
        for rt, _, files in os.walk(folder):
            for f in files:
                if os.path.splitext(f)[1].lower() not in IMAGE_EXTS:
                    continue
                fp = os.path.join(rt, f)
                try:
                    with Image.open(fp) as im:
                        im = im.convert('RGB')
                        arr = np.array(im)
                    # compute stain vectors and concentrations for this single image
                    try:
                        v, conc, _ = mn_sampling._get_stain_vectors_and_concentrations(arr)
                    except Exception as e:
                        tqdm.write(f'Failed computing stain stats for {fp}: {e}')
                        failed += 1
                        continue

                    # store numeric stats only
                    all_stain_vectors.append(v)
                    max_h = float(np.percentile(conc[:, 0], mn_sampling.percentiles[1])) if conc.size else 0.0
                    max_e = float(np.percentile(conc[:, 1], mn_sampling.percentiles[1])) if conc.size else 0.0
                    all_max_h.append(max_h)
                    all_max_e.append(max_e)
                    sample_rows.append({
                        'folder': folder,
                        'filepath': fp,
                        'max_h': max_h,
                        'max_e': max_e,
                        'stain_v_00': float(v[0,0]), 'stain_v_01': float(v[0,1]),
                        'stain_v_10': float(v[1,0]), 'stain_v_11': float(v[1,1]),
                        'stain_v_20': float(v[2,0]), 'stain_v_21': float(v[2,1]),
                    })
                    processed_files += 1

                    # free memory aggressively
                    del arr
                    del v
                    del conc
                    gc.collect()
                    # free CuPy GPU memory pools if using GPU
                    if getattr(mn_sampling, 'use_gpu', False) and getattr(mn_sampling, 'cp', None) is not None:
                        try:
                            mn_sampling.cp.get_default_memory_pool().free_all_blocks()
                            mn_sampling.cp.get_default_pinned_memory_pool().free_all_blocks()
                        except Exception:
                            # older/newer cupy versions may not have pinned pool; ignore failures
                            pass
                except Exception as e:
                    tqdm.write(f'Failed loading sample image {fp}: {e}')
                    failed += 1

    print(f'Processed {processed_files} sample images from {len(sampled_subfolders)} sampled subfolders; {failed} failures')
    if processed_files == 0:
        raise RuntimeError('No sample images found to compute reference stain profile')

    # Combine per-image stain vectors into mean reference stain vectors
    mean_ref_stain_vectors = np.mean(np.stack(all_stain_vectors, axis=0), axis=0)
    mean_max_h = float(np.mean(all_max_h))
    mean_max_e = float(np.mean(all_max_e))
    print('Computed mean stain vectors and concentrations: shapes:', mean_ref_stain_vectors.shape)

    # Save reference stats and per-sample CSV
    Path(NORM_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    np.savez(ref_stats_path, mean_stain_vectors=mean_ref_stain_vectors, mean_max_h=mean_max_h, mean_max_e=mean_max_e)
    print('Saved reference stats to', ref_stats_path)

    # save per-sample stats
    sample_stats_csv = os.path.join(NORM_OUTPUT_DIR, 'ref_stain_samples.csv')
    try:
        pd.DataFrame(sample_rows).to_csv(sample_stats_csv, index=False)
        print('Saved per-sample stain stats to', sample_stats_csv)
    except Exception as e:
        print('Warning: failed saving per-sample stats CSV:', e)

# Instantiate Macenko normalizer for the main normalization task
try:
    mn = MacenkoNormalizer(use_gpu=USE_GPU)
except Exception as e:
    print('Warning: requested GPU normalization but failed to initialize CuPy or GPU backend; falling back to CPU. Error:', e)
    mn = MacenkoNormalizer(use_gpu=False)

# Normalize images referenced by train CSV and save results mirroring structure under NORM_OUTPUT_DIR
import concurrent.futures

def normalize_image_worker(args):
    """Worker function to normalize a single image."""
    row_dict, path_col, mn_use_gpu, mean_ref_stain_vectors, mean_max_h, mean_max_e = args

    # It's crucial to instantiate the normalizer inside the worker
    # if using GPU to ensure correct context initialization per-process.
    try:
        mn_worker = MacenkoNormalizer(use_gpu=mn_use_gpu)
    except Exception as e:
        # Fallback or log if GPU init fails in worker
        mn_worker = MacenkoNormalizer(use_gpu=False)

    raw = row_dict[path_col]
    src = resolve_image_path(raw, PATCHES_ROOT)
    if not src:
        return None, f"Source not resolved for {raw}"

    try:
        with Image.open(src) as im:
            im = im.convert('RGB')
            arr = np.array(im)

        out_arr = mn_worker.normalize(
            arr,
            mean_ref_stain_vectors=mean_ref_stain_vectors,
            mean_ref_max_concentrations_tuple=(mean_max_h, mean_max_e)
        )

        rel = os.path.relpath(src, PATCHES_ROOT)
        dst = os.path.join(NORM_OUTPUT_DIR, rel)
        os.makedirs(os.path.dirname(dst), exist_ok=True)
        Image.fromarray(out_arr).save(dst)

        new_row = dict(row_dict)
        new_row[path_col] = dst
        return new_row, None

    except Exception as e:
        return None, f"Failed normalizing {src}: {e}"

# Prepare arguments for the workers
rows_to_process = [row.to_dict() for _, row in df.iterrows()]
worker_args = [
    (row_dict, path_col, USE_GPU, mean_ref_stain_vectors, mean_max_h, mean_max_e)
    for row_dict in rows_to_process
]

out_rows = []
errors = 0
out_csv = 'outputs/patches_index_train_norm.csv'

# Use ProcessPoolExecutor for parallel processing
# Using max_workers=None will default to the number of processors on the machine.
# Adjust if you want to limit CPU usage.
with concurrent.futures.ProcessPoolExecutor() as executor:
    # Using tqdm to show progress with multiprocessing
    results = list(tqdm(executor.map(normalize_image_worker, worker_args), total=len(rows_to_process), desc="Normalizing"))

for new_row, error_msg in results:
    if error_msg:
        errors += 1
        tqdm.write(error_msg)
    if new_row:
        out_rows.append(new_row)

if out_rows:
    pd.DataFrame(out_rows).to_csv(out_csv, index=False)
    print('Saved normalized CSV to', out_csv)
else:
    print('No normalized rows to save')


print('Normalization complete. Errors:', errors)


Using train CSV: outputs/patches_index_train.csv
Detected path column: path
Found existing reference stain stats: /media/thanakornbuath/patch/norm/ref_stain_stats.npz
Loaded reference stats. Skipping sampling and computation.


Normalizing:   5%|▌         | 47804/933109 [00:45<13:48, 1069.10it/s]Process ForkProcess-6:
Traceback (most recent call last):
  File "/home/thanakornbuath/anaconda3/envs/her2-class/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/thanakornbuath/anaconda3/envs/her2-class/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/thanakornbuath/anaconda3/envs/her2-class/lib/python3.12/concurrent/futures/process.py", line 246, in _process_worker
    call_item = call_queue.get(block=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/thanakornbuath/anaconda3/envs/her2-class/lib/python3.12/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
Exception in thread Thread-6:
Traceback (most recent call last):
  File "/home/thanakornbuath/anaconda3/envs/her2-class/lib/python3.12/threading.py", line 1052, in _

KeyboardInterrupt: 

## Training Phase 1 - ResNet-50

In [None]:
# Phase 1 — Train ResNet-50 (module)
# Use normalized CSV for training if requested and available
train_csv_for_training = train_csv_path
_norm_csv_default = 'outputs/patches_index_train_norm.csv'
if USE_NORMALIZED_FOR_TRAINING and os.path.exists(_norm_csv_default):
    train_csv_for_training = _norm_csv_default
    log(f"Using normalized train CSV for training: {train_csv_for_training}")
else:
    if USE_NORMALIZED_FOR_TRAINING:
        log(f"Requested normalized training but file not found: {_norm_csv_default}. Falling back to original train_csv.")

CFG = {
    'train_csv': train_csv_for_training,
    'val_csv': val_csv_path,
    'output_dir': 'outputs/phase1',
    'pretrained': True,
    'input_size': PATCH_SIZE,
    'batch_size': 32,
    'num_workers': 4,
    'epochs': 10,
    'lr': 1e-5,
    'weight_decay': 1e-4,
    'label_col': 'label',
    'path_col': 'path',
    'save_best_by': 'auc',
    'seed': 42,
}

results = train_phase1(CFG)
print('Best model:', results['best_model_path'])
print('Logs dir:', results['logs_dir'])