In [None]:
from captum.concept import TCAV, BTCAV
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import IterableDataset
from typing import Iterator
from captum.concept._utils.data_iterator import dataset_to_dataloader
from captum.concept import TCAV, BTCAV
from captum.concept import Concept
from captum.attr import LayerGradientXActivation, LayerIntegratedGradients
from captum.concept._utils.data_iterator import dataset_to_dataloader, CustomIterableDataset
import torchvision
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
from captum.concept._utils.common import concepts_to_str
import json
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os, shutil
import time
import cv2
import glob
from sklearn import linear_model
from sklearn.model_selection import train_test_split
from captum.concept._utils.classifier import Classifier
from torch.utils.data import DataLoader
import gc
from collections import defaultdict

# Local imports
from batcave.BayesianLogistic import VBLogisticRegression

In [None]:
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
device

In [None]:
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                        [0.229, 0.224, 0.225])
])

In [None]:
def remove_cav_folder():
    print("Removing cav folder")
    if os.path.exists('cav'):
        shutil.rmtree('cav')

def get_tensor_from_filename(filename):
    img = Image.open(filename).convert("RGB")
    return test_transform(img).to(device)

def load_image_tensors(class_name, root_path='dataset/general blocked unblocked/test', transform=True):
    path = os.path.join(root_path, class_name)
    filenames = glob.glob(path + '/*.jpg')

    tensors = []
    for filename in filenames:
        img = Image.open(filename).convert('RGB')
        tensors.append(test_transform(img) if transform else img)
    
    return tensors

def assemble_concept(name, id, concepts_path="concepts/"):
    concept_path = os.path.join(concepts_path, name) + "/"
    dataset = CustomIterableDataset(get_tensor_from_filename, concept_path)
    concept_iter = dataset_to_dataloader(dataset)

    return Concept(id=id, name=name, data_iter=concept_iter)


def format_float(f):
    return float('{:.3f}'.format(f) if abs(f) >= 0.0005 else '{:.3e}'.format(f))

def extract_btcav_scores(experimental_sets, tcav_scores, layers):

    ex_sets, cons, vals = [], [], []
    for idx_es, concepts in enumerate(experimental_sets):

        concepts = experimental_sets[idx_es]
        concepts_key = concepts_to_str(concepts)
        
        for i in range(len(concepts)):
            vals_j = []
            for j in range(len(tcav_scores)):
                val = [format_float(scores['sign_count'][i]) for layer, scores in tcav_scores[j][concepts_key].items()]
                vals_j.append(val)
            ex_sets.append(idx_es)
            cons.append(concepts[i].name)
            vals.append(vals_j)

    return ex_sets, cons, vals

In [None]:
class CustomClassifier(Classifier):
    def __init__(self):
        # self.lm = linear_model.LogisticRegression(max_iter=1000)
        self.lm = VBLogisticRegression(fit_intercept=False)  # We artificially add an intercept below
        self.test_size = 0.33
        self.evaluate_test = False
        self.metrics = None

    def train_and_eval(self, dataloader: DataLoader, **kwargs):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        inputs, labels = [], []
        for X, y in dataloader:
            X = torch.cat((torch.ones((X.shape[0], 1), device=device), X.to(device)), dim=1)  # Add the intercept term. This is required only for the cav classifier.
            inputs.append(X)
            labels.append(y.to(device))
        
        # Move tensors to CPU before converting to NumPy
        inputs = torch.cat(inputs).detach().cpu().numpy()
        labels = torch.cat(labels).detach().cpu().numpy()
        
        if self.evaluate_test:
            X_train, X_test, y_train, y_test = train_test_split(inputs, labels, test_size=self.test_size)
        else:
            X_train, y_train = inputs, labels
        
        self.lm.fit(X_train, y_train)

        if self.evaluate_test:
            self.metrics = {'accs': self.lm.score(X_test, y_test)}
            return self.metrics
        self.metrics = {'accs': self.lm.score(X_train, y_train)}
        print(self.metrics)
        return self.metrics

    def weights(self):
        if len(self.lm.coef_) == 1:
            # if there are two concepts, there is only one label.
            # We split it in two.
            return torch.tensor(np.array([-1 * self.lm.coef_[0], self.lm.coef_[0]])).to('cuda')
        else:
            return torch.tensor(self.lm.coef_).to('cuda')

    def classes(self):
        return self.lm.classes_

    def get_metrics(self):
        return self.metrics

In [None]:
def get_scores(model_name, experimental_set_rand):
    # Set model
    model = models.alexnet()
    model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)
    model.load_state_dict(torch.load(model_name))
    model = model.eval().to(device)

    # Layers, classifier and experimental sets
    layers = ['classifier.4']
    classifier = CustomClassifier()
    
    # TCAV
    mytcav = BTCAV(model=model,
                    layers=layers,
                    classifier=classifier,
                    n_samples=1000,
                    layer_attr_method = None
                )

    # Class tensors
    blocked_images = load_image_tensors('blocked')
    blocked_tensor = torch.stack([img for img in blocked_images]).to(device)

    # Scores
    scores = mytcav.interpret(
        inputs=blocked_tensor,
        experimental_sets=experimental_set_rand,
        target=0,
    )

    tcav_classifier_accuracy = mytcav.classifier.get_metrics()['accs']

    ex_sets, cons, vals = extract_btcav_scores(experimental_set_rand, scores, layers)

    # Clear the memory
    del model, mytcav, classifier
    gc.collect()
    remove_cav_folder()

    return ex_sets, cons, vals, tcav_classifier_accuracy

In [None]:
# Set concepts
dark_concept = assemble_concept("dark", 0)
light_concept = assemble_concept("light", 1)
orange_concept = assemble_concept("orange", 2)
d30_concept = assemble_concept("distance 30", 3)
d45_concept = assemble_concept("distance 45", 4)
d60_concept = assemble_concept("distance 60", 5)
random_0_concept = assemble_concept("random", 6)

experimental_set_rand = [
                         [dark_concept, random_0_concept],
                         [light_concept, random_0_concept],
                         [orange_concept, random_0_concept],
                         [dark_concept, light_concept], 
                         [dark_concept, orange_concept], 
                         [light_concept, orange_concept],
                         [d30_concept, random_0_concept], 
                         [d45_concept, random_0_concept], 
                         [d60_concept, random_0_concept], 
                        ]

In [None]:
model_name = 'models/full finetuned models/alexnet_finetuned_100_epochs.pth' # Replace this with your model path
data = {}

for i, exp_set in enumerate(experimental_set_rand):
    print(f"Experiment {i+1}/{len(experimental_set_rand)}")
    exp_set = [exp_set]

    ex_sets, cons, vals, tcav_accuracy = get_scores(model_name, exp_set)

    vals_list = []
    vals_list.append(vals)
    vals_list = np.array(vals_list)[0,:,:,0]

    data[i] = {
        "experimental_set": ex_sets,
        "concepts": cons,
        "tcav_classifier_accuracy": tcav_accuracy,
        "tcav_avg": list(np.mean(vals_list, axis=1)),
        "tcav_std": list(np.std(vals_list, axis=1)),
        "tcav_upper_bound": list(np.percentile(vals_list, 75, axis=1)),
        "tcav_lower_bound": list(np.percentile(vals_list, 25, axis=1)),
        "tcav_median": list(np.median(vals_list, axis=1))
    }

In [None]:
os.makedirs('results', exist_ok=True)

with open(f'results/full finetuned.json', 'w') as f:
    json.dump(data, f)