In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

from load import *
import torchmetrics
from tqdm import tqdm
import torch.nn as nn
from torch import optim
import numpy as np

from explainer import gradCAM, interpret
from torch.utils.data import DataLoader, RandomSampler
from scipy.ndimage import filters

seed_everything(hparams['seed'])

In [2]:
def get_dataloader(dataset, seed, bs):
    # Create a RandomSampler with a fixed seed for shuffling
    sampler = RandomSampler(dataset, replacement=False, num_samples=None, generator=torch.Generator().manual_seed(seed))
    dataloader = DataLoader(
        dataset, 
        batch_size=bs,
        sampler=sampler,
        num_workers=16, 
        pin_memory=False
    )
    return dataloader

In [3]:
dataset = ImageNet(hparams['data_dir'], split='val', transform=tfms)
# dataset = CUBDataset(hparams['data_dir'], train=False, transform=tfms)
# dataset = torchvision.datasets.OxfordIIITPet(root=hparams['data_dir'], transform=tfms, split='test')
# dataset = test_set # EuroSAT
# dataset = torchvision.datasets.Food101(root=hparams['data_dir'], transform=tfms, split='test')

bs = 8
seed = 123
dataloader = get_dataloader(dataset, seed, bs)

In [None]:
device = torch.device(hparams['device'])
model, preprocess = clip.load(hparams['model_size'], device=device, jit=False) #Best model use ViT-B/32
checkpoint = torch.load("/path/to/your/model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [6]:
def compute_insertion_fidelity_batched_blur(model, images, texts, heatmaps, insertion_steps=5):
    device = images.device
    batch_size, C, H, W = images.shape
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    # Create blurred version of the images
    blurred_images = torch.nn.functional.interpolate(images, scale_factor=0.5, mode='bilinear', align_corners=False)
    blurred_images = torch.nn.functional.interpolate(blurred_images, size=(H, W), mode='bilinear', align_corners=False)

    # 1. Determine the model's output for the completely blurred images
    image_embedding = model.encode_image(blurred_images.to(device))
    text_embedding = model.encode_text(texts.to(device))
    blurred_output = cos(image_embedding, text_embedding)

    # Determine the model's output for the original images
    image_embedding = model.encode_image(images.to(device))
    original_output = cos(image_embedding, text_embedding)

    max_differences = torch.abs(original_output - blurred_output)

    # 2. Sort pixels by importance for each image in the batch
    _, indices = torch.sort(heatmaps.view(batch_size, -1), descending=True, dim=1)
    total_pixels = indices.shape[1]

    # Placeholder for storing normalized changes
    differences = torch.zeros((batch_size, insertion_steps), device=device)

    for step in range(1, insertion_steps + 1):
        # 3. Gradually insert the image based on pixel importance
        fraction = step / insertion_steps
        num_insert = int(fraction * total_pixels)

        # Start with blurred images
        inserted_images = blurred_images.clone()
        for idx, image in enumerate(inserted_images):
            flat_image = image.view(-1)
            flat_original = images[idx].view(-1)
            flat_image[indices[idx, :num_insert]] = flat_original[indices[idx, :num_insert]]

        # Determine model's output for the inserted images
        image_embedding = model.encode_image(inserted_images.to(device))
        inserted_output = cos(image_embedding, text_embedding)

        # Compute normalized change
        difference = torch.abs(inserted_output - blurred_output)
        differences[:, step - 1] = difference

    # Normalize changes
    normalized_changes = torch.clamp(differences / max_differences.unsqueeze(1), 0, 1)

    # 4. Calculate normalized fidelity (average over the insertion steps)
    fidelity = normalized_changes.mean(dim=1)

    return fidelity

In [None]:
total_fidelity = 0.0
count = 0 
for batch_number, batch in enumerate(tqdm(dataloader)):
    images, labels = batch

    texts = np.array(label_to_classname)[labels].tolist()

    concept_list = []
    tokenized_concepts_list = []
    for i in range(len(texts)):
        concepts = gpt_descriptions[texts[i]][:5].copy()
        # concepts.insert(0, texts[i])
        concept_list.append(concepts)
        tokenized_concepts = clip.tokenize(concepts)
        tokenized_concepts_list.append(tokenized_concepts)

    tokenized_text = clip.tokenize(texts)

    images = images.to(device)
    texts = tokenized_text.to(device)

    attn_map = []
    if hparams['model_size'] in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64']:
        for j in range(len(images)):
            repeated_input = images[j].unsqueeze(0).repeat(tokenized_concepts_list[0].shape[0], 1, 1, 1)
            attn = gradCAM(
                model.visual,
                repeated_input.to(device),
                model.encode_text(tokenized_concepts_list[j].to(device)).float(),
                getattr(model.visual, "layer4")
            )
            attn = F.interpolate(
                attn.unsqueeze(0),
                images.shape[2:],
                mode='bicubic',
                align_corners=False)
            attn = attn.squeeze()
            attn_map.append(attn)

    elif hparams['model_size'] in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']:
        for k in range(len(images)):
            R_image = interpret(model=model, image=images[k].unsqueeze(0), texts=tokenized_concepts_list[k].to(device), device=device)
            image_relevance = R_image[0]
            dim = int(image_relevance.numel() ** 0.5)
            R_image = R_image.reshape(-1, dim, dim)
            attn = F.interpolate(
                R_image.unsqueeze(0),
                images.shape[2:],
                mode='bicubic',
                align_corners=False)
            attn = attn.squeeze()
            attn_map.append(attn)
    
    attn_map = torch.stack(attn_map).reshape(-1, 224, 224)
    repeated_images = images.repeat(tokenized_concepts_list[0].shape[0], 1, 1,1)
    tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77).to(device)

    fidelities = compute_insertion_fidelity_batched_blur(
        model=model,
        images=repeated_images,
        texts=tokenized_concepts_list,
        heatmaps=attn_map
    )
    if np.isnan(torch.mean(fidelities).item()):
        continue
    else:
        total_fidelity += torch.mean(fidelities).item()
        count += 1

    del images, labels, texts, attn_map, repeated_images, tokenized_concepts_list, fidelities
    torch.cuda.empty_cache()

print ('The final explanation fidelity: ' + str(total_fidelity / count))