In [1]:
# get a folder full of images whre the image-classifiers are wrong
# prepare VIT models
# compute alignment scores
# get spearman correlation 
# plot the correlation matrix

In [2]:
from datasets import load_dataset
import random
import timm
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import torch
import torch.nn.functional as F
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import seaborn as sns

# -------------------- Initialization --------------------

# Load the dataset
ds = load_dataset("cais/imagenet-o")

# Load all images from the test split
print("Loading all images from the test split...")
sampled_images = [ds['test'][idx]['image'] for idx in tqdm(range(len(ds['test'])), desc="Loading images")]

# List of models to pre-download weights for
model_list = [
    "vit_tiny_patch16_224.augreg_in21k",
    "vit_small_patch16_224.augreg_in21k",
    "vit_base_patch16_224.augreg_in21k",
    "vit_large_patch16_224.augreg_in21k",
    "vit_base_patch16_224.mae",
    "vit_large_patch16_224.mae",
    "vit_huge_patch14_224.mae",
    "vit_small_patch14_dinov2.lvd142m",
    "vit_base_patch14_dinov2.lvd142m",
    "vit_large_patch14_dinov2.lvd142m",
    "vit_giant_patch14_dinov2.lvd142m",
    "vit_base_patch16_clip_224.laion2b",
    "vit_large_patch14_clip_224.laion2b",
    "vit_huge_patch14_clip_224.laion2b",
    "vit_base_patch16_clip_224.laion2b_ft_in12k",
    "vit_large_patch14_clip_224.laion2b_ft_in12k",
    "vit_huge_patch14_clip_224.laion2b_ft_in12k",
]

# Directory to store saved features
FEATURES_DIR = "model_features"
os.makedirs(FEATURES_DIR, exist_ok=True)

# -------------------- Feature Extraction Functions --------------------

def convert_to_rgb(image):
    """
    Convert a PIL image to RGB if it's not already.
    """
    if image.mode != "RGB":
        image = image.convert("RGB")
    return image

def extract_features(model_name, images, batch_size=32):
    """
    Extract the CLS token from the last layer of the model for all images.
    """
    feature_path = os.path.join(FEATURES_DIR, f"{model_name}_features.pt")
    
    # Check if features are already extracted
    if os.path.exists(feature_path):
        print(f"Loading existing features for {model_name}...")
        features = torch.load(feature_path)
        return features
    
    print(f"Extracting features for {model_name}...")
    
    try:
        # Load the model
        model = timm.create_model(model_name, pretrained=True)
        model.eval()
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        return None
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Define preprocessing transforms based on model's default config
    input_size = model.default_cfg.get('input_size', (3, 224, 224))
    transform = timm.data.transforms_factory.create_transform(
        input_size=input_size,
        is_training=False
    )
    
    features = []
    num_batches = (len(images) + batch_size - 1) // batch_size
    
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch_images = images[i:i + batch_size]
            try:
                # Apply transforms and stack into a batch tensor
                batch = torch.stack([transform(convert_to_rgb(img)) for img in batch_images]).to(device)
                
                # Forward pass to get features
                outputs = model.forward_features(batch)  # Shape: (batch_size, num_tokens, feature_dim)
                
                # Extract CLS token (first token)
                cls_tokens = outputs[:, 0, :]  # Shape: (batch_size, feature_dim)
                
                # Normalize features
                cls_tokens = F.normalize(cls_tokens, dim=-1)
                
                features.append(cls_tokens.cpu())
            except Exception as e:
                print(f"Error processing batch {i // batch_size} for {model_name}: {e}")
                continue
    
    if features:
        # Concatenate all features
        features = torch.cat(features, dim=0)  # Shape: (num_images, feature_dim)
        
        # Save features to disk
        torch.save(features, feature_path)
        print(f"Features for {model_name} saved to {feature_path}")
        return features
    else:
        print(f"No features extracted for {model_name}.")
        return None

def load_or_extract_features(model_name, images, batch_size=32):
    """
    Loads the precomputed features for a model if they exist.
    Otherwise, extracts the features and saves them.
    """
    feature_path = os.path.join(FEATURES_DIR, f"{model_name}_features.pt")
    if os.path.exists(feature_path):
        print(f"Loading features for {model_name} from {feature_path}...")
        features = torch.load(feature_path, map_location='cpu')
        print(f"Features for {model_name} loaded successfully, shape: {features.shape}")
        return features
    return extract_features(model_name, images, batch_size)

# -------------------- Feature Extraction Phase --------------------

# Extract features for all models
model_features = {}

for model_name in tqdm(model_list, desc="Extracting features for all models"):
    features = load_or_extract_features(model_name, sampled_images, batch_size=32)
    if features is not None:
        model_features[model_name] = features
    else:
        print(f"Features for {model_name} could not be extracted.")

# -------------------- Verification Phase --------------------

# Verify that each model has the expected number of samples and feature dimensions
expected_num_samples = len(sampled_images)
print("\nVerification of extracted features:")
for model_name, features in model_features.items():
    num_samples, feature_dim = features.size()
    print(f"Model {model_name}: {num_samples} samples, {feature_dim} features.")
    if num_samples != expected_num_samples:
        print(f"  Warning: Expected {expected_num_samples} samples, but got {num_samples}.")

# Identify models with sufficient samples (at least 10)
sufficient_models = [model_name for model_name, features in model_features.items() 
                     if features is not None and features.size(0) >= 10]

print(f"\nModels with sufficient samples ({len(sufficient_models)}):")
for model_name in sufficient_models:
    print(f" - {model_name}: {model_features[model_name].size(0)} samples")

Resolving data files:   0%|          | 0/2000 [00:00<?, ?it/s]

Loading all images from the test split...


Loading images: 100%|██████████| 2000/2000 [00:06<00:00, 314.33it/s]
  features = torch.load(feature_path, map_location='cpu')
Extracting features for all models: 100%|██████████| 17/17 [00:00<00:00, 85.98it/s]


Loading features for vit_tiny_patch16_224.augreg_in21k from model_features/vit_tiny_patch16_224.augreg_in21k_features.pt...
Features for vit_tiny_patch16_224.augreg_in21k loaded successfully, shape: torch.Size([2000, 192])
Loading features for vit_small_patch16_224.augreg_in21k from model_features/vit_small_patch16_224.augreg_in21k_features.pt...
Features for vit_small_patch16_224.augreg_in21k loaded successfully, shape: torch.Size([2000, 384])
Loading features for vit_base_patch16_224.augreg_in21k from model_features/vit_base_patch16_224.augreg_in21k_features.pt...
Features for vit_base_patch16_224.augreg_in21k loaded successfully, shape: torch.Size([2000, 768])
Loading features for vit_large_patch16_224.augreg_in21k from model_features/vit_large_patch16_224.augreg_in21k_features.pt...
Features for vit_large_patch16_224.augreg_in21k loaded successfully, shape: torch.Size([2000, 1024])
Loading features for vit_base_patch16_224.mae from model_features/vit_base_patch16_224.mae_features.p

In [None]:
# -------------------- Alignment Score Calculation --------------------

from metrics import AlignmentMetrics  # Ensure this is correctly implemented

def compute_alignment_score(model_a, model_b, metric='cknna', topk=10):
    """
    Compute the alignment score between two models using the specified metric.
    """
    feats_a = model_features[model_a]
    feats_b = model_features[model_b]
    
    if metric == 'cknna':
        score = AlignmentMetrics.cknna(feats_a, feats_b, topk=topk)
    elif metric == 'mutual_knn':
        score = AlignmentMetrics.mutual_knn(feats_a, feats_b, topk=topk)
    else:
        raise ValueError(f"Unknown metric: {metric}")
    
    return score

# List of metrics to compute
metrics_list = ['cknna', 'mutual_knn']  # Add more metrics if implemented

# Generate all possible model pairs
model_pairs = list(itertools.combinations(sufficient_models, 2))
print(f"\nTotal model pairs to compute alignment scores: {len(model_pairs)}")

# Initialize a dictionary to store alignment scores
alignment_scores = {metric: {} for metric in metrics_list}

# Function to compute alignment scores for a given pair across all metrics
def compute_scores_for_pair(pair):
    model_a, model_b = pair
    scores = {}
    for metric in metrics_list:
        try:
            score = compute_alignment_score(model_a, model_b, metric=metric, topk=10)
            scores[metric] = score
            print(f"Alignment score for ({model_a}, {model_b}) using '{metric}': {score}")
        except Exception as e:
            print(f"Error computing {metric} for ({model_a}, {model_b}): {e}")
            scores[metric] = None
    return (model_a, model_b), scores

# Use ProcessPoolExecutor for parallel alignment score computation
from concurrent.futures import ProcessPoolExecutor

print("\nCalculating alignment scores for all model pairs...")

with ProcessPoolExecutor() as executor:
    # Submit all alignment score computation tasks
    futures = {executor.submit(compute_scores_for_pair, pair): pair for pair in model_pairs}
    
    # Use tqdm to track progress
    for future in tqdm(as_completed(futures), total=len(futures), desc="Computing alignment scores"):
        pair, scores = future.result()
        for metric, score in scores.items():
            if score is not None:
                alignment_scores[metric][pair] = score

print("Alignment score calculation completed.")

# -------------------- DataFrame Construction --------------------

# Convert alignment_scores to a DataFrame
records = []
for metric, pairs_scores in alignment_scores.items():
    for pair, score in pairs_scores.items():
        records.append({
            'Model A': pair[0],
            'Model B': pair[1],
            'Metric': metric,
            'Score': score
        })

# Create the DataFrame
score_df = pd.DataFrame(records)

# Display a sample of the DataFrame
print("\nSample records:")
print(score_df.head())

# -------------------- Spearman Correlation Calculation --------------------

if not score_df.empty:
    # Pivot the DataFrame to have metrics as columns
    pivot_df = score_df.pivot(index=['Model A', 'Model B'], columns='Metric', values='Score')
    
    # Display a sample of the pivoted DataFrame
    print("\nSample of the pivoted DataFrame:")
    print(pivot_df.head())
    
    # Compute Spearman correlation matrix among different metrics
    correlation_matrix = pivot_df.corr(method='spearman')
    
    print("\nSpearman Correlation Matrix:")
    print(correlation_matrix)
    
    # -------------------- Visualization --------------------
    
    # Plot the Spearman correlation matrix as a heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        correlation_matrix, 
        annot=True, 
        cmap='coolwarm', 
        square=True, 
        center=0, 
        cbar_kws={'label': 'Spearman Correlation'}
    )
    plt.title('Spearman Correlation of Alignment Scores Across Metrics')
    plt.tight_layout()
    plt.show()
else:
    print("No valid alignment scores found. The DataFrame is empty.")