In [3]:
import torch
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

from src.vit_recipro_cam_metric import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Test with: {device}")

# CAM model's accuracy is measured with ILSVRC2012_cls validation dataset
IMAGE_PATH = '../Data/ILSVRC2012_cls/val'
Height = 224
Width = 224
batch_size = 5
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
data_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(IMAGE_PATH, transforms.Compose([
        transforms.Resize([256, 256]),
        transforms.CenterCrop((Height, Width)),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

print(f'Total number of data: {len(data_loader.dataset.imgs)}')


Test with: cpu
Total number of data: 35


In [4]:
# Load ViT base model
from torchvision.models import vit_b_16, ViT_B_16_Weights
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1).to(device)

In [5]:
# Evaluate the model with average drop and increase metric
avg_drop, avg_inc = average_drop_increase(model, data_loader, Height, Width, batch_size, device=device)
print(f"Average Drop: {avg_drop}, Average Increase: {avg_inc}")

Average Drop: 8.554266929626465, Average Increase: 68.57142639160156


In [7]:
# Evaluate the model with DAUC and IAUC metric
DAUC, IAUC = dauc_iauc(model, data_loader, Height, Width, batch_size, device=device)
print(f"Deletion AUC: {DAUC}, Insertion AUC: {IAUC}")

Deletion AUC: 0.006671550885046736, Insertion AUC: 0.029048736376109736


In [6]:
# Evaluate the model with ADCC metric
avg_drop, avg_inc, coherency, complexity, adcc = ADCC(model, data_loader, Height, Width, batch_size, device=device)
print(f"Drop: {avg_drop}, Increase: {avg_inc}, Coherency: {coherency}, Complexity: {complexity}, ADCC: {adcc}")

Drop: 15.181533268519809, Increase: 74.28571428571429, Coherency: 83.77099173409597, Complexity: 45.07743290492466, ADCC: 71.53989475135123
