In [None]:
# Set environment variable to avoid symbolic tracing issues
import os
os.environ['TIMM_FUSED_ATTN'] = '0'
import warnings
warnings.filterwarnings('ignore')
from torchvision import transforms
from datasets import load_dataset
from pytorch_grad_cam import run_dff_on_image, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
import cv2
import torch
from typing import List, Callable, Optional

# Import XAI methods
from baselines.CRP_LXT import CRP_LXT

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
from PIL import Image, ImageDraw
from pathlib import Path


In [2]:
from transformers import (
    ViTImageProcessor, ViTForImageClassification,
    AutoImageProcessor, EfficientNetForImageClassification,
    ResNetForImageClassification, AutoModel
)
import models_vit as models
from util.datasets import TransformWrapper
import timm

#get model
def get_model(task,model,input_size,nb_classes):
    if 'ADCon' in task:
        id2label = {0: "control", 1: "ad"}
        label2id = {v: k for k, v in id2label.items()}
    else:
        id2label = {i: f"class_{i}" for i in range(nb_classes)}
        label2id = {v: k for k, v in id2label.items()}
    processor = None
    if 'RETFound_mae' in model:
        model = models.__dict__['RETFound_mae'](
        img_size=input_size,
        num_classes=nb_classes,
        drop_path_rate=0.2,
        global_pool=True,
    )
    elif 'vit-base-patch16-224' in model:
        # ViT-base-patch16-224 preprocessor
        model_ = 'google/vit-base-patch16-224'
        processor = TransformWrapper(ViTImageProcessor.from_pretrained(model_))
        model = ViTForImageClassification.from_pretrained(
            model_,
            image_size=input_size, #Not in tianhao code, default 224
            num_labels=nb_classes,
            hidden_dropout_prob=0.0, #Not in tianhao code, default 0.0
            attention_probs_dropout_prob=0.0, #Not in tianhao code, default 0.0
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True,
            attn_implementation="eager",      # ← key line
        )
        model.config.return_dict = True
        model.config.output_attentions = True
    elif 'timm_efficientnet-b4' in model:
        model = timm.create_model('efficientnet_b4', pretrained=True, num_classes=nb_classes)
        processor  = transforms.Compose([
            transforms.Resize((380,380)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
        ])
    elif 'ResNet-50' in model:
        model_name = 'microsoft/resnet-50'
        processor = TransformWrapper(AutoImageProcessor.from_pretrained(model_name))
        model = ResNetForImageClassification.from_pretrained(
            model_name,
            num_labels=nb_classes,
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True
        )
        model.config.return_dict = True
        model.config.output_attentions = True

    return model, processor

# Data Load

In [None]:
# task and dataset
#Task_list = ['ADCon','DME']
IMG_MASK = True
HEATMAP_MASK = False
DRAW_LAYER = True
Thickness_DIR = "/orange/ruogu.fang/tienyuchang/IRB2024_OCT_thickness/Data/"
Thickness_CSV = "/orange/ruogu.fang/tienyuchang/IRB2024_OCT_thickness/thickness_map.csv"
Task_list = ['DME']
dataset_fname = 'sampled_labels01.csv'
dataset_dir = '/blue/ruogu.fang/tienyuchang/OCT_EDA'
img_p_fmt = "label_%d/%s" #label index and oct_img name

# model
Model_root = "/blue/ruogu.fang/tienyuchang/RETFound_MAE/output_dir"
Model_fname = "checkpoint-best.pth"
Model_list = ['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
ADCon_finetuned = [
    "ad_control_detect_data-IRB2024v5_ADCON_DL_data-all-resnet-50-OCT-defaulteval---bal_sampler-/",
    "ad_control_detect_data-IRB2024v5_ADCON_DL_data-all-timm_efficientnet-b4-OCT-defaulteval---bal_sampler-/",
    "ad_control_detect_data-IRB2024v5_ADCON_DL_data-all-vit-base-patch16-224-OCT-defaulteval---bal_sampler-/",
    "ad_control_detect_data-IRB2024v5_ADCON_DL_data-all-RETFound_mae-OCT-defaulteval---bal_sampler-/"
]
DME_finetuned = [
    "DME_binary_all_split-IRB2024_v5-all-microsoft/resnet-50-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/",
    "DME_binary_all_split-IRB2024_v5-all-timm_efficientnet-b4-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/",
    "DME_binary_all_split-IRB2024_v5-all-google/vit-base-patch16-224-in21k-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/",
    "DME_binary_all_split-IRB2024_v5-all-RETFound_mae_natureOCT-OCT-bs16ep50lr5e-4optadamw-roc_auceval-trsub0--/"
]
Model_root = "/orange/ruogu.fang/tienyuchang/RETfound_results"
DME_finetuned_masked = [
    "DME_binary_all_split-IRB2024_v5-all-microsoft/resnet-50-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0---add_mask---train_no_aug/",
    "DME_binary_all_split-IRB2024_v5-all-timm_efficientnet-b4-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0---add_mask---train_no_aug/",
    "DME_binary_all_split-IRB2024_v5-all-google/vit-base-patch16-224-in21k-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0---add_mask---train_no_aug/",
    "DME_binary_all_split-IRB2024_v5-all-RETFound_mae_natureOCT-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0---add_mask---train_no_aug/"
]

In [None]:
#mask function
def masked_img_func(img, mask_slice):
    binary_mask = np.zeros_like(img, dtype=np.uint8)
    for i in range(mask_slice.shape[0]-1):
        upper = mask_slice[i].astype(int)
        lower = mask_slice[i+1].astype(int)
        for x in range(img.shape[1]):
            binary_mask[upper[x]:lower[x], x] = 1

    # 套用 mask (把 mask=0 的地方設為 0)
    masked_img = img.copy()
    masked_img[binary_mask == 0] = 0

    return masked_img

# Data loading and preprocessing functions
def load_sample_data(task, num_sample=-1):
    """Load sample images for a given task"""
    df = pd.read_csv(os.path.join(dataset_dir, "%s_sampled"%task, dataset_fname))
    if IMG_MASK or HEATMAP_MASK:
        masked_df = pd.read_csv(Thickness_CSV)
        masked_df = masked_df.rename(columns={'OCT':'folder'}).dropna(subset=['Surface Name'])
        df = df.merge(masked_df,on='folder',how='inner').reset_index(drop=True)
        print('After adding mask, data len: ', df.shape[0])
    task_df = df[df['label'].isin([0, 1])]  # Adjust based on actual DME labels
    # Sample random images
    if num_sample > 0:
        task_df = task_df.sample(n=num_sample, random_state=42).reset_index(drop=True)
    else:
        task_df = task_df.reset_index(drop=True)
    
    images = []
    labels = []
    filenames = []
    mask_slices = []
    
    for _, row in task_df.iterrows():
        # Extract just the filename from oct_img
        filename = os.path.basename(row['OCT']) if isinstance(row['OCT'], str) else row['OCT']
        img_path = os.path.join(dataset_dir, "%s_sampled"%task, img_p_fmt % (row['label'], filename))
        if os.path.exists(img_path):
            try:
                img = Image.open(img_path).convert('RGB')
                if IMG_MASK or HEATMAP_MASK:
                    mask_path = os.path.join(Thickness_DIR, row['folder'], row['Surface Name'])
                    mask = np.load(mask_path) # (Layer, slice, W)

                    # 假設我們要套用其中某一 slice 的 mask，例如 slice_index = 13
                    slice_index = int(os.path.basename(img_path).split("_")[-1].split(".")[0])  # 從檔名抓 13
                    mask_slice = mask[:, slice_index, :]  # shape: (Layer, W)
                    mask_slices.append(mask_slice)
                else:
                    mask_slices.append(None)
                    
                if IMG_MASK:
                    img_np = np.array(img)  # Convert PIL image to numpy array
                    masked_img_np = masked_img_func(img_np, mask_slice)
                    masked_img = Image.fromarray(masked_img_np)
                    images.append(masked_img)
                else:
                    images.append(img)
                labels.append(row['label'])
                # Store filename without extension for directory naming
                image_name = os.path.splitext(filename)[0]
                filenames.append(image_name)
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                continue


    return images, labels, filenames, mask_slices

def preprocess_image(image, processor=None, input_size=224, device=None, dtype=torch.float32):
    assert isinstance(image, Image.Image), f"expect PIL.Image, got {type(image)}"
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    if processor is not None:
        # A) 先尝试“直接可调用”形式（多数 timm/torchvision transform）
        try:
            out = processor(image)
            if isinstance(out, torch.Tensor):
                x = out
                if x.ndim == 3:  # [C,H,W] -> [1,C,H,W]
                    x = x.unsqueeze(0)
                return x.to(device=device, dtype=dtype)
            if isinstance(out, dict) and "pixel_values" in out:
                x = out["pixel_values"]
                if isinstance(x, np.ndarray):
                    x = torch.from_numpy(x)
                if x.ndim == 3:
                    x = x.unsqueeze(0)
                return x.to(device=device, dtype=dtype)
        except TypeError:
            pass

        # B) 再尝试 HuggingFace 风格（不使用 images= 关键字）
        try:
            out = processor(image, return_tensors="pt")
            if isinstance(out, dict) and "pixel_values" in out:
                x = out["pixel_values"]  # [1,3,H,W]
                return x.to(device=device, dtype=dtype)
            if isinstance(out, torch.Tensor):
                x = out
                if x.ndim == 3:
                    x = x.unsqueeze(0)
                return x.to(device=device, dtype=dtype)
        except TypeError:
            pass

        # C) 某些实现仅接受列表
        for attempt in (lambda: processor([image], return_tensors="pt"),
                        lambda: processor([image])):
            try:
                out = attempt()
                if isinstance(out, dict) and "pixel_values" in out:
                    x = out["pixel_values"]
                    if isinstance(x, np.ndarray):
                        x = torch.from_numpy(x)
                    return x.to(device=device, dtype=dtype)
                if isinstance(out, torch.Tensor):
                    x = out
                    if x.ndim == 3:
                        x = x.unsqueeze(0)
                    return x.to(device=device, dtype=dtype)
            except TypeError:
                pass

    # D) 回退：标准 ImageNet 预处理
    fallback = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),  # [0,1]
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    x = fallback(image)            # [3,H,W]
    x = x.unsqueeze(0)             # [1,3,H,W]
    return x.to(device=device, dtype=dtype)

#test dataset
dme_imgs, dme_labels, dme_img_names, dme_mask_slices = load_sample_data('DME',-1)
print(len(dme_imgs))
print(dme_imgs[0])
print(dme_labels)
# 顯示
plt.figure(figsize=(5, 5))
plt.title("OCT slice")
plt.imshow(dme_imgs[0], cmap="gray")
plt.show()

200
<PIL.Image.Image image mode=RGB size=512x496 at 0x1473E320ABB0>
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


# Model and XAI

In [None]:

# Load trained models function
def load_trained_model(task, model_name, input_size=224, nb_classes=2):
    """Load a trained model for a specific task"""
    model, processor = get_model(task, model_name, input_size, nb_classes)
    
    # Load model weights based on task and model
    if task == 'ADCon':
        model_paths = ADCon_finetuned
    elif task == 'DME':
        if IMG_MASK or HEATMAP_MASK:
            model_paths = DME_finetuned_masked
        else:
            model_paths = DME_finetuned
    else:
        print(f"Unknown task: {task}")
        model.eval()
        return model, processor
    
    model_idx = Model_list.index(model_name)
    model_dir = model_paths[model_idx]
    model_path = os.path.join(Model_root, model_dir, Model_fname)
    
    # Load finetuned model if specified (following main_XAI_evaluation.py pattern)
    if model_path and model_path != '':
        if os.path.exists(model_path):
            try:
                # Load checkpoint
                if model_path.startswith('https'):
                    checkpoint = torch.hub.load_state_dict_from_url(
                        model_path, map_location='cpu', check_hash=True)
                else:
                    checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
                
                # Extract model state dict
                if 'model' in checkpoint:
                    checkpoint_model = checkpoint['model']
                else:
                    checkpoint_model = checkpoint
                
                # Load with strict=False to handle potential mismatches
                model.load_state_dict(checkpoint_model, strict=False)
                print(f"Resume checkpoint {model_path} for {model_name} on {task}")
                
            except Exception as e:
                print(f"Error loading model {model_name} for {task}: {e}")
                print("Using pretrained weights instead")
        else:
            print(f"Model path not found: {model_path}")
            print(f"Using pretrained weights for {model_name} on {task}")
    else:
        print(f"No checkpoint specified for {model_name} on {task}, using pretrained weights")
    
    model.eval()
    return model, processor
'''
#Model_list = ['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
model_list = ['RETFound_mae']
for model_name in model_list:
    model, processor = load_trained_model('DME', model_name, 224)
    print(model)
'''

"\n#Model_list = ['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']\nmodel_list = ['RETFound_mae']\nfor model_name in model_list:\n    model, processor = load_trained_model('DME', model_name, 224)\n    print(model)\n"

In [6]:
# XAI Methods Implementation
class XAIGenerator:
    def __init__(self, model, model_name, input_size=224):
        self.model = model
        self.model_name = model_name
        self.input_size = input_size
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # Initialize XAI methods
        self.init_xai_methods()
    
    def get_model_specific_config(self):
        """Get model-specific configuration for XAI methods"""
        config = {
            'patch_size': 14,
            'gpu_batch': 1,
            'attention_layers': 12
        }
        
        # Model-specific configurations
        if 'resnet' in self.model_name.lower():
            config.update({
                'patch_size': 7,  # ResNet has different spatial resolution
                'gpu_batch': 1,  # ResNet can handle larger batches
            })
        elif 'efficientnet' in self.model_name.lower():
            config.update({
                'patch_size': 7,  # EfficientNet spatial resolution
                'gpu_batch': 1,
            })
        elif 'vit' in self.model_name.lower():
            config.update({
                'patch_size': 16,  # ViT patch size
                'gpu_batch': 1,
                'attention_layers': 12,  # Standard ViT-Base layers
            })
        elif 'retfound' in self.model_name.lower():
            config.update({
                'patch_size': 16,  # RETFound uses ViT architecture
                'gpu_batch': 1,
                'attention_layers': 12,
            })
        
        return config
    
    def init_xai_methods(self):
        """Initialize all XAI methods with model-specific configurations"""
        
        # CRP_LXT with model-specific batch size
        # Reduce batch for memory-heavy models
        self.crp_lxt = CRP_LXT(
            self.model, 
            self.model_name,
            img_size=(self.input_size, self.input_size)
        )
        print(f"✓ CRP_LXT initialized for {self.model_name}")
    
    def generate_crp_lxt(self, image_tensor, target_class=None):
        """Generate CRP_LXT heatmap"""
        if self.crp_lxt is None:
            return None
        image_tensor = image_tensor.to(self.device)

        if target_class is None:
            # Get predicted class
            with torch.no_grad():
                outputs = self.model(image_tensor)
                target_class = outputs.argmax(dim=1).item()

        heatmaps = self.crp_lxt(image_tensor, target_class)
        return heatmaps
    
    def generate_all_heatmaps(self, image_tensor, target_class=None):
        """Generate all available heatmaps for an image"""
        heatmaps = {}
            
        # CRP_LXT
        crp_lxt_map = self.generate_crp_lxt(image_tensor, target_class)
        if crp_lxt_map is not None:
            heatmaps['CRP_LXT'] = crp_lxt_map
        
        return heatmaps
    
#test
for model_name in Model_list:
    model, processor = load_trained_model('DME', model_name, 224)
    XAIGenerator(model, model_name)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([2, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Resume checkpoint /blue/ruogu.fang/tienyuchang/RETFound_MAE/output_dir/DME_binary_all_split-IRB2024_v5-all-microsoft/resnet-50-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/checkpoint-best.pth for ResNet-50 on DME
✓ CRP_LXT initialized for ResNet-50
Resume checkpoint /blue/ruogu.fang/tienyuchang/RETFound_MAE/output_dir/DME_binary_all_split-IRB2024_v5-all-timm_efficientnet-b4-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/checkpoint-best.pth for timm_efficientnet-b4 on DME
✓ CRP_LXT initialized for timm_efficientnet-b4


# Run Heatmap

In [None]:
# Visualization functions
def masked_heatmap_func(heatmap, mask_slice):
    if heatmap is None:
        return None
    
    binary_mask = np.zeros_like(heatmap, dtype=np.uint8)
    for i in range(mask_slice.shape[0]-1):
        upper = mask_slice[i].astype(int)
        lower = mask_slice[i+1].astype(int)
        for x in range(heatmap.shape[1]):
            binary_mask[upper[x]:lower[x], x] = 1

    # 套用 mask (把 mask=0 的地方設為 0)
    masked_heatmap = heatmap.copy()
    masked_heatmap[binary_mask == 0] = 0

    return masked_heatmap

def add_layer_line(overlay, mask_slice, width=1, cmap_name="rainbow"):
    # Convert to PIL.Image if needed
    if isinstance(overlay, np.ndarray):
        if overlay.dtype != np.uint8:
            overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
        overlay_img = Image.fromarray(overlay)
    else:
        overlay_img = overlay.convert("RGB")
    draw = ImageDraw.Draw(overlay_img)
    n_layers, W = mask_slice.shape
    xs = np.arange(W)
    # Generate rainbow colors for each layer
    cmap = plt.get_cmap(cmap_name)
    colors = (np.array([cmap(i / max(1, n_layers - 1))[:3] for i in range(n_layers)]) * 255).astype(int)
    # Draw each layer line
    for i in range(n_layers):
        ys = np.nan_to_num(mask_slice[i].astype(float), nan=0.0)
        ys = np.clip(ys, 0, overlay_img.height - 1)
        points = list(zip(xs, ys))
        color = tuple(colors[i])
        draw.line(points, fill=color, width=width)

    return overlay_img

def normalize_heatmap(heatmap):
    """Normalize heatmap to 0-1 range"""
    if heatmap is None:
        return None
    
    heatmap = np.array(heatmap)
    if heatmap.max() == heatmap.min():
        return np.zeros_like(heatmap)
    
    return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

def overlay_heatmap_on_image(image, heatmap, mask_slice = None, alpha=0.4, colormap='jet'):
    """Return uint8 RGB overlay, shape (H,W,3)."""

    if heatmap is None:
        # 保證回傳 (H,W,3) uint8
        if isinstance(image, Image.Image):
            img = np.array(image)
        else:
            img = np.array(image)
        if img.ndim == 2:                 # 灰階 → RGB
            img = np.repeat(img[..., None], 3, axis=-1)
        if img.shape[-1] == 4:            # RGBA → RGB
            img = img[..., :3]
        if img.dtype != np.uint8:
            img = np.clip(img, 0, 255).astype(np.uint8)
        return img
    
    if HEATMAP_MASK and mask_slice is not None:
        heatmap = masked_heatmap_func(heatmap, mask_slice)

    # 1) 統一 image → RGB uint8
    img = np.array(image if not isinstance(image, Image.Image) else np.array(image))
    if img.ndim == 2:
        img = np.repeat(img[..., None], 3, axis=-1)
    if img.shape[-1] == 4:
        img = img[..., :3]
    if img.dtype != np.uint8:
        # 若是 0–1，乘回 255；否則直接裁切/轉型
        mx = float(img.max()) if img.size else 1.0
        if mx <= 1.0:
            img = (np.clip(img, 0.0, 1.0) * 255.0).astype(np.uint8)
        else:
            img = np.clip(img, 0, 255).astype(np.uint8)
    H, W = img.shape[:2]

    # 2) 統一 heatmap → 2D float [0,1]
    hm = heatmap
    if "torch" in str(type(hm)):
        hm = hm.detach().float().cpu().numpy()
    hm = np.array(hm)
    hm = np.squeeze(hm)                   # (H,W) 最佳
    if hm.ndim == 3 and hm.shape[-1] == 3:
        hm = hm.mean(axis=-1)             # 轉成單通道
    if hm.ndim != 2:
        raise ValueError(f"heatmap must be 2D after squeeze; got {hm.shape}")
    # normalize to [0,1]
    hm = hm.astype(np.float32)
    ptp = hm.max() - hm.min()
    hm = (hm - hm.min()) / (ptp + 1e-12)

    # 3) resize heatmap 到影像大小
    hm_resized = cv2.resize(hm, (W, H), interpolation=cv2.INTER_LINEAR)

    # 4) 上 colormap → 得到 RGBA，再取前 3 個通道 (RGB)
    cmap = plt.get_cmap(colormap)
    hm_rgb = cmap(hm_resized)[..., :3].astype(np.float32)   # (H,W,3), 0–1

    # 5) 影像轉 0–1，做疊加
    img_rgb = img.astype(np.float32) / 255.0
    overlay = alpha * hm_rgb + (1.0 - alpha) * img_rgb
    overlay = np.clip(overlay, 0.0, 1.0)

    return (overlay * 255.0).astype(np.uint8)

In [None]:
# Updated function with new directory structure for heatmap saving
def generate_comprehensive_heatmaps_v2(num_samples=3,task_list=Task_list,model_list=Model_list,heatmap_dir="./heatmap_results"):
    """Generate heatmaps for all task-model combinations with new directory structure"""
    
    results = {}
    input_size = 224
    
    print("Starting comprehensive heatmap generation...")
    print(f"Tasks: {task_list}")
    print(f"Models: {model_list}")
    print(f"Samples per task: {num_samples}")
    
    for task in task_list:
        print(f"\n=== Processing Task: {task} ===")
        results[task] = {}
        
        # Load sample data for this task (now returns filenames too)
        try:
            images, labels, filenames, mask_slices = load_sample_data(task, num_samples)
            print(f"Loaded {len(images)} images for {task}")
        except Exception as e:
            print(f"Error loading data for {task}: {e}")
            continue
        
        for model_name in model_list:
            print(f"\n--- Processing Model: {model_name} ---")
            # Load trained model
            model, processor = load_trained_model(task, model_name, input_size)
            # Initialize XAI generator
            xai_generator = XAIGenerator(model, model_name, input_size)
            # Store results for this model
            results[task][model_name] = {
                'images': images,
                'labels': labels,
                "mask_slices": mask_slices,
                'heatmaps': []
            }
            # Process each image with filename
            for idx, (image, label, filename, mask_slice) in enumerate(zip(images, labels, filenames, mask_slices)):
                # Preprocess image
                image_tensor = preprocess_image(image, processor, input_size)
                # Generate all heatmaps for this image
                heatmaps = xai_generator.generate_all_heatmaps(image_tensor, target_class=label)
                results[task][model_name]['heatmaps'].append(heatmaps)
                for xai_name, heatmap in heatmaps.items():
                    overlay = overlay_heatmap_on_image(image, heatmap, mask_slice)
                    if DRAW_LAYER:
                        overlay = add_layer_line(overlay, mask_slice)
                    # overlay is np.uint8 HxWx3 per implementation
                    # Create directory structure: ./heatmap_results/<task_name>/<label_idx>/<image_name>/<baselinemodel>/<XAI>.jpg
                    save_dir = Path(heatmap_dir) / task / str(label) / filename / model_name
                    save_dir.mkdir(parents=True, exist_ok=True)
                    out_path = save_dir / f"{xai_name}.jpg"
                    try:
                        Image.fromarray(overlay).save(out_path, format='JPEG', quality=95)
                    except Exception as e:
                        print(f"Failed to save {out_path}: {e}")
            print(f"Completed {model_name} for {task}")
            #delete after finish
            del xai_generator
    return results


In [None]:
# Test the new function with improved directory structure
print("=== Testing New Heatmap Generation Function ===")
print("This will save heatmaps in the structure: ./heatmap_results/<task_name>/<label_idx>/<image_name>/<baselinemodel>/<XAI>.jpg")

# Test with a small number of samples first
#Model_list = ['ResNet-50', 'timm_efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
heatmap_results_v2 = generate_comprehensive_heatmaps_v2(num_samples=-1,task_list=Task_list,model_list=Model_list,heatmap_dir="./heatmap_results_modelmask")  # Start with 3 samples for testing
#heatmap_results_v2 = generate_comprehensive_heatmaps_v2(num_samples=1,task_list=['DME'],model_list=['ResNet-50'])  # Start with 3 samples for testing


=== Testing New Heatmap Generation Function ===
This will save heatmaps in the structure: ./heatmap_results/<task_name>/<label_idx>/<image_name>/<baselinemodel>/<XAI>.jpg
Starting comprehensive heatmap generation...
Tasks: ['ADCon', 'DME']
Models: ['ResNet-50', 'timm_efficientnet-b4']
Samples per task: -1

=== Processing Task: ADCon ===
Loaded 200 images for ADCon

--- Processing Model: ResNet-50 ---


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([2, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Resume checkpoint /blue/ruogu.fang/tienyuchang/RETFound_MAE/output_dir/ad_control_detect_data-IRB2024v5_ADCON_DL_data-all-resnet-50-OCT-defaulteval---bal_sampler-/checkpoint-best.pth for ResNet-50 on ADCon
✓ CRP_LXT initialized for ResNet-50
Processing image 1/200 (Label: 0, File: 1.2.840.114158.46801694025402817689310895532469688711_latL_13)
Processing image 2/200 (Label: 0, File: 1.2.840.114158.502621114697249271110805189119597217173_latL_13)
Processing image 3/200 (Label: 0, File: 1.2.840.114158.48723819094541232528721935005684885894_latL_13)
Processing image 4/200 (Label: 0, File: 1.2.840.114158.535203666622458697717439954291493020309_latR_13)
Processing image 5/200 (Label: 0, File: 1.2.840.114158.544986813490060889617660541551129217671_latR_13)
Processing image 6/200 (Label: 0, File: 1.2.840.114158.49167721181529507523635972451615719590_latR_13)
Processing image 7/200 (Label: 0, File: 1.2.840.114158.55959335593077182251739677378616809363_latL_13)
Processing image 8/200 (Label: 0, 

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([2, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Resume checkpoint /blue/ruogu.fang/tienyuchang/RETFound_MAE/output_dir/DME_binary_all_split-IRB2024_v5-all-microsoft/resnet-50-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/checkpoint-best.pth for ResNet-50 on DME
✓ CRP_LXT initialized for ResNet-50
Processing image 1/200 (Label: 0, File: 1.2.840.114158.50071404597218188804561682135446663304_latR_13)
Processing image 2/200 (Label: 0, File: 1.2.840.114158.536026259510780249311212458532212434831_latL_13)
Processing image 3/200 (Label: 0, File: 1.2.840.114158.467916554174621643614971482669568579718_latR_13)
Processing image 4/200 (Label: 0, File: 1.2.840.114158.51397392179968342934233302352253723286_latR_13)
Processing image 5/200 (Label: 0, File: 1.2.840.114158.572295763583983052317497210302308949182_latL_10)
Processing image 6/200 (Label: 0, File: 1.2.840.114158.532435480947824630410785238965867404435_latR_13)
Processing image 7/200 (Label: 0, File: 1.2.840.114158.516497366494051081910080350389676234380_latL_13)
Processing image 8/200

In [None]:
# Test the new function with improved directory structure
print("=== Testing New Heatmap Generation Function ===")
print("This will save heatmaps in the structure: ./heatmap_results/<task_name>/<label_idx>/<image_name>/<baselinemodel>/<XAI>.jpg")

# Test with a small number of samples first
Model_list = ['vit-base-patch16-224', 'RETFound_mae']
#heatmap_results_v2 = generate_comprehensive_heatmaps_v2(num_samples=-1,task_list=Task_list,model_list=Model_list)
heatmap_results_v2 = generate_comprehensive_heatmaps_v2(num_samples=1,task_list=['DME'],model_list=['vit-base-patch16-224'],heatmap_dir="./heatmap_results_modelmask")


=== Testing New Heatmap Generation Function ===
This will save heatmaps in the structure: ./heatmap_results/<task_name>/<label_idx>/<image_name>/<baselinemodel>/<XAI>.jpg
Starting comprehensive heatmap generation...
Tasks: ['DME']
Models: ['vit-base-patch16-224']
Samples per task: 1

=== Processing Task: DME ===
Loaded 1 images for DME

--- Processing Model: vit-base-patch16-224 ---


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Resume checkpoint /blue/ruogu.fang/tienyuchang/RETFound_MAE/output_dir/DME_binary_all_split-IRB2024_v5-all-microsoft/resnet-50-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/checkpoint-best.pth for vit-base-patch16-224 on DME


ValueError: ViTModel not yet supported. Supported models are: transformers.models.llama.modeling_llama, transformers.models.qwen2.modeling_qwen2, transformers.models.qwen3.modeling_qwen3, transformers.models.gemma3.modeling_gemma3, transformers.models.bert.modeling_bert, transformers.models.gpt2.modeling_gpt2, torchvision.models.vision_transformer Please provide a custom 'patch_map'. Contributions to the GitHub repository are welcome!