In [1]:
from pathlib import Path
from collections import defaultdict, OrderedDict
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

# 你的資料夾結構（可依需要調整）
# ./heatmap_results/<task>/<label_idx>/<image_name>/<model>/<XAI>.jpg
BASE_DIR = Path("./heatmap_results_modelmask")
OUT_DIR  = Path("./grid_modelmask")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# 可選：若想固定顯示順序，填入清單；若留空就自動以字母序
FIXED_MODEL_ORDER  = ['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
FIXED_METHOD_ORDER = ['GradCAM', 'ScoreCAM', 'RISE', 'Attention', 'CRP_LXT']

# 支援的影像副檔名
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}

def _is_image(p: Path) -> bool:
    return p.suffix.lower() in IMG_EXTS and p.is_file()

def _safe_open(path: Path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return None

def _collect_index(task_dir: Path):
    """
    回傳：
    - index: {(label, image_name): {model: {method: Path}}}
    - models: set([...])
    - methods: set([...])
    """
    index = defaultdict(lambda: defaultdict(dict))
    models, methods = set(), set()

    # 走訪 task 之下所有檔案，篩出影像
    for img_path in task_dir.rglob("*.jpg"):
        # 期望路徑：task/label/image_name/model/method.ext
        try:
            method = img_path.stem                      # 檔名（不含副檔名）
            model  = img_path.parent.name               # 倒數第2層
            image_name = img_path.parent.parent.name    # 倒數第3層
            label = img_path.parent.parent.parent.name  # 倒數第4層
        except Exception:
            # 結構不符就略過
            continue

        key = (label, image_name)
        index[key].setdefault(model, {})[method] = img_path
        models.add(model)
        methods.add(method)

    return index, models, methods

def _order(items, fixed):
    """若有固定順序，先照固定清單排序，其餘依字母序接在後面"""
    if not fixed:
        return sorted(items)
    seen = set()
    ordered = []
    # 先加入 fixed 中存在的
    for x in fixed:
        if x in items and x not in seen:
            ordered.append(x); seen.add(x)
    # 再加上剩餘（字母序）
    for x in sorted(items):
        if x not in seen:
            ordered.append(x)
    return ordered

def _draw_grid_for_sample(task_name, key, cell_map, models, methods, save_dir: Path,
                          cell_size=(384, 384), dpi=220):
    """
    cell_map: {model: {method: Path}}
    產出一張 grid：rows=models, cols=methods
    """
    n_rows, n_cols = len(models), len(methods)
    fig_w = max(8, n_cols * 3)
    fig_h = max(6, n_rows * 3)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h))
    if n_rows == 1 and n_cols == 1:
        axes = [[axes]]
    elif n_rows == 1:
        axes = [axes]
    elif n_cols == 1:
        axes = [[ax] for ax in axes]

    # 欄標（方法）
    for c, mth in enumerate(methods):
        axes[0][c].set_title(mth, fontsize=12)

    # 列標（模型）
    for r, mdl in enumerate(models):
        axes[r][0].set_ylabel(mdl, fontsize=12, rotation=90, labelpad=20, va='center')
        axes[r][0].set_xticks([]); axes[r][0].set_yticks([])
        for spine in axes[r][0].spines.values():
            spine.set_visible(False)

    # 畫圖
    for r, mdl in enumerate(models):
        for c, mth in enumerate(methods):
            ax = axes[r][c]
            if c != 0:
                ax.axis("off")

            path = cell_map.get(mdl, {}).get(mth, None)
            if path is None:
                ax.text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, transform=ax.transAxes)
                continue

            im = _safe_open(path)
            if im is None:
                ax.text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, transform=ax.transAxes)
                continue

            if cell_size is not None:
                im = im.resize(cell_size)
            ax.imshow(im)
            if c == 0:
                ax.set_xticks([]); ax.set_yticks([])

    # 整體標題：加上 image_name 和 label
    label, image_name = key
    fig.suptitle(f"{task_name} — Image: {image_name} | Label: {label}", fontsize=14)

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # 預留上方空間給標題

    save_dir.mkdir(parents=True, exist_ok=True)
    out_path = save_dir / f"{label}_{image_name}_grid.png"
    plt.savefig(out_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)
    return out_path

def build_all_grids(base_dir: Path = BASE_DIR, out_dir: Path = OUT_DIR):
    tasks = [p for p in base_dir.iterdir() if p.is_dir()]
    if not tasks:
        print(f"[WARN] 在 {base_dir} 底下沒有找到任何 task 資料夾。")
        return

    for task_dir in tasks:
        task_name = task_dir.name
        print(f"[Task] {task_name} → 掃描中…")

        index, models_set, methods_set = _collect_index(task_dir)

        if not index:
            print(f"  - 找不到任何影像（{task_dir}）。略過。")
            continue

        models  = [n for n in _order(models_set, FIXED_MODEL_ORDER)if not n.endswith('checkpoints')]
        methods = [n for n in _order(methods_set, FIXED_METHOD_ORDER) if not n.endswith('-checkpoint')]

        print(f"  - 模型：{models}")
        print(f"  - 方法：{methods}")
        save_dir = out_dir / task_name

        for key in tqdm(sorted(index.keys()), desc=f"  產出 {task_name} grids"):
            cell_map = index[key]
            _ = _draw_grid_for_sample(task_name, key, cell_map, models, methods, save_dir)

    print("✅ 全部完成！")

# 執行
build_all_grids(BASE_DIR, OUT_DIR)


[Task] DME → 掃描中…
  - 模型：['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
  - 方法：['GradCAM', 'ScoreCAM', 'RISE', 'Attention']


  產出 DME grids: 100%|██████████| 200/200 [06:05<00:00,  1.83s/it]

✅ 全部完成！



