In [None]:
'''
Overlap of NeurFlow, NeuronMCT, NeuCEPT
'''

import torch

import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

from utils import split_model, batch_inference, load_data, get_conditional_modules, Model_wrapper, _wrapper
import numpy as np
from itertools import combinations

from NeurFlow import Framework

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

directory = "./fidelity_of_neuron/full_16/"
data_dir = "/mnt/disk1/user/Tue.CM210908/imagenet"
model_name = "resnet50"
label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
tau = 16

In [None]:
import numpy as np

from torchvision import models
if model_name == "resnet50":
    model = models.resnet50(pretrained=True).eval()
    all_layers = ["layer4.2", "layer4.1", "layer4.0", "layer3.5", "layer3.4",
                  "layer3.3", "layer3.2", "layer3.1", "layer3.0", "layer2.3",
                  "layer2.2", "layer2.1", "layer2.0", "layer1.2", "layer1.1", "layer1.0"]
    conditional_modules = get_conditional_modules(model_name)
elif model_name == "googlenet":
    model = models.googlenet(pretrained=True).eval()
    all_layers = ["inception5b", "inception5a", "inception4e", "inception4d",
                  "inception4c", "inception4b", "inception4a", "inception3b",
                  "inception3a"]
    conditional_modules = get_conditional_modules(model_name)
else:
    raise ValueError("Model not supported")

In [None]:
from captum.attr import IntegratedGradients
from knockpy.knockoff_filter import KnockoffFilter
from tqdm import tqdm

def get_integrated_gradients(
    syn_data: torch.Tensor,
    fm_id: int,
    netB: torch.nn.Module,
    batch_size: int = 8
) -> torch.Tensor:
    integrated_gradients = IntegratedGradients(netB.to(device))
    importance = []
    for i in range(0, len(syn_data), batch_size):
        batch_data = syn_data[i:i + batch_size].to(device)
        batch_data.requires_grad = True
        attributions_ig = integrated_gradients.attribute(
            batch_data, baselines=batch_data * 0, target=fm_id
        )
        importance.append(
            _wrapper(attributions_ig.detach().cpu())
        )
        batch_data.cpu().detach()
    return torch.sum(torch.abs(torch.cat(importance)), dim=0)

def get_knockoff(
    syn_data: torch.Tensor,
    fm_id: int,
    activation_target: torch.Tensor,
) -> torch.Tensor: 
    attr = KnockoffFilter(ksampler='gaussian', fstat='lasso')
    
    def attribute(x, target):
        x = _wrapper(x).detach().cpu().numpy()
        attr.forward(
            X=x, y=activation_target[:, target].detach().cpu().numpy(), fdr=1.0
        )
        return torch.from_numpy(attr.W)
    
    importance = []
    attributions_ig = attribute(syn_data, fm_id)
    importance.append(attributions_ig)
    
    return torch.abs(torch.cat(importance))

def calculate_f1(set1, set2):
    common_neurons = set1.intersection(set2)
    tp = len(common_neurons) / len(set2) if len(set2) > 0 else 0
    fp = len(set1 - set2) / len(set1) if len(set1) > 0 else 0
    fn = len(set2 - set1) / len(set2) if len(set2) > 0 else 0
    return 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0

In [None]:
f1_scores = {}

for label_id in label_list:
    path = directory+ f"store_{model_name}_label{label_id}_tau{tau}.pth"
    store = torch.load(path, map_location="cpu", weights_only=False)
    
    FW = store["FW"]
    layers = store["layers"]
    del store
    
    all_images, all_labels = load_data(data_dir, [label_id]) 
    with torch.no_grad():
        outputs = batch_inference(model, all_images, device=device)
    
    num_layers = len(FW.layers.keys()) - 1 # exclude the last layer
    num_nodes_per_layer = [len(FW.layers[i].keys()) for i in range(1, num_layers+1)]
    num_nodes = sum(num_nodes_per_layer)

    all_neuronmct = {} # layer : neuronmct
    all_knockoff = {} # layer : knockoff
    for layer_index in tqdm(range(1, num_layers+1)):
        netA, netB = split_model(model, layers[-layer_index-1], True, conditional_modules)
        intermediate = batch_inference(netA, all_images, device=device)
        # neuronmct (IG but to the output of the model)
        importance = get_integrated_gradients(intermediate, label_id, netB)
        # knockoff score
        knockoff_score = get_knockoff(intermediate, label_id, outputs)
        all_neuronmct[layer_index] = importance
        all_knockoff[layer_index] = knockoff_score
        
    top_neurons_neuronmct = {}
    top_neurons_neurflow = {}
    top_neurons_knockoff = {}
    for layer_index, scores in all_neuronmct.items():
        # neuronmct (IG but to the output of the model)
        top_neurons_neuronmct[layer_index] = set(torch.topk(scores, len(FW.layers[layer_index]))[1].numpy())
        # neurflow
        top_neurons_neurflow[layer_index] = set(FW.layers[layer_index].keys())
        # knockoff
        top_neurons_knockoff[layer_index] = set(torch.topk(all_knockoff[layer_index], len(FW.layers[layer_index]))[1].numpy())


    for layer_index in range(1, num_layers + 1):
        set1 = top_neurons_neuronmct[layer_index]
        set2 = top_neurons_neurflow[layer_index]
        set3 = top_neurons_knockoff[layer_index]

        # Store all sets in a list
        sets = [set1, set2, set3]
        set_names = ["NeuronMCT", "NeurFlow", "Knockoff"]

        # Calculate F1 scores for all pairs
        for (i, j) in combinations(range(len(sets)), 2):
            pair_name = f"{set_names[i]}-{set_names[j]}"
            if pair_name not in f1_scores:
                f1_scores[pair_name] = []
            f1_scores[pair_name].append(calculate_f1(sets[i], sets[j]))

# Print mean F1 scores for each pair
for pair, scores in f1_scores.items():
    print(f"Mean F1 Score for {pair}: {np.mean(scores):.4f}")