In [None]:
%load_ext autoreload
%autoreload 2

# Import library
import pandas as pd
import numpy as np
import re
import os
from tqdm import tqdm
from IPython.display import display
from glob import glob
import cv2


# Import dependency
from src.preprocessing.generate_metadata import discover_wsi
from src.preprocessing.xml_to_mask import get_mask
from src.preprocessing.annotation_utils import resolve_annotation_path
from src.preprocessing.extract_patches import extract_patches
from src.preprocessing.load_wsi import load_wsi
#from src.train.train_phase1 import train_phase1


In [None]:
# Configuration
BASE_DIR = 'data'
SOURCES = [
    'Yale_HER2_cohort',
    'Yale_trastuzumab_response_cohort',
    'TCGA_BRCA_Filtered'
]
OUTPUT_CSV = 'outputs/index/wsi_index.csv'

In [None]:
log_dir = 'outputs/preprocessing/logs'
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, 'preprocessing.log')

def log(msg):
    print(msg)
    with open(log_path, 'a') as f:
        f.write(msg + '\n')

In [None]:
csv_path = discover_wsi(
    base_dir=BASE_DIR, 
    sources=SOURCES, 
    output_path=OUTPUT_CSV
)

# Load and display the results
df = pd.read_csv(csv_path)
display(df.head(50))

In [None]:
for idx, row in tqdm(df.iterrows(), total=len(df), desc='Processing slides'):
    wsi_path = row['wsi_path']
    # Resolve annotation path using helper (handles pandas NA, relative paths, and glob fallback)
    annotation_path = resolve_annotation_path(row.get('annotation_path', None), wsi_path, base_dir=BASE_DIR)
    if not annotation_path:
        log(f"Skipping slide without annotation: {wsi_path}")
        continue

    log(f"Processing slide: {wsi_path} with annotation: {annotation_path}")
    try:
        mask = get_mask(annotation_path, wsi_path)
    except Exception as e:
        log(f"Failed to generate mask for {wsi_path}: {e}")
        continue
    if mask is None:
        log(f"No mask generated for {wsi_path}")
        continue

    # At this point mask should be a 2D uint8 array (0 or 255)
    log(f'Mask shape: {mask.shape}')

    # Load WSI using the wrapper which prefers CuCIM when available
    try:
        wsi_slide = load_wsi(wsi_path)
    except Exception as e:
        log(f"Failed to load WSI ({wsi_path}): {e}")
        continue
    if wsi_slide is None:
        log(f"Failed to load WSI: {wsi_path}")
        continue

    # Log which backend the loader selected (cucim or openslide)
    backend = getattr(wsi_slide, 'backend', None)
    log(f'Loaded WSI backend: {backend}')

    # Extract patches (extract_patches can optionally save patches to disk)
    slide_base = os.path.splitext(os.path.basename(wsi_path))[0]
    out_dir_patches = os.path.join('outputs', 'patches', slide_base)
    try:
        patches = extract_patches(
            wsi_slide,
            mask=mask,
            size=512,
            stride=512,
            save_dir=out_dir_patches,
            save_prefix=slide_base,
            save_format='png'
        )
    except Exception as e:
        log(f"Failed to extract patches for {wsi_path}: {e}")
        continue

    saved = sum(1 for p in patches if p.get('path'))
    log(f'Extracted {len(patches)} patches from {wsi_path}; saved {saved} to {out_dir_patches}')

# Phase 1 — Train ResNet-50 (module)
This trains a patch-level HER2 classifier using ResNet-50.

Inputs: two CSV files with columns `path` and `label` (0 = negative, 1 = positive).
Outputs:
- Best model: `outputs/phase1/models/model_phase1.pth`
- Logs/metrics: `outputs/phase1/logs`

In [None]:
# CSV Detail: path (image path), label (0=negative, 1=positive)
CFG = {
    'train_csv': 'outputs/patches_index_train.csv',
    'val_csv': 'outputs/patches_index_val.csv',
    'output_dir': 'outputs/phase1',
    'pretrained': True,
    'input_size': 512,
    'batch_size': 32,
    'num_workers': 4,
    'epochs': 10,
    'lr': 1e-4,
    'weight_decay': 1e-4,
    'label_col': 'label',
    'path_col': 'path',
    'save_best_by': 'auc',
    'seed': 42,
}

results = train_phase1(CFG)
print('Best model:', results['best_model_path'])
print('Logs dir:', results['logs_dir'])