In [1]:
# from huggingface_hub import notebook_login
# notebook_login()

In [2]:
from datasets import load_dataset
ds = load_dataset('imagenet-1k')

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/267 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import random
import numpy as np

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32  # derive a seed unique per worker
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Create a generator with a fixed seed
g = torch.Generator()
g.manual_seed(42)

test_transforms = transforms.Compose([
    transforms.Resize(256),          # Resize the shorter side to 256 pixels
    transforms.CenterCrop(224),      # Crop the center 224x224 patch
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


class ImageNetDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        image = image.convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = item['label']
        return image, label

imagenet_val = ImageNetDataset(ds['validation'], transform=test_transforms)
val_loader = DataLoader(imagenet_val, batch_size=64, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g)

resnet18 = models.resnet18(pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18 = resnet18.to(device)



In [4]:
import torch
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np

def plot_random_cluster_samples(ds, indexes, partition, class_names, sample_count=8, seed=None, fontsize=16):
    """
    Plots randomly selected images from the cluster defined by indexes[partition] in a grid.
    
    Parameters:
        ds (dict): Dataset containing the images and labels (e.g., ds['validation'][i]['image'] and ['label']).
        indexes (dict or array-like): Data structure where indexes[partition] is a boolean array
                                      indicating membership of each sample in the cluster.
        partition: Key/index for the desired cluster in the indexes structure.
        class_names (list or dict): Mapping from label indices to class names.
        sample_count (int): Number of random images to select and plot. Defaults to 8.
        seed (int, optional): Seed for random number generator (for reproducibility).
    """
    
    # Optionally set random seed for reproducibility
    if seed is not None:
        np.random.seed(seed)
    
    # Get the indices of images in the desired cluster
    cluster_indices = np.where(indexes[partition] == True)[0]
    if len(cluster_indices) == 0:
        print("No images found in this cluster!")
        return
    
    # Adjust sample_count if the cluster has fewer than sample_count images
    actual_sample_count = min(sample_count, len(cluster_indices))
    
    # Randomly select the desired number of indices from the cluster
    selected_indices = np.random.choice(cluster_indices, size=actual_sample_count, replace=False)
    
    # Determine grid size: here we use 4 columns
    cols = 4
    rows = int(np.ceil(actual_sample_count / cols))
    
    # Create subplots
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    # In case there's only one row, wrap axes in a list
    if rows == 1:
        axes = np.array(axes).reshape(-1)
    else:
        axes = axes.flatten()
    
    # Plot each randomly selected image with its label
    for ax, idx in zip(axes, selected_indices):
        idx = int(idx)
        image = ds['validation'][idx]['image']
        label = ds['validation'][idx]['label']
        title_str = f"{class_names[label]}"
        ax.imshow(image)
        ax.set_title(title_str, fontsize=fontsize)
        ax.axis("off")
    
    # Turn off any extra axes (if any)
    for ax in axes[actual_sample_count:]:
        ax.axis("off")
    
    plt.tight_layout()
    plt.show()

In [5]:
from tqdm.auto import tqdm

# Define hook function
intermediate_features = {}
def get_features(name):
    def hook(model, input, output):
        intermediate_features[name] = output.detach()
    return hook

hook_handle = resnet18.avgpool.register_forward_hook(get_features('avgpool'))

all_inputs = []
all_features = []
all_preds = []
all_labels = []

resnet18.eval()
with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(tqdm(val_loader)):
        images = images.to(device).float()
        preds = resnet18(images)
        features = intermediate_features['avgpool'].squeeze()
        all_inputs.append(images.cpu())
        all_features.append(features.cpu())
        all_preds.append(preds.cpu())
        all_labels.append(labels)

inputs_tensor = torch.cat(all_inputs, dim=0)
features_tensor = torch.cat(all_features, dim=0)
preds_tensor = torch.cat(all_preds, dim=0)
labels_tensor = torch.cat(all_labels, dim=0)

hook_handle.remove()

  0%|          | 0/782 [00:00<?, ?it/s]

In [6]:
from local_corex import partition_data
from nn_utils import compute_cluster_accuracies
from nn_plotting import plot_perturved_accuracy_resnet, plot_logit_effects
import pickle
from_scratch = False

if from_scratch:
    print('computing clusters from scratch')
    num_clusters=100
    indexes = partition_data(features_tensor, n_partitions=num_clusters, phate_dim=10, n_jobs=-2, seed=42)
    print('saving clusters')
    with open('resnet_100_indexes.pkl', 'wb') as f:
        pickle.dump(indexes, f)
else:
    print("loading clusters")
    with open('resnet_100_indexes.pkl', 'rb') as f:
        indexes = pickle.load(f)

Install CUDA and cudamat (for python) to enable GPU speedups.
loading clusters


In [7]:
import numpy as np
from collections import defaultdict

if isinstance(labels_tensor, torch.Tensor):
    labels = labels_tensor.numpy()
cluster_counts = defaultdict(dict)

for cluster_id, mask in enumerate(indexes):
    cluster_labels = labels[mask]
    unique, counts = np.unique(cluster_labels, return_counts=True)
    cluster_counts[cluster_id] = dict(zip(unique, counts))

for cluster_id, counts in cluster_counts.items():
    sorted_counts = sorted(counts.items(), key=lambda x: -x[1])
    total = sum(counts.values())
    # percentages = sorted_counts
    percentages = {k: f"{(v/total)*100:.1f}%" for k, v in sorted_counts}
    print(f"Cluster {cluster_id} ({total} samples):")
    print("  Class distribution:", percentages)

Cluster 0 (944 samples):
  Class distribution: {159: '5.2%', 243: '5.0%', 253: '4.8%', 242: '4.7%', 172: '4.6%', 211: '4.6%', 168: '4.3%', 171: '4.3%', 158: '4.1%', 254: '4.1%', 180: '4.0%', 246: '3.8%', 179: '3.5%', 209: '3.3%', 245: '3.3%', 173: '3.1%', 151: '2.8%', 195: '2.5%', 237: '2.5%', 208: '2.2%', 176: '2.1%', 236: '2.1%', 163: '2.0%', 167: '1.4%', 178: '1.3%', 210: '1.2%', 165: '1.1%', 162: '1.0%', 227: '0.8%', 234: '0.8%', 161: '0.7%', 225: '0.6%', 262: '0.6%', 164: '0.5%', 676: '0.5%', 166: '0.4%', 238: '0.4%', 170: '0.3%', 182: '0.3%', 215: '0.3%', 263: '0.3%', 268: '0.3%', 205: '0.2%', 214: '0.2%', 217: '0.2%', 241: '0.2%', 264: '0.2%', 267: '0.2%', 273: '0.2%', 811: '0.2%', 184: '0.1%', 191: '0.1%', 204: '0.1%', 212: '0.1%', 216: '0.1%', 226: '0.1%', 232: '0.1%', 235: '0.1%', 239: '0.1%', 240: '0.1%', 285: '0.1%', 330: '0.1%', 338: '0.1%', 354: '0.1%', 434: '0.1%', 465: '0.1%', 656: '0.1%', 678: '0.1%', 750: '0.1%', 808: '0.1%', 882: '0.1%', 912: '0.1%'}
Cluster 1 (587 s

# Looking at Partition 58

In [None]:
from local_corex import LinearCorex

partition=58
x = np.concatenate([features_tensor[indexes[partition]], preds_tensor[indexes[partition]]], axis=1)

corex_58 = LinearCorex(20)
corex_58.fit(x)
print(corex_58.tcs)

In [None]:
weights = models.ResNet18_Weights.DEFAULT
class_names = weights.meta['categories']

plot_random_cluster_samples(ds, indexes, 58, class_names, sample_count=12, seed=124)

## Factor 0

In [None]:
base_accuracies, base_probs = compute_cluster_accuracies(
    resnet18, val_loader, device, indexes, return_probs=True
)
diff_probs = plot_perturved_accuracy_resnet(
    resnet18, 
    corex_58, 
    val_loader, 
    indexes,
    device, 
    base_accuracies=base_accuracies, 
    base_probs=base_probs, 
    factor_num=0, 
    num_drop=100,
    hidden_dim=512
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=2,
    top_vals=8
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[28]], dim=0), 
    class_names, 
    bottom_vals=5,
    top_vals=5
)

Look at how if you were to subtract off the bird features you would most likely notice the stripe pattern of a zebra

In [None]:
plot_random_cluster_samples(ds, indexes, 28, class_names, sample_count=12, seed=124)

In [None]:

plot_logit_effects(
    torch.mean(diff_probs[~indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=5, 
    top_vals=5
)

## Factor 1

In [None]:
base_accuracies, base_probs = compute_cluster_accuracies(
    resnet18, val_loader, device, indexes, return_probs=True
)
diff_probs = plot_perturved_accuracy_resnet(
    resnet18, 
    corex_58, 
    val_loader, 
    indexes,
    device, 
    base_accuracies=base_accuracies, 
    base_probs=base_probs, 
    factor_num=1, 
    num_drop=100,
    hidden_dim=512
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=2,
    top_vals=8
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[79]], dim=0), 
    class_names, 
    bottom_vals=5,
    top_vals=5
)

Notice the classes affected have a net or white line across. The African Grey is often pictured with a perch or rope that has similar features.

In [None]:
plot_random_cluster_samples(ds, indexes, 79, class_names, sample_count=12, seed=148)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[~indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=5, 
    top_vals=5
)

# Partition 78

In [None]:
# from local_corex.latent_transformer import CorExWrapper

partition=78
x = np.concatenate([features_tensor[indexes[partition]], preds_tensor[indexes[partition]]], axis=1)

corex_78 = LinearCorex(20)
corex_78.fit(x)
print(corex_78.tcs)

In [None]:
plot_random_cluster_samples(ds, indexes, partition, class_names, sample_count=12, seed=148)

## Factor 0

In [None]:
diff_probs = plot_perturved_accuracy_resnet(
    resnet18, 
    corex_78, 
    val_loader, 
    indexes,
    device, 
    base_accuracies=base_accuracies, 
    base_probs=base_probs, 
    factor_num=0, 
    num_drop=100,
    hidden_dim=512
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=2,
    top_vals=8
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[15]], dim=0), 
    class_names, 
    bottom_vals=5,
    top_vals=5
)

I think this feature was used to find the stem of the fruit/flower and can look a bit like the ropes/leather of a muzzle/cart/plow

In [None]:
plot_random_cluster_samples(ds, indexes, 15, class_names, sample_count=12, seed=148)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[~indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=5, 
    top_vals=5
)

## Factor 1

In [None]:
diff_probs = plot_perturved_accuracy_resnet(
    resnet18, 
    corex_78, 
    val_loader, 
    indexes,
    device, 
    base_accuracies=base_accuracies, 
    base_probs=base_probs, 
    factor_num=1, 
    num_drop=100,
    hidden_dim=512
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=2,
    top_vals=8
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[31]], dim=0), 
    class_names, 
    bottom_vals=5,
    top_vals=5
)

My guess is that this feature relates to the tail of the stingray, the tubes of the scuba diver, the tentacles of the jellyfish, and the actual part of the snorkal. This could look like a stem or vine connecting the plant.

In [None]:
plot_random_cluster_samples(ds, indexes, 31, class_names, sample_count=12, seed=125)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[~indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=5, 
    top_vals=5
)

## Factor 2

In [None]:
diff_probs = plot_perturved_accuracy_resnet(
    resnet18, 
    corex_78, 
    val_loader, 
    indexes,
    device, 
    base_accuracies=base_accuracies, 
    base_probs=base_probs, 
    factor_num=2, 
    num_drop=100,
    hidden_dim=512
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=2,
    top_vals=8
)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[indexes[82]], dim=0), 
    class_names, 
    bottom_vals=5,
    top_vals=5
)

In [None]:
plot_random_cluster_samples(ds, indexes, 82, class_names, sample_count=12, seed=124)

In [None]:
plot_logit_effects(
    torch.mean(diff_probs[~indexes[partition]], dim=0), 
    class_names, 
    bottom_vals=5, 
    top_vals=5
)