# roi patching pipeline

In [1]:
!pip -q install gcsfs fsspec tqdm pillow opencv-python

from google.colab import drive
drive.mount('/content/drive')  # Accept the prompt

import os, json, io, csv, re, math
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
from PIL import Image
import fsspec
from tqdm import tqdm
import cv2

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# pip install in Colab if not present
!pip -q install google-cloud-storage pillow tqdm

import os, io, json, csv, time
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing import Queue, Event, Process
from google.cloud import storage
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from google.colab import auth
auth.authenticate_user()

# === Config ===
BUCKET_NAME = "bracs-dataset-bucket"
ROOT_PREFIX = "BRACS/BRACS_RoI/latest_version"
SPLITS = ["train","val","test"]
LESIONS = ["0_N","1_PB","2_UDH","3_FEA","4_ADH","5_DCIS","6_IC"]

DRIVE_ROOT = Path("/content/drive/MyDrive/BRACS/ROIPatches")
DRIVE_ROOT.mkdir(parents=True, exist_ok=True)

PATCH_SIZE = 224
STRIDE = 74
MIN_PATCHES = 5
BATCH_MAX_PATCHES = 4096  # max patches to process per GPU step (tune to GPU mem)
GPU_DEVICE = "cuda:0"     # change if using another GPU idx
NUM_PRODUCERS = 6         # number of parallel fetchers (tune with CPU & network)
NUM_SAVER_THREADS = 4     # threads writing images to Drive

# PIL image size guard: increase for big ROIs (see note)
from PIL import Image
Image.MAX_IMAGE_PIXELS = 300_000_000

# === Helpers ===
client = storage.Client()
bucket = client.bucket(BUCKET_NAME)

# --- Helper: parse split / lesion / roi_id from blob path ---
def parse_ids(blob_name: str):
    parts = Path(blob_name).parts
    # .../BRACS/BRACS_RoI/latest_version/{split}/{lesion}/{roi}.png
    lesion = parts[-2]
    split  = parts[-3]
    roi_id = Path(blob_name).stem
    return split, lesion, roi_id


# --- Add near the top (helper) ---
def roi_already_processed(blob_name: str) -> bool:
    """
    Check if this ROI already exists in Drive output.
    We consider an ROI processed if its folder exists and contains any patch files.
    """
    parts = Path(blob_name).parts
    try:
        lesion = parts[-2]
        split = parts[-3]
        roi_id = Path(blob_name).stem
    except IndexError:
        return False
    roi_dir = DRIVE_ROOT / split / lesion / roi_id
    if roi_dir.exists():
        # Check for at least one patch image file
        patch_files = list(roi_dir.glob("patch_*.png"))
        if len(patch_files) > 0:
            return True
    return False


# --- Updated process_one_blob (used by GPU consumer or CPU version) ---
def process_one_blob(blob_name: str, bts: bytes):
    """
    Processes one ROI blob (downloaded bytes already available).
    Skips if ROI already processed in Drive.
    """
    # --- Check for already processed ROI ---
    if roi_already_processed(blob_name):
        print(f"[SKIP] {blob_name} (already processed)")
        return 0, "skipped_existing"

    # Decode image
    try:
        img = decode_image_bytes(bts)
    except Exception as e:
        print(f"[ERROR] Failed to decode {blob_name}: {e}")
        return 0, "decode_error"

    # --- Your existing tiling and patch saving logic ---
    # (GPU vectorized version or CPU version goes here)
    # e.g., pass img_np to tile_roi or GPU unfold etc.
    # At the end, return (num_patches, status)

    # Example dummy return to show structure:
    return len([]), "ok"  # replace with actual logic


def list_blob_names(split, lesion):
    prefix = f"{ROOT_PREFIX}/{split}/{lesion}/"
    blobs = client.list_blobs(BUCKET_NAME, prefix=prefix)
    for b in blobs:
        if b.name.lower().endswith(".png") and Path(b.name).parent.name == lesion:
            yield b.name

def fetch_blob_bytes(blob_name):
    b = bucket.blob(blob_name)
    return blob_name, b.download_as_bytes()

# quick CPU decode that we use in producer (small overhead)
def decode_image_bytes(img_bytes):
    return np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))

# create output dir
def save_patch_cpu(out_path: Path, patch_np: np.ndarray):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    Image.fromarray(patch_np).save(out_path, format="PNG")

# === Producer function (runs in separate processes) ===
def producer_worker(work_queue: Queue, blob_names_iterable, stop_event: Event):
    # fetch blobs and put into queue as (blob_name, bytes)
    for blob_name in blob_names_iterable:
        if stop_event.is_set():
            break
        try:
            # download bytes (fast, single call)
            blob = bucket.blob(blob_name)
            bts = blob.download_as_bytes()
            work_queue.put((blob_name, bts))
        except Exception as e:
            print("Producer error", blob_name, e)
    # producer done
    return

# === GPU consumer (single process) ===
def gpu_consumer(work_queue: Queue, save_queue: Queue, stop_event: Event):
    device = torch.device(GPU_DEVICE)
    torch.cuda.set_device(device)

    # knobs
    ks = PATCH_SIZE
    st = STRIDE
    MAX_L_PER_BLOCK = 40_000   # max number of patches handled per block (tune)
    SAVE_CHUNK = 256           # how many patches to enqueue per save batch

    while not (stop_event.is_set() and work_queue.empty()):
        try:
            blob_name, bts = work_queue.get(timeout=2)
        except Exception:
            continue

        # skip if already processed
        if roi_already_processed(blob_name):
            print(f"[SKIP] {blob_name} (already processed)")
            continue

        split, lesion, roi_id = parse_ids(blob_name)
        print(f"[PATCH] start {split}/{lesion}/{roi_id}")

        # ---- decode on CPU
        try:
            img_np = decode_image_bytes(bts)  # H x W x 3, uint8
        except Exception as e:
            print(f"[ERROR] Decode failed {blob_name}: {e}")
            continue

        H, W, _ = img_np.shape
        if H < ks or W < ks:
            print(f"[SKIP] {split}/{lesion}/{roi_id} (too small: {H}x{W})")
            continue

        # ---- to GPU (FP16)
        img_t = (torch.from_numpy(img_np)
                 .permute(2,0,1)              # 3xHxW
                 .unsqueeze(0)                # 1x3xHxW
                 .to(device=device, dtype=torch.float16)) / 255.0

        # number of starting positions along rows/cols
        n_rows = (H - ks)//st + 1
        n_cols = (W - ks)//st + 1
        if n_rows <= 0 or n_cols <= 0:
            print(f"[SKIP] {split}/{lesion}/{roi_id} (no full patches)")
            del img_t; torch.cuda.empty_cache()
            continue

        # choose rows per block so rows_per_block * n_cols <= MAX_L_PER_BLOCK
        rows_per_block = max(1, min(n_rows, MAX_L_PER_BLOCK // max(1, n_cols)))
        # band height we need to slice: (rows_per_block-1)*st + ks
        band_h = (rows_per_block - 1)*st + ks

        total_kept = 0
        # process vertical bands
        for row_start in range(0, n_rows, rows_per_block):
            # compute the y-range in pixels we need
            y0 = row_start * st
            y1 = y0 + band_h
            if y1 > H:
                # last band: clamp to image bottom, recompute rows_this
                y1 = H
                # recompute rows_this from actual band height
                rows_this = ( (y1 - y0) - ks ) // st + 1
            else:
                rows_this = rows_per_block

            # crop band (still on GPU)
            band = img_t[:, :, y0:y1, :]  # 1x3x(band_h)xW

            # unfold on the band
            # L_band = rows_this * n_cols
            patches_unf = F.unfold(band, kernel_size=ks, stride=st)  # 1 x (3*ks*ks) x L_band
            L_band = patches_unf.shape[-1]
            if L_band == 0:
                del patches_unf, band
                torch.cuda.empty_cache()
                continue

            # compute brightness & “sat” cheaply without full reshape
            # mean over patch dimension (dim=1 is flattened 3*ks*ks)
            brightness = patches_unf.mean(dim=1)                             # [L_band]
            # approximate saturation via std over flattened channels/pixels
            sat = patches_unf.std(dim=1)                                     # [L_band]

            keep_mask = (brightness < 0.85) & (sat > 0.02)
            keep_idxs = torch.nonzero(keep_mask, as_tuple=False).squeeze(1)  # [K]
            if keep_idxs.numel() == 0:
                del patches_unf, band, brightness, sat, keep_mask
                torch.cuda.empty_cache()
                continue

            # We will stream patches in chunks to avoid big gathers
            # Compute coordinates for all L_band patches in this band
            cols_idx = torch.arange(L_band, device=device) % n_cols
            rows_idx = torch.arange(L_band, device=device) // n_cols
            xs_all = cols_idx * st
            ys_all = rows_idx * st + y0  # offset by band start in pixels

            xs_keep = xs_all[keep_idxs].to("cpu", non_blocking=True).tolist()
            ys_keep = ys_all[keep_idxs].to("cpu", non_blocking=True).tolist()

            # stream save in chunks: gather a small slice of columns from patches_unf
            # We need actual 3xksxks patches to write PNGs:
            # For chunked gather, take indices slice and reshape only that slice.
            K = keep_idxs.numel()
            start = 0
            while start < K:
                end = min(start + SAVE_CHUNK, K)
                idx_slice = keep_idxs[start:end]

                # gather subset: (1, C*ks*ks, m)
                sub_unf = patches_unf.index_select(dim=2, index=idx_slice)
                # reshape to (m, 3, ks, ks)
                sub = sub_unf.squeeze(0).transpose(0,1).reshape(-1, 3, ks, ks).contiguous()

                # to CPU uint8
                sub_cpu = (sub.permute(0,2,3,1).to(dtype=torch.float32) * 255.0).to("cpu", non_blocking=True).numpy().astype("uint8")

                # coords for this slice
                xs_slice = xs_keep[start:end]
                ys_slice = ys_keep[start:end]

                # enqueue save batch
                batch = []
                for i in range(len(xs_slice)):
                    batch.append((blob_name, xs_slice[i], ys_slice[i], sub_cpu[i]))
                save_queue.put(batch)

                total_kept += len(xs_slice)

                # free slice tensors
                del sub_unf, sub, sub_cpu
                start = end
                # optional: tiny sync helps fragmentation on some drivers
                torch.cuda.synchronize()

            # free band tensors
            del patches_unf, band, brightness, sat, keep_mask, keep_idxs
            torch.cuda.empty_cache()

        print(f"[PATCH] extracted {total_kept} patches for {split}/{lesion}/{roi_id} -> enqueued")

        # free big image tensor
        del img_t
        torch.cuda.empty_cache()


# === Saver: runs in ThreadPool (multiple threads) to write PNGs to Drive ===
def saver_thread(save_queue: Queue, stop_event: Event):
    while not (stop_event.is_set() and save_queue.empty()):
        try:
            batch = save_queue.get(timeout=2)  # batch = list of (blob_name, x, y, patch_np)
        except Exception:
            continue
        for blob_name, x, y, patch_np in batch:
            # compute split/lesion/roi id from blob_name like ROOT_PREFIX/{split}/{lesion}/BRACS_123.png
            parts = Path(blob_name).parts
            # parts[-3] = lesion, parts[-4] = split in the given layout
            lesion = parts[-2]
            split = parts[-3]
            roi_fname = Path(blob_name).stem
            out_dir = DRIVE_ROOT / split / lesion / roi_fname
            out_dir.mkdir(parents=True, exist_ok=True)
            out_path = out_dir / f"patch_y{y}_x{x}.png"
            Image.fromarray(patch_np).save(out_path, format="PNG")
        # optionally flush, write manifest etc.

# === Orchestration ===
if __name__ == "__main__":
    from multiprocessing import Manager
    mgr = Manager()
    work_q = mgr.Queue(maxsize=NUM_PRODUCERS * 4)
    save_q = mgr.Queue(maxsize=NUM_SAVER_THREADS * 8)
    stop_evt = mgr.Event()

    # prepare lists of blob names
    all_blob_names = []
    for split in SPLITS:
        for lesion in LESIONS:
            all_blob_names.extend(list(list_blob_names(split, lesion)))
    print("Total ROIs:", len(all_blob_names))

    # start producers as Processes
    producers = []
    chunk_size = max(1, len(all_blob_names) // NUM_PRODUCERS)
    for i in range(NUM_PRODUCERS):
        start = i * chunk_size
        end = None if i == NUM_PRODUCERS - 1 else (i+1) * chunk_size
        p_blob_names = all_blob_names[start:end]
        p = Process(target=producer_worker, args=(work_q, p_blob_names, stop_evt))
        p.start()
        producers.append(p)

    # start GPU consumer
    gpu_proc = Process(target=gpu_consumer, args=(work_q, save_q, stop_evt))
    gpu_proc.start()

    # start saver threads
    import threading
    savers = []
    for i in range(NUM_SAVER_THREADS):
        t = threading.Thread(target=saver_thread, args=(save_q, stop_evt), daemon=True)
        t.start()
        savers.append(t)

    # wait for producers to finish
    for p in producers:
        p.join()

    # producers done: set stop event and wait for queues to drain
    stop_evt.set()
    gpu_proc.join(timeout=600)
    # wait until save_q empty
    while not save_q.empty():
        time.sleep(1)
    print("All done.")


Total ROIs: 4539
[SKIP] BRACS/BRACS_RoI/latest_version/train/1_PB/BRACS_1642_PB_3.png (already processed)
[SKIP] BRACS/BRACS_RoI/latest_version/train/4_ADH/BRACS_1903_ADH_8.png (already processed)
[SKIP] BRACS/BRACS_RoI/latest_version/train/0_N/BRACS_1003675_N_1.png (already processed)
[SKIP] BRACS/BRACS_RoI/latest_version/train/3_FEA/BRACS_1506_FEA_11.png (already processed)Producer error
 BRACS/BRACS_RoI/latest_version/train/5_DCIS/BRACS_752_DCIS_38.png Checksum mismatch while downloading:

  https://storage.googleapis.com/download/storage/v1/b/bracs-dataset-bucket/o/BRACS%2FBRACS_RoI%2Flatest_version%2Ftrain%2F5_DCIS%2FBRACS_752_DCIS_38.png?alt=media

The X-Goog-Hash header indicated an MD5 checksum of:

  TN/TDfQ4Z8Z8giiXRDmJog==

but the actual MD5 checksum of the downloaded contents was:

  FqDALFhwtyRqZdBPQL6gDg==

The X-Goog-Stored-Content-Length is 11230062. The X-Goog-Stored-Content-Encoding is identity.

The download request read 11230062 bytes of data.
If the download was i

KeyboardInterrupt: 

Process Process-6:
Process Process-8:
Process Process-7:
Traceback (most recent call last):
