In [None]:
'''
The independence on the value of k
'''

import torch

import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

from utils import get_sub_model, split_model, batch_inference, get_conditional_modules, Model_wrapper, _wrapper, load_model_data, get_crop_data
import numpy as np
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "googlenet"
data_dir = "/mnt/disk1/user/Tue.CM210908/imagenet"
labels_list =  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

from torchvision import models

model_name = model_name
if model_name == "resnet50":
    model = models.resnet50(pretrained=True).eval()
    layers = ["fc", "layer4.2", "layer4.1", "layer4.0", "layer3.5", "layer3.4", "layer3.3", "layer3.2", "layer3.1", "layer3.0", "layer2.3"]
    conditional_modules = get_conditional_modules(model_name)
elif model_name == "googlenet":
    model = models.googlenet(pretrained=True).eval()
    layers = ["fc", "inception5b", "inception5a", "inception4e", "inception4d", "inception4c", "inception4b", "inception4a", "inception3b", "inception3a", "maxpool2"]
    conditional_modules = get_conditional_modules(model_name)
elif model_name == "alexnet":
    model = models.alexnet(pretrained=True).eval()
    layers = ["classifier.6", "classifier.5", "classifier.2", "features.12", "features.9", "features.7", "features.5", "features.2"]
    conditional_modules = get_conditional_modules(model_name)
else:
    raise ValueError("Model not supported")

In [None]:
from captum.attr import IntegratedGradients

def get_integrated_gradients(
    syn_data: torch.Tensor,
    fm_id: int,
    netB: torch.nn.Module,
) -> torch.Tensor:
    integrated_gradients = IntegratedGradients(netB.to(device))
    importance = []
    for activation in syn_data:
        activation = activation.unsqueeze(0).to(device)
        activation.requires_grad = True
        attributions_ig = integrated_gradients.attribute(
            activation, baselines=activation * 0, target=fm_id
        )
        importance.append(
            _wrapper(attributions_ig.detach().cpu())
        )
        
    return torch.sum(torch.abs(torch.cat(importance)), dim=0)

def masked_probing(data, indices, net, device = "cuda", batch_size = 256, reverse = True):
    data = data.to("cpu")
    if reverse:
        reverse_indices = [i for i in range(data.shape[1]) if i not in indices]
        data[:, reverse_indices] = 0
    else:
        data[:, indices] = 0    
    return batch_inference(net, data, batch_size=batch_size, device = device)

def probing(data, net, device = "cuda", batch_size = 128):
    return batch_inference(net, data, batch_size=batch_size, device = device)

def top_img(probed, num_img = 5):
    _, indices = torch.topk(probed, num_img)
    return indices

def overlap(tensor1, tensor2):
    common_elements = np.intersect1d(tensor1, tensor2).shape[0]
    return common_elements

In [None]:
from tqdm import tqdm
test_nodes = [i for i in range(10)]
list_num_top_imgs = [10, 30, 40, 50, 60, 70, 90, 110, 130, 150, 170, 190]
tau = 16

all_list_overlap = {num_top_imgs : [] for num_top_imgs in list_num_top_imgs}

for label_id in tqdm(labels_list):
    print(f"Label: {label_id}")
    class_images, class_labels = load_model_data(data_dir, [label_id], model, device) 
    concept_data = get_crop_data(class_images)
    layer_index = random.choice(range(1, len(layers)+1)) # random layer

    netA, _ = split_model(
        model, layers[-layer_index], True, conditional_modules=conditional_modules
    )
    netB = Model_wrapper(
        get_sub_model(
            model, 
            layers[-layer_index], 
            layers[-layer_index-1], 
            True, 
            conditional_modules,
        )
    )

    intermediate = batch_inference(netA, concept_data, device=device)
    activation = batch_inference(netB, intermediate, device=device)

    num_node = intermediate.shape[1]
    all_top_neurons = {node : [] for node in test_nodes}
    baselines = {}
    for test_node in tqdm(test_nodes):
        for num_top_imgs in list_num_top_imgs:
            top_imgs = torch.topk(activation[:, test_node], num_top_imgs)[1].detach().cpu().numpy()
            
            importance = get_integrated_gradients(
                intermediate[torch.from_numpy(top_imgs)], test_node, netB
            ).detach().cpu().numpy()
            
            top_neurons = torch.topk(torch.from_numpy(importance), tau)[1].detach().cpu().numpy()
            all_top_neurons[test_node].append(top_neurons)
            
            if num_top_imgs == 50:
                baselines[test_node] = top_neurons
                
    for i, num_top_imgs in enumerate(list_num_top_imgs):
        list_overlap = []
        for test_node in test_nodes:
            list_overlap.append(overlap(all_top_neurons[test_node][i], baselines[test_node]))
        all_list_overlap[num_top_imgs].append(list_overlap)

for num_top_imgs in list_num_top_imgs:
    print(f"{num_top_imgs}: {np.mean(all_list_overlap[num_top_imgs])}")