In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from captum.attr import IntegratedGradients, Saliency, DeepLift
import numpy as np
from scipy.spatial.distance import euclidean
from scipy.stats import spearmanr
import os
import timm

In [None]:
# Dataset Preparation
num_classes = 10
batch_size = 32
image_size = 224

device = torch.device('cpu')

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=f'coco1400_perclass', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Custom sorting to ensure numerical order if filenames don't have leading zeros


test_dataset = datasets.ImageFolder(root=f'coco200_perclass', transform=transform)
test_dataset.samples.sort(key=lambda x: int(os.path.splitext(os.path.basename(x[0]))[0].replace('im', '')))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Paths to saved models
model_paths = {
    'resnet50': 'models/resnet50_fine_tuned.pth',
    'resnet18': 'models/resnet18_fine_tuned.pth',
    'alexnet': 'models/alexnet_fine_tuned.pth',
    'convnext': 'models/convnext_fine_tuned.pth',
    'vgg19': 'models/vgg19_fine_tuned.pth',
    'vgg16': 'models/vgg16_fine_tuned.pth',
    'vit': 'models/vit_fine_tuned.pth',
    'vit_ssl': 'models/vit_ssl_fine_tuned.pth',
    'resnet_ssl': 'models/resnet_ssl_fine_tuned.pth',
    'efficientnet': 'models/efficientnet_fine_tuned.pth',
    'swin': 'models/swin_fine_tuned.pth'
}

# Initialize models with pretrained weights
models_dict = {
    'resnet50': models.resnet50(pretrained=True),
    'resnet18': models.resnet18(pretrained=True),
    'alexnet': models.alexnet(pretrained=True),
    'vit': models.vit_b_16(pretrained=True),
    'vgg19': models.vgg19(pretrained=True),
    'vgg16': models.vgg16(pretrained=True),
    'convnext': timm.create_model("convnext_base", pretrained=True),
    'vit_ssl': timm.create_model("vit_small_patch16_224_dino", pretrained=True),
    'resnet_ssl': torch.hub.load('facebookresearch/dino:main', 'dino_resnet50'),
    'efficientnet': timm.create_model("efficientnet_b3", pretrained=True),
    'swin': timm.create_model("swin_base_patch4_window7_224", pretrained=True)
}

# Adjust the final layer 
models_dict['resnet50'].fc = torch.nn.Linear(models_dict['resnet50'].fc.in_features, num_classes)
models_dict['resnet18'].fc = torch.nn.Linear(models_dict['resnet18'].fc.in_features, num_classes)
models_dict['alexnet'].classifier[6] = torch.nn.Linear(models_dict['alexnet'].classifier[6].in_features, num_classes)
models_dict['vit'].heads.head = torch.nn.Linear(models_dict['vit'].heads.head.in_features, num_classes)
models_dict['vgg16'].classifier[6] = torch.nn.Linear(models_dict['vgg16'].classifier[6].in_features, num_classes)
models_dict['vgg19'].classifier[6] = torch.nn.Linear(models_dict['vgg19'].classifier[6].in_features, num_classes)
models_dict['convnext'].head.fc = torch.nn.Linear(1024, num_classes)
models_dict['vit_ssl'].head = nn.Linear(384, num_classes)
models_dict['resnet_ssl'].fc = torch.nn.Linear( (models.resnet50(pretrained=True)).fc.in_features, num_classes)
models_dict['efficientnet'].classifier = nn.Linear(models_dict['efficientnet'].classifier.in_features, num_classes)
models_dict['swin'].head.fc = torch.nn.Linear(1024, num_classes)

# Load the fine-tuned weights
for model_name, model in models_dict.items():
    model.load_state_dict(torch.load(model_paths[model_name], map_location=torch.device('cpu')))  # Adjust device if needed
    model.eval()  # Set the model to evaluation mode

print("Models successfully loaded and ready for inference!")


In [None]:
# Generate features for image-level behavior testing
for model_name, model in models_dict.items():
    features = []
    model.eval()
    model.to(device)
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            output = model(inputs)
            output = torch.nn.functional.softmax(output, dim=-1)
            features.append(output.detach().cpu().numpy())
    features = np.concatenate(features, axis=0)

    i1_scores = create_i1_test(features, model_name)
    np.save(f'./i1s_clean/i1_{model_name}.npy', i1_scores)
    print(f'I1 scores for {model_name}:', i1_scores.mean())