In [None]:
import warnings
warnings.filterwarnings('ignore')
from torchvision import transforms
from datasets import load_dataset
from pytorch_grad_cam import run_dff_on_image, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import numpy as np
import cv2
import torch
from typing import List, Callable, Optional

In [None]:
from transformers import (
    ViTImageProcessor, ViTForImageClassification,
    AutoImageProcessor, EfficientNetForImageClassification,
    ResNetForImageClassification, AutoModel
)
import models_vit as models

#get model
def get_model(task,model,input_size,nb_classes):
    if 'ADCon' in task:
        id2label = {0: "control", 1: "ad"}
        label2id = {v: k for k, v in id2label.items()}
    else:
        id2label = {i: f"class_{i}" for i in range(nb_classes)}
        label2id = {v: k for k, v in id2label.items()}
    processor = None
    if 'RETFound_mae' in model:
        model = models.__dict__['RETFound_mae'](
        img_size=input_size,
        num_classes=nb_classes,
        drop_path_rate=0.2,
        global_pool=True,
    )
    elif 'vit-base-patch16-224' in model:
            # ViT-base-patch16-224 preprocessor
            model_ = 'google/vit-base-patch16-224'
            processor = TransformWrapper(ViTImageProcessor.from_pretrained(model_))
            model = ViTForImageClassification.from_pretrained(
                model_,
                image_size=input_size, #Not in tianhao code, default 224
                num_labels=nb_classes,
                hidden_dropout_prob=0.0, #Not in tianhao code, default 0.0
                attention_probs_dropout_prob=0.0, #Not in tianhao code, default 0.0
                id2label=id2label,
                label2id=label2id,
                ignore_mismatched_sizes=True
            )
    elif 'efficientnet-b4' in model:
        # EfficientNet-B4 preprocessor
        model_ = 'google/efficientnet-b4'
        processor = TransformWrapper(AutoImageProcessor.from_pretrained(model_))
        model = EfficientNetForImageClassification.from_pretrained(
            model_,
            image_size=input_size,
            num_labels=nb_classes,
            dropout_rate=0.0,
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True
        )
    elif 'resnet-50' in model:
        model_name = 'microsoft/resnet-50'
        processor = TransformWrapper(AutoImageProcessor.from_pretrained(model_name))
        model = ResNetForImageClassification.from_pretrained(
            model_name,
            num_labels=nb_classes,
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True
        )

    return model, processor

In [None]:
# task and dataset
Task_list = ['ADCon','DME']
dataset_fname = 'sampled_labels01.csv'
dataset_dir = '/blue/ruogu.fang/tienyuchang/OCT_EDA'
img_p_fmt = "label_%d/%s" #label index and oct_img name

# model
Model_list = ['ResNet-50', 'efficientnet-b4', 'vit-base-patch16-224', 'RETFound_mae']
ADCon_finetuned = [
    "",
    "",
    "",
    "ad_control_detect_data-IRB2024v5_ADCON_DL_data-all-RETFound_mae-OCT-defaulteval---bal_sampler-/"
]
DME_finetuned = [
    "",
    "",
    "DME_binary_all_split-IRB2024_v5-all-google/vit-base-patch16-224-in21k-OCT-bs16ep50lr5e-4optadamw-defaulteval-trsub0--/",
    "DME_binary_all_split-IRB2024_v5-all-RETFound_mae_natureOCT-OCT-bs16ep50lr5e-4optadamw-roc_auceval-trsub0--/"
]