In [None]:
from pathlib import Path
from collections import defaultdict, OrderedDict
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
from PIL import Image, ImageDraw
from tqdm import tqdm
import os
import numpy as np
import pandas as pd
from torchvision import transforms
import cv2
import torch

In [None]:
IMG_MASK = True
HEATMAP_MASK = False
DRAW_LAYER = True
Thickness_DIR = "/orange/ruogu.fang/tienyuchang/IRB2024_OCT_thickness/Data/"
Thickness_CSV = "/orange/ruogu.fang/tienyuchang/IRB2024_OCT_thickness/thickness_map.csv"
Task_list = ['DME']
dataset_fname = 'sampled_labels01.csv'
dataset_dir = '/blue/ruogu.fang/tienyuchang/OCT_EDA'
img_p_fmt = "label_%d/%s" #label index and oct_img name

In [None]:
#mask function
def masked_img_func(img, mask_slice):
    binary_mask = np.zeros_like(img, dtype=np.uint8)
    for i in range(mask_slice.shape[0]-1):
        upper = mask_slice[i].astype(int)
        lower = mask_slice[i+1].astype(int)
        for x in range(img.shape[1]):
            binary_mask[upper[x]:lower[x], x] = 1

    # 套用 mask (把 mask=0 的地方設為 0)
    masked_img = img.copy()
    masked_img[binary_mask == 0] = 0

    return masked_img

# Data loading and preprocessing functions
def load_sample_data(task, num_sample=-1):
    """Load sample images for a given task"""
    df = pd.read_csv(os.path.join(dataset_dir, "%s_sampled"%task, dataset_fname))
    if IMG_MASK or HEATMAP_MASK:
        masked_df = pd.read_csv(Thickness_CSV)
        masked_df = masked_df.rename(columns={'OCT':'folder'}).dropna(subset=['Surface Name'])
        df = df.merge(masked_df,on='folder',how='inner').reset_index(drop=True)
        print('After adding mask, data len: ', df.shape[0])
    task_df = df[df['label'].isin([0, 1])]  # Adjust based on actual DME labels
    # Sample random images
    if num_sample > 0:
        task_df = task_df.sample(n=num_sample, random_state=42).reset_index(drop=True)
    else:
        task_df = task_df.reset_index(drop=True)
    
    images = []
    labels = []
    filenames = []
    mask_slices = []
    
    for _, row in task_df.iterrows():
        # Extract just the filename from oct_img
        filename = os.path.basename(row['OCT']) if isinstance(row['OCT'], str) else row['OCT']
        img_path = os.path.join(dataset_dir, "%s_sampled"%task, img_p_fmt % (row['label'], filename))
        if os.path.exists(img_path):
            try:
                img = Image.open(img_path).convert('RGB')
                if IMG_MASK or HEATMAP_MASK:
                    mask_path = os.path.join(Thickness_DIR, row['folder'], row['Surface Name'])
                    mask = np.load(mask_path) # (Layer, slice, W)

                    # 假設我們要套用其中某一 slice 的 mask，例如 slice_index = 13
                    slice_index = int(os.path.basename(img_path).split("_")[-1].split(".")[0])  # 從檔名抓 13
                    mask_slice = mask[:, slice_index, :]  # shape: (Layer, W)
                    mask_slices.append(mask_slice)
                else:
                    mask_slices.append(None)
                    
                if IMG_MASK:
                    img_np = np.array(img)  # Convert PIL image to numpy array
                    masked_img_np = masked_img_func(img_np, mask_slice)
                    masked_img = Image.fromarray(masked_img_np)
                    images.append(masked_img)
                else:
                    images.append(img)
                labels.append(row['label'])
                # Store filename without extension for directory naming
                image_name = os.path.splitext(filename)[0]
                filenames.append(image_name)
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                continue


    return images, labels, filenames, mask_slices

def preprocess_image(image, processor=None, input_size=224, device=None, dtype=torch.float32):
    assert isinstance(image, Image.Image), f"expect PIL.Image, got {type(image)}"
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    if processor is not None:
        # A) 先尝试“直接可调用”形式（多数 timm/torchvision transform）
        try:
            out = processor(image)
            if isinstance(out, torch.Tensor):
                x = out
                if x.ndim == 3:  # [C,H,W] -> [1,C,H,W]
                    x = x.unsqueeze(0)
                return x.to(device=device, dtype=dtype)
            if isinstance(out, dict) and "pixel_values" in out:
                x = out["pixel_values"]
                if isinstance(x, np.ndarray):
                    x = torch.from_numpy(x)
                if x.ndim == 3:
                    x = x.unsqueeze(0)
                return x.to(device=device, dtype=dtype)
        except TypeError:
            pass

        # B) 再尝试 HuggingFace 风格（不使用 images= 关键字）
        try:
            out = processor(image, return_tensors="pt")
            if isinstance(out, dict) and "pixel_values" in out:
                x = out["pixel_values"]  # [1,3,H,W]
                return x.to(device=device, dtype=dtype)
            if isinstance(out, torch.Tensor):
                x = out
                if x.ndim == 3:
                    x = x.unsqueeze(0)
                return x.to(device=device, dtype=dtype)
        except TypeError:
            pass

        # C) 某些实现仅接受列表
        for attempt in (lambda: processor([image], return_tensors="pt"),
                        lambda: processor([image])):
            try:
                out = attempt()
                if isinstance(out, dict) and "pixel_values" in out:
                    x = out["pixel_values"]
                    if isinstance(x, np.ndarray):
                        x = torch.from_numpy(x)
                    return x.to(device=device, dtype=dtype)
                if isinstance(out, torch.Tensor):
                    x = out
                    if x.ndim == 3:
                        x = x.unsqueeze(0)
                    return x.to(device=device, dtype=dtype)
            except TypeError:
                pass

    # D) 回退：标准 ImageNet 预处理
    fallback = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),  # [0,1]
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    x = fallback(image)            # [3,H,W]
    x = x.unsqueeze(0)             # [1,3,H,W]
    return x.to(device=device, dtype=dtype)

#test dataset
dme_imgs, dme_labels, dme_img_names, dme_mask_slices = load_sample_data('DME',-1)
print(len(dme_imgs))
print(dme_imgs[0])
print(dme_labels)

# 顯示
plt.figure(figsize=(5, 5))
plt.title("OCT slice")
plt.imshow(dme_imgs[0], cmap="gray")
plt.show()

In [None]:
def add_layer_line(overlay, mask_slice, width=1, cmap_name="rainbow"):
    # Convert to PIL.Image if needed
    if isinstance(overlay, np.ndarray):
        if overlay.dtype != np.uint8:
            overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
        overlay_img = Image.fromarray(overlay)
    else:
        overlay_img = overlay.convert("RGB")
    draw = ImageDraw.Draw(overlay_img)
    n_layers, W = mask_slice.shape
    xs = np.arange(W)
    # Generate rainbow colors for each layer
    cmap = plt.get_cmap(cmap_name)
    colors = (np.array([cmap(i / max(1, n_layers - 1))[:3] for i in range(n_layers)]) * 255).astype(int)
    # Draw each layer line
    for i in range(n_layers):
        ys = np.nan_to_num(mask_slice[i].astype(float), nan=0.0)
        ys = np.clip(ys, 0, overlay_img.height - 1)
        points = list(zip(xs, ys))
        color = tuple(colors[i])
        draw.line(points, fill=color, width=width)

    return overlay_img

In [None]:
# 你的資料夾結構（可依需要調整）
# ./heatmap_results/<task>/<label_idx>/<image_name>/<model>/<XAI>.jpg
BASE_DIR = Path("./heatmap_results_modelmask")
OUT_DIR  = Path("./grid_modelmask")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# 可選：若想固定顯示順序，填入清單；若留空就自動以字母序
FIXED_MODEL_ORDER  = ['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
FIXED_METHOD_ORDER = ['GradCAM', 'ScoreCAM', 'RISE', 'Attention', 'CRP_LXT']

# 支援的影像副檔名
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}

def _is_image(p: Path) -> bool:
    return p.suffix.lower() in IMG_EXTS and p.is_file()

def _safe_open(path: Path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return None

def _collect_index(task_dir: Path):
    """
    回傳：
    - index: {(label, image_name): {model: {method: Path}}}
    - models: set([...])
    - methods: set([...])
    """
    index = defaultdict(lambda: defaultdict(dict))
    models, methods = set(), set()

    # 走訪 task 之下所有檔案，篩出影像
    for img_path in task_dir.rglob("*.jpg"):
        # 期望路徑：task/label/image_name/model/method.ext
        try:
            method = img_path.stem                      # 檔名（不含副檔名）
            model  = img_path.parent.name               # 倒數第2層
            image_name = img_path.parent.parent.name    # 倒數第3層
            label = img_path.parent.parent.parent.name  # 倒數第4層
        except Exception:
            # 結構不符就略過
            continue

        key = (label, image_name)
        index[key].setdefault(model, {})[method] = img_path
        models.add(model)
        methods.add(method)

    return index, models, methods

def _order(items, fixed):
    """若有固定順序，先照固定清單排序，其餘依字母序接在後面"""
    if not fixed:
        return sorted(items)
    seen = set()
    ordered = []
    # 先加入 fixed 中存在的
    for x in fixed:
        if x in items and x not in seen:
            ordered.append(x); seen.add(x)
    # 再加上剩餘（字母序）
    for x in sorted(items):
        if x not in seen:
            ordered.append(x)
    return ordered

def _draw_grid_for_sample(task_name, key, cell_map, models, methods, save_dir: Path,
                          cell_size=(384, 384), dpi=220, mask_index=None, line_width=2):
    """
    cell_map: {model: {method: Path}}
    產出一張 grid：rows=models, cols=methods
    新增：
      - mask_index: dict，key=(label, image_name) -> mask_slice (Layer, W)
      - line_width: 層線粗細
    """
    n_rows, n_cols = len(models), len(methods)
    fig_w = max(8, n_cols * 3)
    fig_h = max(6, n_rows * 3)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h))
    if n_rows == 1 and n_cols == 1:
        axes = [[axes]]
    elif n_rows == 1:
        axes = [axes]
    elif n_cols == 1:
        axes = [[ax] for ax in axes]

    # 欄標（方法）
    for c, mth in enumerate(methods):
        axes[0][c].set_title(mth, fontsize=12)

    # 列標（模型）
    for r, mdl in enumerate(models):
        axes[r][0].set_ylabel(mdl, fontsize=12, rotation=90, labelpad=20, va='center')
        axes[r][0].set_xticks([]); axes[r][0].set_yticks([])
        for spine in axes[r][0].spines.values():
            spine.set_visible(False)

    # 畫圖
    for r, mdl in enumerate(models):
        for c, mth in enumerate(methods):
            ax = axes[r][c]
            if c != 0:
                ax.axis("off")

            path = cell_map.get(mdl, {}).get(mth, None)
            if path is None:
                ax.text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, transform=ax.transAxes)
                continue

            im = _safe_open(path)
            if im is None:
                ax.text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, transform=ax.transAxes)
                continue

            if cell_size is not None:
                im = im.resize(cell_size)

            # 疊層線（若有對應 mask_slice）
            if mask_index is not None and key in mask_index:
                try:
                    im = add_layer_line(im, mask_index[key], width=line_width, cmap_name="rainbow")
                except Exception as e:
                    print(f"[WARN] add_layer_line failed for {key} ({mdl}/{mth}): {e}")

            ax.imshow(im)
            if c == 0:
                ax.set_xticks([]); ax.set_yticks([])

    # 整體標題
    label, image_name = key
    fig.suptitle(f"{task_name} — Image: {image_name} | Label: {label}", fontsize=14)

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    save_dir.mkdir(parents=True, exist_ok=True)
    out_path = save_dir / f"{label}_{image_name}_grid.png"
    plt.savefig(out_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)
    return out_path

def build_all_grids(base_dir: Path = BASE_DIR, out_dir: Path = OUT_DIR, mask_index=None, line_width=2):
    tasks = [p for p in base_dir.iterdir() if p.is_dir()]
    if not tasks:
        print(f"[WARN] 在 {base_dir} 底下沒有找到任何 task 資料夾。")
        return

    for task_dir in tasks:
        task_name = task_dir.name
        print(f"[Task] {task_name} → 掃描中…")

        index, models_set, methods_set = _collect_index(task_dir)
        if not index:
            print(f"  - 找不到任何影像（{task_dir}）。略過。")
            continue

        models  = [n for n in _order(models_set, FIXED_MODEL_ORDER) if not n.endswith('checkpoints')]
        methods = [n for n in _order(methods_set, FIXED_METHOD_ORDER) if not n.endswith('-checkpoint')]

        print(f"  - 模型：{models}")
        print(f"  - 方法：{methods}")
        save_dir = out_dir / task_name

        for key in tqdm(sorted(index.keys()), desc=f"  產出 {task_name} grids"):
            cell_map = index[key]
            _ = _draw_grid_for_sample(task_name, key, cell_map, models, methods, save_dir,
                                      mask_index=mask_index, line_width=line_width)

    print("✅ 全部完成！")

[Task] DME → 掃描中…
  - 模型：['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
  - 方法：['GradCAM', 'ScoreCAM', 'RISE', 'Attention']


  產出 DME grids: 100%|██████████| 200/200 [06:05<00:00,  1.83s/it]

✅ 全部完成！





In [None]:
# 先建索引
mask_index = {}
for name, label, ms in zip(dme_img_names, dme_labels, dme_mask_slices):
    mask_index[(str(label), name)] = ms
build_all_grids(BASE_DIR, OUT_DIR, mask_index=mask_index, line_width=2)