Feature Extraction: Extract convnext IT features following the presentations of the EMIs

In [None]:
import os
import torch
from torchvision import datasets, transforms
import numpy as np
import timm
import re

In [None]:
num_classes = 10
batch_size = 32
image_size = 224

device = torch.device('cpu')

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset = datasets.ImageFolder(root=f'coco200_perclass', transform=transform)
test_dataset.samples.sort(key=lambda x: int(os.path.splitext(os.path.basename(x[0]))[0].replace('im', '')))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
from torchvision import datasets, models, transforms

# Paths to saved models
model_path = 'models/resnet_ssl_fine_tuned.pth'

# Initialize models with pretrained weights
model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
model.fc = torch.nn.Linear( (models.resnet50(pretrained=True)).fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))  # Adjust device if needed
model.eval()  # Set the model to evaluation mode

In [None]:
def hook_fn(module, input, output):
    output_features.append(output)

def run_inference(model, data_loader):
    global output_features
    output_features = []
    
    with torch.no_grad():
        for batch in data_loader:
            images = batch[0]
            _ = model(images)
    
    return output_features

In [None]:
hook = model.layer4[0].register_forward_hook(hook_fn)

In [None]:

output_features = []

with torch.no_grad():
    for batch in test_loader:
        images = batch[0]
        _ = model(images)

hook.remove()
output_features = torch.cat(output_features, dim=0)
np.save(f'./resnet_ssl_features_clean.npy', output_features.numpy())

In [None]:
pattern_pemi = re.compile(r'pEMI_(\d+)\.npy')
percentiles = [50, 60, 70, 80, 90, 95]
model_name = "resnet_ssl"

for p in percentiles:
    pemis_dir = f'./perturbed_images/resnet_ssl/NoiseTunnel_Saliency/{p}'
    pemis_files = sorted(
            [f for f in os.listdir(pemis_dir) if pattern_pemi.match(f)],
            key=lambda x: int(pattern_pemi.match(x).group(1))
        )

    hook = model.layer4[0].register_forward_hook(hook_fn)

    output_features = []

    with torch.no_grad():
        for p_file in pemis_files:
            pEMI_tensor = torch.tensor(np.load(os.path.join(pemis_dir, p_file))).unsqueeze(0).to(device)
            print(pEMI_tensor.shape)

            _ = model(pEMI_tensor)

    hook.remove()
    output_features = torch.cat(output_features, dim=0)
    np.save(f'./features/{model_name}_features_NoiseTunnel_Saliency_{p}.npy', output_features.numpy())


for p in percentiles:
    pemis_dir = f'./perturbed_images/resnet_ssl/Deconvolution/{p}'
    pemis_files = sorted(
            [f for f in os.listdir(pemis_dir) if pattern_pemi.match(f)],
            key=lambda x: int(pattern_pemi.match(x).group(1))
        )

    hook = model.layer4[0].register_forward_hook(hook_fn)

    output_features = []

    with torch.no_grad():
        for p_file in pemis_files:
            pEMI_tensor = torch.tensor(np.load(os.path.join(pemis_dir, p_file))).unsqueeze(0).to(device)
            print(pEMI_tensor.shape)

            _ = model(pEMI_tensor)

    hook.remove()
    output_features = torch.cat(output_features, dim=0)
    np.save(f'./features/{model_name}_features_Deconvolution_{p}.npy', output_features.numpy())

Once the features are extracted, predictions are done using the following github repository: https://github.com/vital-kolab/reverse_pred

To proceed, clone the repository (git clone https://github.com/vital-kolab/reverse_pred.git) and follow the README.md instructions to start. Then run:

In [None]:
model = "clean"
monkey = "pooled"
out_dir = f'./results_explainability/model2monkey/{model}'
data_dir = "/scratch/smuzelle/results_predictions"
n_images = 200

from model_to_monkey import main

main(model, monkey, out_dir, n_images, data_dir)

In [None]:
percentiles = [50, 60, 70, 80, 90, 95]

In [None]:
for p in percentiles:
    model = f"NoiseTunnel_Saliency_{p}"
    monkey = "pooled"
    out_dir = f'./results_explainability/model2monkey/{model}'
    data_dir = "/scratch/smuzelle/results_predictions"
    n_images = 200
    if not os.path.exists(os.path.join(out_dir, f'forward_{monkey}_ev.npy')):
        print(p)
        main(model, monkey, out_dir, n_images, data_dir)
    

In [None]:
for p in percentiles:
    model = f"Deconvolution_{p}"
    monkey = "pooled"
    out_dir = f'./results_explainability/model2monkey/{model}'
    data_dir = "/scratch/smuzelle/results_predictions"
    n_images = 200
    if not os.path.exists(os.path.join(out_dir, f'forward_{monkey}_ev.npy')):
        print(p)
        main(model, monkey, out_dir, n_images, data_dir)
    

In [None]:
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt 
from utils import journal_figure

In [None]:
monkey= "pooled"

In [None]:
out_dir = f'./results_explainability/model2monkey/clean'
ev_path = os.path.join(out_dir, f'forward_{monkey}_ev.npy')
ev_clean = np.load(ev_path)
plt.axhline(ev_clean.mean(), color="black", linestyle="--")

evs_best = []
evs_best_err = []

evs_worst = []
evs_worst_err = []

for p in percentiles:
    model = f"NoiseTunnel_Saliency_{p}"
    monkey = "pooled"
    out_dir = f'./results_explainability/model2monkey/{model}'

    ev_path = os.path.join(out_dir, f'forward_{monkey}_ev.npy')
    ev = np.load(ev_path)
    plt.scatter(p, ev.mean(), color="deepskyblue")
    plt.errorbar(p, ev.mean(), yerr=ev.std()/np.sqrt(200), color="deepskyblue")
    evs_best.append(ev.mean())
    evs_best_err.append(ev.std()/np.sqrt(200))

    model = f"Deconvolution_{p}"
    monkey = "pooled"
    out_dir = f'./results_explainability/model2monkey/{model}'

    ev_path = os.path.join(out_dir, f'forward_{monkey}_ev.npy')
    ev = np.load(ev_path)
    plt.scatter(p, ev.mean(), color="limegreen")
    plt.errorbar(p, ev.mean(), yerr=ev.std()/np.sqrt(200), color="limegreen")
    evs_worst.append(ev.mean())
    evs_worst_err.append(ev.std()/np.sqrt(200))

x = np.array(percentiles, dtype=float)
y = np.array(evs_best)

yerr = np.array(evs_best_err)

order = np.argsort(x)
x = x[order]
y = y[order]
yerr = yerr[order]

y_fit = savgol_filter(y, window_length=5, polyorder=2)
plt.plot(x, y_fit, color="deepskyblue")

B = 2000
rng = np.random.default_rng(0)
Yb = np.empty((B, len(x)))
for b in range(B):
    yb = rng.normal(y, yerr)                    # resample points
    Yb[b] = savgol_filter(yb, 5, 2)             # refit smoother

lo, hi = np.percentile(Yb, [2.5, 97.5], axis=0)

plt.fill_between(x, lo, hi, color="deepskyblue", alpha=0.15, linewidth=0)

x = np.array(percentiles, dtype=float)
y = np.array(evs_worst)

yerr = np.array(evs_worst_err)

order = np.argsort(x)
x = x[order]
y = y[order]
yerr = yerr[order]

y_fit = savgol_filter(y, window_length=5, polyorder=2)
plt.plot(x, y_fit, color="limegreen")

B = 2000
rng = np.random.default_rng(0)
Yb = np.empty((B, len(x)))
for b in range(B):
    yb = rng.normal(y, yerr)                    # resample points
    Yb[b] = savgol_filter(yb, 5, 2)             # refit smoother

lo, hi = np.percentile(Yb, [2.5, 97.5], axis=0)

plt.fill_between(x, lo, hi, color="limegreen", alpha=0.15, linewidth=0)

journal_figure()