# roi patching pipeline

In [1]:
!pip -q install gcsfs fsspec tqdm pillow opencv-python

from google.colab import drive
drive.mount('/content/drive')  # Accept the prompt

import os, json, io, csv, re, math
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
from PIL import Image
import fsspec
from tqdm import tqdm
import cv2

In [None]:
# pip install in Colab if not present
!pip -q install google-cloud-storage pillow tqdm

import os, io, json, csv, time, shutil, glob, threading, gc
from pathlib import Path
from multiprocessing import Queue, Event, Process, Manager
from google.cloud import storage
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from google.colab import drive as gdrive

# === Mount Google Drive ===
MOUNT_POINT = "/content/drive"
gdrive.mount(MOUNT_POINT, force_remount=True)

# === Config ===
BUCKET_NAME = "bracs-dataset-bucket"
ROOT_PREFIX = "BRACS/BRACS_RoI/latest_version"
SPLITS = ["train","val","test"]
LESIONS = ["0_N","1_PB","2_UDH","3_FEA","4_ADH","5_DCIS","6_IC"]

DRIVE_ROOT = Path("/content/drive/MyDrive/BRACS/ROIPatches")
DRIVE_ROOT.mkdir(parents=True, exist_ok=True)

PATCH_SIZE = 224
STRIDE = 74
GPU_DEVICE = "cuda:0"
NUM_PRODUCERS = 6
NUM_SAVER_THREADS = 4
SAVE_BATCH = 256
DOWNSCALE_FACTOR = 4   # 40x -> 10x

# image safety guard
Image.MAX_IMAGE_PIXELS = 300_000_000

# === GCS ===
client = storage.Client()
bucket = client.bucket(BUCKET_NAME)

def parse_ids(blob_name: str):
    p = Path(blob_name)
    return p.parent.parent.name, p.parent.name, p.stem  # split, lesion, roi_id

def list_blob_names(split, lesion):
    prefix = f"{ROOT_PREFIX}/{split}/{lesion}/"
    for b in client.list_blobs(BUCKET_NAME, prefix=prefix):
        if b.name.lower().endswith(".png") and Path(b.name).parent.name == lesion:
            yield b.name

def fetch_blob_bytes(blob_name, retries=4):
    b = bucket.blob(blob_name); backoff = 1.0
    for k in range(retries):
        try:
            return blob_name, b.download_as_bytes()
        except Exception as e:
            if k == retries-1: raise
            time.sleep(backoff); backoff *= 1.7

def decode_image_bytes(img_bytes):
    return np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))

def downscale_to_10x(img_np: np.ndarray, factor: int = DOWNSCALE_FACTOR) -> np.ndarray:
    """Area-like downscale: fast & good for reducing resolution."""
    h, w = img_np.shape[:2]
    nh, nw = max(1, h // factor), max(1, w // factor)
    im = Image.fromarray(img_np)
    # BOX is efficient for downscaling; LANCZOS is also fine but heavier
    return np.array(im.resize((nw, nh), resample=Image.BOX))

# === Producer ===
def producer_worker(work_queue: Queue, blob_names, stop_event: Event):
    for blob_name in blob_names:
        if stop_event.is_set(): break
        try:
            _, bts = fetch_blob_bytes(blob_name)
            work_queue.put((blob_name, bts))
        except Exception as e:
            print(f"[PRODUCER][ERR] {blob_name} -> {e}")

# === GPU consumer ===
def gpu_consumer(work_q: Queue, save_q: Queue, stop_evt: Event, stats=None):
    device = torch.device(GPU_DEVICE); torch.cuda.set_device(device)
    ks, st = PATCH_SIZE, STRIDE

    while not (stop_evt.is_set() and work_q.empty()):
        try:
            blob_name, bts = work_q.get(timeout=2)
        except Exception:
            continue

        split, lesion, roi_id = parse_ids(blob_name)

        # decode @40x
        try:
            img_np_40 = decode_image_bytes(bts)
        except Exception as e:
            if stats: stats["done"] += 1
            continue

        # ---- NEW: downscale to 10x (factor 4) ----
        img_np = downscale_to_10x(img_np_40, DOWNSCALE_FACTOR)
        del img_np_40, bts  # free immediately
        gc.collect()

        H, W, _ = img_np.shape
        if H < ks or W < ks:
            if stats:
                stats["done"] += 1
                if stats["done"] % 10 == 0 or stats["done"] == stats["total"]:
                    print(f"[PROGRESS] {stats['done']}/{stats['total']}")
            del img_np; gc.collect()
            continue

        # to GPU
        img_t = torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).to(device=device, dtype=torch.float32) / 255.0
        del img_np; gc.collect()

        # unfold patches
        unf = F.unfold(img_t, kernel_size=ks, stride=st)   # 1 x (3*ks*ks) x L
        L = unf.shape[-1]
        if L == 0:
            if stats:
                stats["done"] += 1
                if stats["done"] % 10 == 0 or stats["done"] == stats["total"]:
                    print(f"[PROGRESS] {stats['done']}/{stats['total']}")
            del img_t, unf; torch.cuda.empty_cache()
            continue

        patches = unf.squeeze(0).transpose(0,1).reshape(L,3,ks,ks).contiguous()
        brightness = patches.mean(dim=(1,2,3))
        sat        = patches.std(dim=(1,2,3))
        keep       = (brightness < 0.85) & (sat > 0.02)
        keep_idx   = torch.nonzero(keep).squeeze(1)
        if keep_idx.numel() == 0:
            if stats:
                stats["done"] += 1
                if stats["done"] % 10 == 0 or stats["done"] == stats["total"]:
                    print(f"[PROGRESS] {stats['done']}/{stats['total']}")
            del img_t, unf, patches, brightness, sat, keep, keep_idx
            torch.cuda.empty_cache()
            continue

        kept = patches[keep_idx]
        cols = (W - ks)//st + 1
        ys_all = (torch.arange(L, device=device) // cols) * st
        xs_all = (torch.arange(L, device=device) %  cols) * st
        ys = ys_all[keep_idx].cpu().numpy().tolist()
        xs = xs_all[keep_idx].cpu().numpy().tolist()

        # enqueue for saving (to Drive mount)
        batch, K = [], kept.shape[0]
        for i in range(K):
            patch_np = (kept[i].permute(1,2,0).cpu().numpy() * 255.0).astype("uint8")
            y, x = ys[i], xs[i]
            batch.append((blob_name, x, y, patch_np))
            if len(batch) >= SAVE_BATCH:
                save_q.put(batch); batch = []
        if batch:
            save_q.put(batch)

        # concise progress
        if stats:
            stats["done"] += 1
            if stats["done"] % 10 == 0 or stats["done"] == stats["total"]:
                print(f"[PROGRESS] {stats['done']}/{stats['total']}")

        # cleanup GPU/CPU memory (no local ROI files are kept)
        del img_t, unf, patches, brightness, sat, keep, keep_idx, kept
        torch.cuda.empty_cache(); gc.collect()

# === Saver (writes directly to mounted Drive) ===
def saver_thread(save_q: Queue, stop_evt: Event):
    while not (stop_evt.is_set() and save_q.empty()):
        try:
            batch = save_q.get(timeout=2)
        except Exception:
            continue
        for blob_name, x, y, patch_np in batch:
            split, lesion, roi = parse_ids(blob_name)
            out_dir = (DRIVE_ROOT / split / lesion / roi)
            out_dir.mkdir(parents=True, exist_ok=True)
            out_path = out_dir / f"patch_y{y}_x{x}.png"
            Image.fromarray(patch_np).save(out_path, format="PNG")
        # hint kernel to flush page cache
        try: os.sync()
        except: pass

# === Orchestration ===
if __name__ == "__main__":
    mgr = Manager()
    work_q = mgr.Queue(maxsize=NUM_PRODUCERS * 4)
    save_q = mgr.Queue(maxsize=NUM_SAVER_THREADS * 8)
    stop_evt = Event()

    # gather ROIs
    all_blob_names = []
    for split in SPLITS:
        for lesion in LESIONS:
            all_blob_names.extend(list(list_blob_names(split, lesion)))
    total_rois = len(all_blob_names)
    print("Total ROIs:", total_rois)

    # stats for concise progress
    stats = mgr.dict(total=total_rois, done=0)

    # start producers
    producers = []
    chunk = max(1, total_rois // max(1, NUM_PRODUCERS))
    for i in range(NUM_PRODUCERS):
        s = i*chunk
        e = None if i == NUM_PRODUCERS-1 else (i+1)*chunk
        p = Process(target=producer_worker, args=(work_q, all_blob_names[s:e], stop_evt))
        p.start(); producers.append(p)

    # consumer
    gpu_p = Process(target=gpu_consumer, args=(work_q, save_q, stop_evt, stats))
    gpu_p.start()

    # savers
    savers = []
    for i in range(NUM_SAVER_THREADS):
        t = threading.Thread(target=saver_thread, args=(save_q, stop_evt), daemon=True)
        t.start(); savers.append(t)

    # wait
    for p in producers: p.join()
    stop_evt.set()
    gpu_p.join()

    while not save_q.empty():
        time.sleep(1)

    # final cleanup (no ROI files were stored; free memory)
    gc.collect()
    print("[MAIN] Done.")
