In [None]:
'''
integrated gradients vs l2-norm weight
'''

import torch

import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

from utils import split_model, batch_inference, get_conditional_modules, Model_wrapper, _wrapper, load_model_data, get_crop_data, accuracy
import numpy as np
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "googlenet"
data_dir = "/mnt/disk1/user/Tue.CM210908/imagenet"
labels_list =  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

from torchvision import models

model_name = model_name
if model_name == "resnet50":
    model = models.resnet50(pretrained=True).eval()
    layers = ["fc", "layer4.2", "layer4.1", "layer4.0", "layer3.5", "layer3.4", "layer3.3", "layer3.2", "layer3.1", "layer3.0", "layer2.3"]
    conditional_modules = get_conditional_modules(model_name)
elif model_name == "googlenet":
    model = models.googlenet(pretrained=True).eval()
    layers = ["fc", "inception5b", "inception5a", "inception4e", "inception4d", "inception4c", "inception4b", "inception4a", "inception3b", "inception3a", "maxpool2"]
    conditional_modules = get_conditional_modules(model_name)
elif model_name == "alexnet":
    model = models.alexnet(pretrained=True).eval()
    layers = ["classifier.6", "classifier.5", "classifier.2", "features.12", "features.9", "features.7", "features.5", "features.2"]
    conditional_modules = get_conditional_modules(model_name)
else:
    raise ValueError("Model not supported")



In [None]:
from captum.attr import IntegratedGradients

def get_integrated_gradients(
    syn_data: torch.Tensor,
    fm_id: int,
    netB: torch.nn.Module,
) -> torch.Tensor:
    integrated_gradients = IntegratedGradients(netB.to(device))
    importance = []
    for activation in syn_data:
        activation = activation.unsqueeze(0).to(device)
        activation.requires_grad = True
        attributions_ig = integrated_gradients.attribute(
            activation, baselines=activation * 0, target=fm_id
        )
        importance.append(
            _wrapper(attributions_ig.detach().cpu())
        )
        
    return torch.sum(torch.abs(torch.cat(importance)), dim=0)

def masked_probing(data, indices, net, device = "cuda", batch_size = 256, reverse = True):
    data = data.to("cpu")
    if reverse:
        reverse_indices = [i for i in range(data.shape[1]) if i not in indices]
        data[:, reverse_indices] = 0
    else:
        data[:, indices] = 0    
    return batch_inference(net, data, batch_size=batch_size, device = device)

def probing(data, net, device = "cuda", batch_size = 128):
    return batch_inference(net, data, batch_size=batch_size, device = device)

def top_img(probed, num_img = 5):
    _, indices = torch.topk(probed, num_img)
    return indices

In [None]:
from tqdm import tqdm

nums_top_neurons = [i for i in range(1, 17)]
num_test_nodes = 1
num_top_imgs = 50

if model_name == "resnet50":
    all_layers = ["layer4.2", "layer4.1", "layer4.0", "layer3.5",
                "layer3.4", "layer3.3", "layer3.2", "layer3.1", 
                "layer3.0", "layer2.3"]
elif model_name == "googlenet":
    all_layers = ["inception5b", "inception5a", "inception4e", "inception4d",
                "inception4c", "inception4b", "inception4a", "inception3b",
                "inception3a", "maxpool2"]
else:
    raise ValueError("Model not supported")

diff_acc = []
for layer_index in range(len(all_layers)-1):
    print("Layer:", all_layers[layer_index])
    
    start_layer = all_layers[layer_index+1]
    end_layer = all_layers[layer_index]
    
    if model_name == "resnet50":
        conv_end_layer = [end_layer.replace(".", "_"), "conv1"]
    elif model_name == "googlenet":
        conv_end_layer = [end_layer, "branch1", "conv"]
    else:
        raise ValueError("Model not supported")
    
    netA, temp = split_model(model, start_layer, True, conditional_modules=conditional_modules)
    netB = Model_wrapper(
        split_model(
            temp, '.'.join(conv_end_layer), 
            True, 
            conditional_modules=conditional_modules
        )[0]
    )
    
    for label_id in tqdm(labels_list):
        class_images, class_labels = load_model_data(data_dir, [label_id], model, device) 
        concept_data = get_crop_data(class_images)
        
        intermediate = batch_inference(netA, concept_data, device=device)
        activation = batch_inference(netB, intermediate, device=device)
        
        num_node = activation.shape[1]
        test_nodes = np.random.choice(num_node, num_test_nodes).tolist()  # choose random nodes
        
        for test_node in test_nodes:
            top_imgs = torch.topk(activation[:, test_node], num_top_imgs)[1].detach().cpu().numpy()
            importance = get_integrated_gradients(intermediate[torch.from_numpy(top_imgs)], test_node, netB).detach().cpu().numpy()
            if (np.sum(importance) == 0):
                print("No importance!")
                
            weights = netB.model._modules['_'.join(conv_end_layer)].weight[test_node].clone() # type: ignore
            l2_norm = torch.norm(weights.reshape(weights.shape[0], -1), dim=1)
            for num_top_neurons in nums_top_neurons:
                top_neurons = torch.topk(torch.from_numpy(importance), num_top_neurons)[1].detach().cpu().numpy()
                top_filters = torch.topk(l2_norm, num_top_neurons)[1].detach().cpu().numpy()

                mine = masked_probing(
                    intermediate.clone(), top_neurons, netB, device = device, reverse=False, batch_size=64
                )[:, test_node].detach().cpu()
                top_img_mine = top_img(mine, num_img=num_top_imgs).numpy()
                mine_acc = accuracy(top_imgs, top_img_mine)

                them = masked_probing(
                    intermediate.clone(), top_filters, netB, device = device, reverse=False, batch_size=64
                )[:, test_node].detach().cpu()
                top_img_them = top_img(them, num_img=num_top_imgs).numpy()
                them_acc = accuracy(top_imgs, top_img_them)

                diff_acc.append(mine_acc - them_acc)
            
print("Mean diff acc:", np.mean(diff_acc))