In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

from load import *
import torchmetrics
from tqdm import tqdm
import torch.nn as nn
from torch import optim
import numpy as np

from explainer import gradCAM, interpret
from torch.utils.data import DataLoader, RandomSampler
from scipy.ndimage import filters
from loss import BatchSeparationLoss, BatchConsistencyLoss, SparsityLoss

seed_everything(hparams['seed'])

In [2]:
def get_dataloader(dataset, seed, bs):
    # Create a RandomSampler with a fixed seed for shuffling
    sampler = RandomSampler(dataset, replacement=False, num_samples=None, generator=torch.Generator().manual_seed(seed))
    dataloader = DataLoader(
        dataset, 
        batch_size=bs,
        sampler=sampler,
        num_workers=16, 
        pin_memory=False
    )
    return dataloader

In [3]:
dataset = ImageNet(hparams['data_dir'], split='val', transform=tfms)
# dataset = CUBDataset(hparams['data_dir'], train=False, transform=tfms)
# dataset = torchvision.datasets.OxfordIIITPet(root=hparams['data_dir'], transform=tfms, split='test')
# dataset = test_set # EuroSAT
# dataset = torchvision.datasets.Food101(root=hparams['data_dir'], transform=tfms, split='test')

bs = 8
seed = 123
dataloader = get_dataloader(dataset, seed, bs)

In [None]:
device = torch.device(hparams['device'])
model, preprocess = clip.load(hparams['model_size'], device=device, jit=False) #Best model use ViT-B/32
checkpoint = torch.load("/path/to/your/model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [5]:
def loss_to_metric(loss, k=50, threshold=23):
    # Ensure the loss is a tensor
    if not isinstance(loss, torch.Tensor):
        loss = torch.tensor(loss)
    
    return (1 - torch.sigmoid(-k * (loss - threshold))).item()

In [6]:
def metric(loss):
    return (1 - (loss / 100)).item()

In [7]:
loss_sep = BatchSeparationLoss()

In [None]:
total_disentanglability = 0.0
count = 0 
for batch_number, batch in enumerate(tqdm(dataloader)):
    images, labels = batch
    texts = np.array(label_to_classname)[labels].tolist()

    tokenized_concepts_list = []
    rich_labels = []
    for i in range(len(texts)):
        concepts = gpt_descriptions[texts[i]][:5]
        concatenated_concepts = ', '.join(concepts)
        label = hparams['label_before_text'] + wordify(texts[i]) + hparams['label_after_text'] + " It may contains " + concatenated_concepts
        rich_labels.append(label)
        
        concepts.insert(0, texts[i])
        tokenized_concepts = clip.tokenize(concepts)
        tokenized_concepts_list.append(tokenized_concepts)

    images = images.to(device)
    texts = clip.tokenize(texts)
    texts = texts.to(device)

    attn_map = []
    if hparams['model_size'] in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64']:
        for k in range(len(images)):
            num_texts = tokenized_concepts_list[k].shape[0]
            repeated_image = images[k].unsqueeze(0).repeat(num_texts, 1, 1, 1)
            heatmap = gradCAM(
                model.visual,
                repeated_image,
                model.encode_text(tokenized_concepts_list[k].to(device)),
                getattr(model.visual, "layer4")
            )
            attn_map.append(heatmap)
    
    elif hparams['model_size'] in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']:
        for k in range(len(images)):
            R_image = interpret(model=model, image=images[k].unsqueeze(0), texts=tokenized_concepts_list[k].to(device), device=device)
            image_relevance = R_image[0]
            dim = int(image_relevance.numel() ** 0.5)
            R_image = R_image.reshape(-1, dim, dim)
            attn_map.append(R_image)
    
    attn_map_label = [item[0] for item in attn_map]
    attn_map_concepts = [item[1:] for item in attn_map]

    total_disentanglability += loss_sep(attn_map_concepts) / 100
    count += 1

avg_disen = total_disentanglability / count
print ('The final explanation disentanglability: ' + str(avg_disen))

In [14]:
def sigmoid_metric(loss, k=10, M=0.2, temp=1.0):
    """
    Transform a loss value into a metric using a sigmoid function.
    
    Returns:
    - A metric value between 0 and 1, where lower losses yield higher metrics.
    """
    # Apply the sigmoid transformation
    metric = 1 / (1 + np.exp(-k * (M - loss / temp)))
    return metric

In [None]:
# final metric value
sigmoid_metric(avg_disen.detach().cpu().numpy())