# roi patching pipeline

In [6]:
# --- Deps (Colab) ---
!pip -q install google-cloud-storage pillow tqdm

import io, time, gc, os, concurrent.futures
from pathlib import Path
from typing import List
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F

from google.colab import auth as colab_auth, drive as colab_drive
import google.auth
from google.cloud import storage

# -------------------
# Auth & Mounts
# -------------------
# GCS auth for source reads
colab_auth.authenticate_user()
creds, _ = google.auth.default(scopes=['https://www.googleapis.com/auth/devstorage.read_write'])
client = storage.Client(credentials=creds)

# Mount Google Drive for destination writes
colab_drive.mount('/content/drive', force_remount=True)

# -------------------
# Config
# -------------------
BUCKET_NAME = "bracs-dataset-bucket"
bucket = client.bucket(BUCKET_NAME)

# SOURCE: GCS (unchanged)
SRC_PREFIX = "BRACS/BRACS_RoI/latest_version"   # ROIs source

# DESTINATION: GOOGLE DRIVE
DEST_DRIVE_ROOT = "/content/drive/MyDrive/BRACS/ROIPatches"  # <<< NEW

SPLITS  = ["train","val","test"]
LESIONS = ["0_N","1_PB","2_UDH","3_FEA","4_ADH","5_DCIS","6_IC"]

PATCH_SIZE = 224
STRIDE     = 74
DOWNSCALE_FACTOR = 4   # 40x -> 10x

# GPU / perf
GPU_DEVICE = "cuda:0"
torch.backends.cudnn.benchmark = True

# Threading
ENCODE_WRITE_WORKERS = 8
RETRY_DOWNLOADS = 5
RETRY_WRITES    = 5

# Batch ROIs (looping granularity)
BATCH_ROIS = 32

# Prevent DecompressionBomb warnings
Image.MAX_IMAGE_PIXELS = 300_000_000

# -------------------
# Helpers
# -------------------
def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def parse_ids(blob_name: str):
    p = Path(blob_name)
    # .../latest_version/<split>/<lesion>/<roi>.png
    return p.parent.parent.name, p.parent.name, p.stem  # split, lesion, roi_id

def list_roi_blobs() -> List[str]:
    names = []
    for split in SPLITS:
        for lesion in LESIONS:
            prefix = f"{SRC_PREFIX}/{split}/{lesion}/"
            for b in client.list_blobs(BUCKET_NAME, prefix=prefix):
                if b.name.lower().endswith(".png") and Path(b.name).parent.name == lesion:
                    names.append(b.name)
    return names

def roi_already_done_drive(split: str, lesion: str, roi_id: str) -> bool:
    """Treat ROI as done if destination folder exists and has at least 1 file."""
    dest_dir = Path(DEST_DRIVE_ROOT) / split / lesion / roi_id
    if not dest_dir.exists():
        return False
    try:
        # any() short-circuits; avoids listing huge dirs
        return any(dest_dir.iterdir())
    except Exception:
        return False

def download_bytes_with_retry(blob_name: str, retries=RETRY_DOWNLOADS):
    b = bucket.blob(blob_name)
    delay = 1.0
    for k in range(retries):
        try:
            return b.download_as_bytes()
        except Exception as e:
            if k == retries - 1:
                raise
            time.sleep(delay); delay *= 1.7

def write_bytes_with_retry(path: Path, data: bytes, retries=RETRY_WRITES):
    delay = 1.0
    for k in range(retries):
        try:
            ensure_dir(path.parent)
            with open(path, "wb") as f:
                f.write(data)
            return
        except Exception as e:
            if k == retries - 1:
                raise
            time.sleep(delay); delay *= 1.7

def decode_image_bytes(img_bytes: bytes) -> np.ndarray:
    return np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))

def downscale_to_10x(img_np: np.ndarray, factor: int = DOWNSCALE_FACTOR) -> np.ndarray:
    h, w = img_np.shape[:2]
    nh, nw = max(1, h // factor), max(1, w // factor)
    return np.array(Image.fromarray(img_np).resize((nw, nh), resample=Image.BOX))

def encode_png(np_img: np.ndarray) -> bytes:
    bio = io.BytesIO()
    Image.fromarray(np_img).save(bio, format="PNG", optimize=True)
    return bio.getvalue()

# -------------------
# Core per-ROI processing (GPU)
# -------------------
@torch.no_grad()
def process_one_roi(blob_name: str):
    """Returns (num_patches_saved, skipped) for this ROI; saves to Drive."""
    split, lesion, roi_id = parse_ids(blob_name)

    # Skip if already done on Drive
    if roi_already_done_drive(split, lesion, roi_id):
        return 0, True  # 0 saved, skipped

    # Download ROI from GCS
    try:
        roi_bytes = download_bytes_with_retry(blob_name)
    except Exception:
        return 0, False

    # Decode @40x -> downscale to 10x
    try:
        img40 = decode_image_bytes(roi_bytes)
        img   = downscale_to_10x(img40, DOWNSCALE_FACTOR)
        del img40, roi_bytes
    except Exception:
        return 0, False

    ks, st = PATCH_SIZE, STRIDE
    H, W, _ = img.shape
    if H < ks or W < ks:
        del img
        return 0, True  # treat as processed (nothing to do)

    # To GPU, unfold patches
    device = torch.device(GPU_DEVICE)
    x = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).to(device, dtype=torch.float32, non_blocking=True) / 255.0
    del img

    unf = F.unfold(x, kernel_size=ks, stride=st)  # 1 x (3*ks*ks) x L
    L = unf.shape[-1]
    if L == 0:
        del x, unf
        torch.cuda.empty_cache()
        return 0, True

    patches = unf.squeeze(0).transpose(0,1).reshape(L, 3, ks, ks).contiguous()

    # quick tissue filter (GPU)
    brightness = patches.mean(dim=(1,2,3))
    sat        = patches.std(dim=(1,2,3))
    keep_mask  = (brightness < 0.85) & (sat > 0.02)
    keep_idx   = torch.nonzero(keep_mask).squeeze(1)
    if keep_idx.numel() == 0:
        del x, unf, patches, brightness, sat, keep_mask, keep_idx
        torch.cuda.empty_cache()
        return 0, True

    kept = patches[keep_idx]
    cols = (W - ks)//st + 1
    ys_all = (torch.arange(L, device=device) // cols) * st
    xs_all = (torch.arange(L, device=device) %  cols) * st
    ys = ys_all[keep_idx].cpu().numpy().tolist()
    xs = xs_all[keep_idx].cpu().numpy().tolist()

    # Prepare Drive destination
    dest_dir = Path(DEST_DRIVE_ROOT) / split / lesion / roi_id
    ensure_dir(dest_dir)

    # Encode + write on CPU threads
    K = kept.shape[0]

    def encode_and_write(i):
        # GPU->CPU copy (non_blocking) then encode/write
        arr = (kept[i].permute(1,2,0).cpu().numpy() * 255.0).astype("uint8")
        payload = encode_png(arr)
        out_path = dest_dir / f"patch_y{ys[i]}_x{xs[i]}.png"
        write_bytes_with_retry(out_path, payload, retries=RETRY_WRITES)

    saved = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=ENCODE_WRITE_WORKERS) as ex:
        futures = [ex.submit(encode_and_write, i) for i in range(K)]
        for fut in concurrent.futures.as_completed(futures):
            try:
                fut.result()
                saved += 1
            except Exception:
                pass

    del x, unf, patches, brightness, sat, keep_mask, keep_idx, kept
    torch.cuda.empty_cache()

    return saved, False  # saved count, not skipped

# -------------------
# Run (Batched)
# -------------------
if __name__ == "__main__":
    device = torch.device(GPU_DEVICE)
    assert torch.cuda.is_available(), "No CUDA device visible."
    torch.cuda.set_device(device)

    all_blobs = list_roi_blobs()
    total = len(all_blobs)
    done = 0
    saved_total = 0
    skipped_total = 0

    for start in range(0, total, BATCH_ROIS):
        batch = all_blobs[start : start + BATCH_ROIS]
        for blob_name in batch:
            saved, skipped = process_one_roi(blob_name)
            done += 1
            saved_total += saved
            skipped_total += 1 if skipped else 0
        print(f"PROGRESS: {done}/{total} saved_patches={saved_total} skipped_rois={skipped_total}")

    print("FINISHED:",
          f"ROIs={total}, saved_patches={saved_total}, skipped_rois={skipped_total}")


Mounted at /content/drive
PROGRESS: 32/4539 saved_patches=310 skipped_rois=5
PROGRESS: 64/4539 saved_patches=1400 skipped_rois=5
PROGRESS: 96/4539 saved_patches=2268 skipped_rois=6
PROGRESS: 128/4539 saved_patches=3210 skipped_rois=6
PROGRESS: 160/4539 saved_patches=3829 skipped_rois=7
PROGRESS: 192/4539 saved_patches=4547 skipped_rois=9
PROGRESS: 224/4539 saved_patches=5243 skipped_rois=14
PROGRESS: 256/4539 saved_patches=5435 skipped_rois=33
PROGRESS: 288/4539 saved_patches=6014 skipped_rois=44
PROGRESS: 320/4539 saved_patches=6336 skipped_rois=50
PROGRESS: 352/4539 saved_patches=6846 skipped_rois=52
PROGRESS: 384/4539 saved_patches=15065 skipped_rois=56
PROGRESS: 416/4539 saved_patches=15508 skipped_rois=63
PROGRESS: 448/4539 saved_patches=16736 skipped_rois=63
PROGRESS: 480/4539 saved_patches=20047 skipped_rois=63
PROGRESS: 512/4539 saved_patches=22003 skipped_rois=63
PROGRESS: 544/4539 saved_patches=23372 skipped_rois=63
PROGRESS: 576/4539 saved_patches=25843 skipped_rois=63
PROGR