In [1]:
# get a folder full of images whre the image-classifiers are wrong
# prepare VIT models
# compute alignment scores
# get spearman correlation 
# plot the correlation matrix

In [2]:
from datasets import load_dataset
import random

# Load the dataset
ds = load_dataset("cais/imagenet-o")

# Randomly sample 100 image indices
random_indices = random.sample(range(len(ds['test'])), 100)
sampled_images = [ds['test'][idx]['image'] for idx in range(len(ds))]

Resolving data files:   0%|          | 0/2000 [00:00<?, ?it/s]

In [3]:
import timm
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# List of models to pre-download weights for
model_list = [
    "vit_tiny_patch16_224.augreg_in21k",
    "vit_small_patch16_224.augreg_in21k",
    "vit_base_patch16_224.augreg_in21k",
    "vit_large_patch16_224.augreg_in21k",
    "vit_base_patch16_224.mae",
    "vit_large_patch16_224.mae",
    "vit_huge_patch14_224.mae",
    "vit_small_patch14_dinov2.lvd142m",
    "vit_base_patch14_dinov2.lvd142m",
    "vit_large_patch14_dinov2.lvd142m",
    "vit_giant_patch14_dinov2.lvd142m",
    "vit_base_patch16_clip_224.laion2b",
    "vit_large_patch14_clip_224.laion2b",
    "vit_huge_patch14_clip_224.laion2b",
    "vit_base_patch16_clip_224.laion2b_ft_in12k",
    "vit_large_patch14_clip_224.laion2b_ft_in12k",
    "vit_huge_patch14_clip_224.laion2b_ft_in12k",
]

# Define a function to download weights for a single model
def download_model_weights(model_name):
    try:
        # Load the model to download weights and immediately delete to free memory
        model = timm.create_model(model_name, pretrained=True)
        del model
        return f"Downloaded weights for {model_name}"
    except Exception as e:
        return f"Failed to download weights for {model_name}: {e}"

# Use ThreadPoolExecutor for parallel downloading
print("Pre-downloading weights for all models in parallel...")
with ThreadPoolExecutor() as executor:
    # Submit all download tasks
    futures = {executor.submit(download_model_weights, model_name): model_name for model_name in model_list}

    # Use tqdm to track progress of the downloads
    for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading model weights", leave=True):
        result = future.result()
        print(result)

print("All model weights pre-downloaded.")

Pre-downloading weights for all models in parallel...


Downloading model weights:   6%|▌         | 1/17 [00:32<08:45, 32.83s/it]

Downloaded weights for vit_small_patch14_dinov2.lvd142m


Downloading model weights:  12%|█▏        | 2/17 [00:35<03:42, 14.80s/it]

Downloaded weights for vit_base_patch16_224.augreg_in21k
Downloaded weights for vit_tiny_patch16_224.augreg_in21k


Downloading model weights:  24%|██▎       | 4/17 [00:35<01:14,  5.71s/it]

Downloaded weights for vit_base_patch14_dinov2.lvd142m


Downloading model weights:  29%|██▉       | 5/17 [00:36<00:50,  4.21s/it]

Downloaded weights for vit_base_patch16_224.mae


Downloading model weights:  35%|███▌      | 6/17 [00:37<00:33,  3.09s/it]

Downloaded weights for vit_small_patch16_224.augreg_in21k


Downloading model weights:  41%|████      | 7/17 [00:38<00:24,  2.48s/it]

Downloaded weights for vit_base_patch16_clip_224.laion2b


Downloading model weights:  47%|████▋     | 8/17 [00:38<00:16,  1.84s/it]

Downloaded weights for vit_base_patch16_clip_224.laion2b_ft_in12k


model.safetensors:  22%|##2       | 283M/1.26G [00:00<?, ?B/s]

Downloading model weights:  53%|█████▎    | 9/17 [01:13<01:34, 11.75s/it]

Downloaded weights for vit_large_patch14_dinov2.lvd142m


Downloading model weights:  59%|█████▉    | 10/17 [01:14<00:59,  8.56s/it]

Downloaded weights for vit_large_patch16_224.mae


Downloading model weights:  65%|██████▍   | 11/17 [01:16<00:40,  6.71s/it]

Downloaded weights for vit_large_patch16_224.augreg_in21k


Downloading model weights:  71%|███████   | 12/17 [01:17<00:24,  4.96s/it]

Downloaded weights for vit_large_patch14_clip_224.laion2b


model.safetensors:  30%|##9       | 765M/2.58G [00:00<?, ?B/s]

open_clip_pytorch_model.bin:  77%|#######6  | 3.03G/3.94G [00:00<?, ?B/s]

Downloading model weights:  76%|███████▋  | 13/17 [01:33<00:32,  8.22s/it]

Downloaded weights for vit_huge_patch14_224.mae


Downloading model weights:  82%|████████▏ | 14/17 [01:43<00:26,  8.92s/it]

Downloaded weights for vit_giant_patch14_dinov2.lvd142m


Downloading model weights:  88%|████████▊ | 15/17 [08:17<04:08, 124.32s/it]

Downloaded weights for vit_large_patch14_clip_224.laion2b_ft_in12k


Downloading model weights:  94%|█████████▍| 16/17 [08:41<01:34, 94.27s/it] 

Downloaded weights for vit_huge_patch14_clip_224.laion2b


Downloading model weights: 100%|██████████| 17/17 [11:01<00:00, 38.93s/it] 

Downloaded weights for vit_huge_patch14_clip_224.laion2b_ft_in12k
All model weights pre-downloaded.





In [None]:
import os
import torch
import gc
import timm
from tqdm import tqdm
from torchvision import transforms
from PIL import Image

# Directory to store saved features
FEATURES_DIR = "model_features"
os.makedirs(FEATURES_DIR, exist_ok=True)

# Check if GPU is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def convert_to_rgb(image):
    if image.mode != "RGB":
        image = image.convert("RGB")
    return image

def extract_and_save_features(model_name, images, batch_size=8):
    feature_path = os.path.join(FEATURES_DIR, f"{model_name}_features.pt")
    if os.path.exists(feature_path):
        print(f"Features for {model_name} already exist. Loading...")
        return torch.load(feature_path, map_location='cpu')

    print(f"Extracting features for {model_name}...")
    
    # Load the model with debug messages
    try:
        print(f"Loading model {model_name}...")
        model = timm.create_model(model_name, pretrained=True)
        print(f"Model {model_name} loaded successfully.")
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        return None
    
    # Move the model to the GPU
    model.to(device)
    model.eval()
    
    # Check if model is on GPU
    if next(model.parameters()).is_cuda:
        print(f"{model_name} is on GPU.")
    else:
        print(f"{model_name} is not on GPU as expected.")

    # Get the model's expected input size
    input_size = model.default_cfg.get('input_size', (3, 224, 224))
    img_height, img_width = input_size[1], input_size[2]

    # Define a transform for preprocessing images
    transform = transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    features = []
    with torch.no_grad():
        for i in tqdm(range(0, len(images), batch_size), desc=f"Extracting features for {model_name}"):
            batch_images = images[i:i + batch_size]
            batch_tensors = torch.stack([transform(convert_to_rgb(img)) for img in batch_images])
            batch_tensors = batch_tensors.to(device)

            try:
                feats = model(batch_tensors).flatten(start_dim=1)
                features.append(feats.cpu())
            except Exception as e:
                print(f"Error during inference with {model_name}: {e}")
                return None

            # Clear CUDA cache to free up memory
            torch.cuda.empty_cache()
            gc.collect()

    # Concatenate and normalize features
    features = torch.cat(features, dim=0)
    features = torch.nn.functional.normalize(features, dim=-1)

    # Save the features for future use
    torch.save(features, feature_path)
    print(f"Features for {model_name} saved to {feature_path}")

    # Clear model from memory after saving
    del model
    torch.cuda.empty_cache()
    gc.collect()

    return features

# Extract and save features for all models with batching
model_features = {}
for model_name in tqdm(model_list, desc="Extracting features for all models"):
    model_features[model_name] = extract_and_save_features(model_name, sampled_images, batch_size=16)

In [None]:
from metrics import AlignmentMetrics
import itertools
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

alignment_scores = {}

# Define a function to compute alignment score for a pair of models
def compute_alignment_score(model_pair):
    model_a, model_b = model_pair
    feats_a = extract_features(model_a, sampled_images)
    feats_b = extract_features(model_b, sampled_images)
    score = AlignmentMetrics.cknna(feats_a, feats_b, topk=10)
    return (model_a, model_b), score

# List of all model pairs
model_pairs = list(itertools.combinations(model_list, 2))

# Use ProcessPoolExecutor for parallel execution
with ProcessPoolExecutor() as executor:
    # Submit tasks to the executor
    futures = {executor.submit(compute_alignment_score, pair): pair for pair in model_pairs}
    
    # Use tqdm to track progress of the tasks as they complete
    for future in tqdm(as_completed(futures), total=len(futures), desc="Calculating alignment scores", leave=True):
        model_pair, score = future.result()
        alignment_scores[model_pair] = score

In [None]:
import pandas as pd

# Create a DataFrame for storing scores
score_df = pd.DataFrame(
    list(alignment_scores.values()),
    index=pd.MultiIndex.from_tuples(alignment_scores.keys(), names=["Model A", "Model B"]),
    columns=["Score"]
).unstack()

# Compute Spearman correlation matrix
correlation_matrix = score_df.corr(method='spearman')


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', square=True)
plt.title('Spearman Correlation of Alignment Scores')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', square=True)
plt.title('Spearman Correlation of Alignment Scores')
plt.tight_layout()
plt.show()