# Imports

In [2]:
# install dependencies (once)
!pip install transformers  # maybe also datasets, torchvision, etc.
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-ud2kuc5g
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-ud2kuc5g
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=49a697fa5da442d90ca1ab751e8c131d274206df2d9808be40a12ea6d7f1c31a
  Stored in directory: /tmp/pip-ephem-wheel-cache-qr4l25_9/wheels/35/3e/df/3d24cbfb3b6a06f17

In [3]:
import torch
from transformers import CLIPModel, CLIPProcessor, AutoModel

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision
from torchvision import transforms
from tqdm import tqdm
import clip

device = "cuda"

# Helper Functions

In [4]:


def unnormalize(x):
    """
    x: tensor of shape (B,3,H,W) or (3,H,W), CLIP-normalized.
    Returns: unnormalized image in [0,1].
    """
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073],
                        device=x.device)[None, :, None, None]
    std  = torch.tensor([0.26862954, 0.26130258, 0.27577711],
                        device=x.device)[None, :, None, None]

    # If x is (3,H,W), add batch dimension
    if x.dim() == 3:
        x = x.unsqueeze(0)

    x_unnorm = x * std + mean
    return x_unnorm.clamp(0, 1)

def layer_norm(x):
    ln = nn.LayerNorm(768, elementwise_affine=False).to(device)
    return ln(x)

def print_stats(x):
    print("mean:", x.mean().item())
    print("std:", x.std().item())
    print("min:", x.min().item())
    print("max:", x.max().item())


# Models

##Sparse Autoencoder

In [5]:
'''
Model Architecture (taken from HF)
Input Dimension: 768
SAE Dimension: 49,152
Expansion Factor: x64 (vanilla architecture)
Activation Function: ReLU
Initialization: encoder_transpose_decoder
Context Size: 50 tokens
'''

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=49152):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.encoder = nn.Parameter(torch.zeros(input_dim, hidden_dim))
        self.decoder = nn.Parameter(torch.zeros(hidden_dim, input_dim))

        self.b_enc = nn.Parameter(torch.zeros(hidden_dim))
        self.b_dec = nn.Parameter(torch.zeros(input_dim))

    def encode(self, x):
        return self.forward(x)["acts"]

    def decode(self, acts):
        recon = acts @ self.decoder + self.b_dec
        return recon

    def forward(self, x):
        encoded = x @ self.encoder + self.b_enc
        acts = F.relu(encoded)

        # Decode using tied weights
        recon = acts @ self.decoder + self.b_dec

        return {
            "acts": acts,
            "reconstruction": recon
        }


In [6]:
def load_sae_from_pt(path):
    raw_state = torch.load(path, map_location="cpu")

    # Infer sizes
    input_dim, hidden_dim = raw_state["W_enc"].shape
    sae = SparseAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim)

    # Adapt keys to match your model's expected names
    state = {
        "encoder": raw_state["W_enc"],
        "decoder": raw_state["W_dec"],
        "b_enc":   raw_state["b_enc"],
        "b_dec":   raw_state["b_dec"],
    }

    sae.load_state_dict(state)
    return sae

In [7]:
model = load_sae_from_pt("weights.pt").to(device)

model.eval()

SparseAutoencoder()

In [8]:
state = torch.load("weights.pt", map_location="cpu")


w = state["W_enc"]
print("mean:", w.mean().item())
print("std:", w.std().item())
print("min:", w.min().item())
print("max:", w.max().item())

mean: -5.68432587897405e-05
std: 0.02647518552839756
min: -0.6612777709960938
max: 0.7138082385063171


##CLIP Classifier

In [9]:
# model
class CLIPClassifier(nn.Module):
    def __init__(self, clip_model_name="ViT-B/32", num_classes=10, device=device):
        super().__init__()

        # Load CLIP
        self.clip_model, _ = clip.load(clip_model_name, device=device)
        self.clip_model = self.clip_model.float()

        # freeze CLIP parameters
        for p in self.clip_model.parameters():
            p.requires_grad = False

        # create classification head
        embed_dim = self.clip_model.visual.output_dim  # 512 for ViT-B/32
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, images):
        # extract CLIP image features
        feats = self.clip_model.encode_image(images)

        # classifier
        logits = self.classifier(feats)
        return logits

    def classify(self, feats):
        return self.classifier(feats)

    def ln(self, x):
        return self.clip_model.visual.ln_post(x)

    def precompute(self, x):
        return self.clip_model.encode_image(x)

    def head(self, images, split_at=11):
        V = self.clip_model.visual

        # Patch embedding
        x = V.conv1(images)
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)  # [B, N, C]

        # CLS token
        cls = V.class_embedding.to(x.dtype)
        cls = cls.unsqueeze(0).unsqueeze(1).expand(x.shape[0], -1, -1)  # [B,1,C]
        x = torch.cat([cls, x], dim=1)  # [B, 1+N, C]

        # Positional embedding + pre-LN
        x = x + V.positional_embedding.to(x.dtype)
        x = V.ln_pre(x)

        # Switch to CLIP transformer format: [L, B, C]
        x = x.permute(1, 0, 2)

        # Run layers 0..split_at
        for i in range(split_at + 1):
            x = V.transformer.resblocks[i](x)

        # Convert back to batch-first for user-facing return
        x = x.permute(1, 0, 2)  # [B, L, C]

        return x

    def tail(self, x, split_at=11):
        V = self.clip_model.visual

        # x arrives as batch-first: [B, L, C]
        # Convert to transformer format
        x = x.permute(1, 0, 2)  # [L, B, C]

        # Run remaining layers
        for i in range(split_at + 1, len(V.transformer.resblocks)):
            x = V.transformer.resblocks[i](x)

        # Back to batch-first
        x = x.permute(1, 0, 2)  # [B, L, C]

        # Final post-LN should operate on the sequence
        x = V.ln_post(x)

        # CLS token
        cls = x[:, 0, :]

        # Projection
        if hasattr(V, "proj") and V.proj is not None:
            if isinstance(V.proj, torch.nn.Parameter):
                cls = cls @ V.proj
            else:
                cls = V.proj(cls)

        # Classifier
        logits = self.classifier(cls)

        return logits

In [10]:
clip_classifier = CLIPClassifier().to(device)
clip_classifier.eval()

100%|███████████████████████████████████████| 338M/338M [00:05<00:00, 60.8MiB/s]


CLIPClassifier(
  (clip_model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQu

In [11]:
clip_classifier.classifier.load_state_dict(torch.load("classifier_weights.pt"), strict=True)

<All keys matched successfully>

# Data

In [12]:
# load cifar-10

def get_cifar10_loaders(batch_size=128, num_workers=0, val_ratio=0.1):
    """
    Returns train, val, and test DataLoaders for CIFAR-10.
    Train split is further split into train + val.
    """
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # CLIP expects 224x224
        transforms.ToTensor(),
        transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073),
            (0.26862954, 0.26130258, 0.27577711)
        )
    ])

    torch.manual_seed(67)

    # Load the full 50k CIFAR-10 training set
    full_train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform,)

    # Split indices
    val_size = int(len(full_train_dataset) * val_ratio)  # 5000 if val_ratio=0.1
    train_size = len(full_train_dataset) - val_size      # 45000

    train_dataset, val_dataset = torch.utils.data.random_split(
        full_train_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)  # for reproducibility
)

    # Test dataset (10k)
    test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True,transform=transform)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers,pin_memory=True)
    val_loader = DataLoader(val_dataset,batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader, test_loader, test_dataset

train_loader, val_loader, test_loader, test_dataset = get_cifar10_loaders()

100%|██████████| 170M/170M [00:04<00:00, 35.6MB/s]


# Train classifier

## CLIP Embeddings (for training classification head)

In [8]:
# precompute train embeddings

all_feats = []
all_labels = []

print("Precomputing CLIP embeddings...")
with torch.no_grad():
    for images, labels in tqdm(train_loader):
        images = images.to(device)
        feats = clip_classifier.precompute(images)  # shape [batch, 512]

        all_feats.append(feats.cpu())
        all_labels.append(labels)

all_feats = torch.cat(all_feats)
all_labels = torch.cat(all_labels)

print("Feature tensor shape:", all_feats.shape)

# dataset of precomputed embeddings
feature_dataset = TensorDataset(all_feats, all_labels)
feature_loader = DataLoader(feature_dataset, batch_size=128, shuffle=True)

Precomputing CLIP embeddings...


100%|██████████| 352/352 [01:50<00:00,  3.19it/s]

Feature tensor shape: torch.Size([45000, 512])





In [9]:
# precompute test embeddings

all_feats = []
all_labels = []

print("Test CLIP embeddings...")
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        feats = clip_classifier.precompute(images)  # shape [batch, 512]

        all_feats.append(feats.cpu())
        all_labels.append(labels)

all_feats = torch.cat(all_feats)
all_labels = torch.cat(all_labels)

print("Feature tensor shape:", all_feats.shape)

# dataset of precomputed embeddings
test_feature_dataset = TensorDataset(all_feats, all_labels)
test_feature_loader = DataLoader(test_feature_dataset, batch_size=128, shuffle=True)

Test CLIP embeddings...


100%|██████████| 79/79 [00:24<00:00,  3.23it/s]

Feature tensor shape: torch.Size([10000, 512])





## Training loop

In [None]:
# training loop
optimizer = optim.Adam(clip_classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
num_epochs = 20


for epoch in range(num_epochs):
    clip_classifier.train()
    avg_loss = 0

    for images, labels in tqdm(feature_loader):
        images, labels = images.to(device), labels.to(device)

        logits = clip_classifier.classify(images)

        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

    print(f"Epoch {epoch+1}, Average Loss: {avg_loss/images.shape[0]:.4f}")

In [20]:
# evaluation

clip_classifier.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_feature_loader:
        images, labels = images.to(device), labels.to(device)
        logits = clip_classifier.classify(images)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print("Accuracy:", correct / total)

torch.save(clip_classifier.classifier.state_dict(), "classifier_weights.pt")

Accuracy: 0.9447


# Projected Gradient Descent

In [13]:
# pgd

def PGD(
    h,                      # [B, T, D] clean tokens
    y,                      # labels
    model_tail,             # tail(z) -> logits
    loss_fn=nn.CrossEntropyLoss(),
    scale=False,
    eps=0.1,
    alpha=1e-3,
    num_steps=40,
):
    """
    PGD attack in CLIP token space.
    h must be [B, T, D] and output will match exactly.
    """

    # detach ensures h is a constant baseline, but keep shape
    h_orig = h.detach()

    if scale:
        scale_hp = h_orig.pow(2).mean().sqrt() # rms norm over ALL tokens
        eps = eps * scale_hp
        alpha = eps * scale_hp

    # initialize perturbation with correct shape
    delta = torch.zeros_like(h_orig).uniform_(-eps, eps).to(device)

    logits = model_tail(h_orig)
    l0 = loss_fn(logits, y).detach().cpu().item()

    for _ in range(num_steps):
        delta.requires_grad_(True)

        h_adv = h_orig + delta          # preserve shape [B, T, D]
        logits = model_tail(h_adv)
        loss = loss_fn(logits, y)

        # compute gradient w.r.t delta only
        (grad,) = torch.autograd.grad(loss, delta)

        # PGD update in L∞
        delta = delta + alpha * grad.sign()

        # project back into L∞ ball — preserves shape
        delta = torch.clamp(delta, -eps, eps)

        # detach to prevent graph buildup
        delta = delta.detach()

    logits = model_tail(h_adv)
    l_final = loss_fn(logits, y).detach().cpu().item()
    loss_diff = l_final-l0

    # print(f"Loss difference:{loss_diff} \n")
    # final adv tokens (same shape as h_orig)
    return (h_orig + delta).detach(), loss_diff

# Get activations and features

In [50]:
'''
Input: x (image)

Output:
- h & SAE(h) : base
- PGD(h) & SAE(h_adv) : function 1 aka which features set the model off the most
- h(x_adv) & SAE(h(x_adv)) : function 2 aka which features get set off by perturbations the most
'''

def function_1(x, y, model, clip_classifier, alpha=1e-2, print_stats=False):
    h = clip_classifier.head(x)
    h_features = model.encode(layer_norm(h))
    h_adv, loss_diff = PGD(h, y, clip_classifier.tail, scale=True)
    h_adv_features = model.encode(layer_norm(h_adv))

    if print_stats:
        x = layer_norm(h_adv)
        ('Activation shape: ' + h.shape)
        print("Distribution before SAE")
        print_stats(x)

    return {'h': h.detach().cpu(),
            'h_features': h_features.detach().cpu(),
            'h_adv':h_adv.detach().cpu(),
            'h_adv_features': h_adv_features.detach().cpu(),
            'loss_diff': loss_diff}

def function_2(x, y, model, clip_classifier, alpha = 1e-2, print_stats=False):
    x_adv, loss_diff = PGD(x, y, clip_classifier)
    hx_adv = clip_classifier.head(x_adv)
    hx_adv_features = model.encode(layer_norm(hx_adv))


    if print_stats:
        x = layer_norm(hx_adv)
        print('Activation shape: ' + hx_adv.shape)
        print("Distribution before SAE")
        print_stats(x)

    return {'x_adv': x_adv.detach().cpu(),
            'hx_adv':hx_adv.detach().cpu(),
            'hx_adv_features': hx_adv_features.detach().cpu(),
            'loss_diff': loss_diff}

# non-functional
def get_all_outputs(x, y, model, clip_classifier):
    dict_1 = function_1(x, y, model, clip_classifier)
    dict_2 = function_2(x, y, model, clip_classifier)
    return {**dict_1, **dict_2}

# Data collection (2hr)

## Tests

In [201]:
test_iter = iter(test_loader)

In [None]:
image, label, idx = next(test_iter)

image, label, idx = image[0].unsqueeze(0).to(device), label[0].unsqueeze(0).to(device), idx

In [268]:
sample = function_1(image, label, model, clip_classifier)
sample_2 = function_2(image, label, model, clip_classifier)

print(sample['h_adv_features'].shape)
print(sample_2['hx_adv_features'].shape)

torch.Size([1, 50, 49152])
torch.Size([1, 50, 49152])


In [1]:
import io
from contextlib import redirect_stdout
import os
import shutil
from datetime import datetime

from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [None]:
def count_feature_fires(
    class_loaders,
    model,
    clip_classifier,
    activation_type: str,
    function=function_2,
    threshold=0.1,
):
    """
    Returns:
        class_feature_counts: dict[label -> (num_features,) tensor]
    """

    # Detect number of features from model
    num_features = model.hidden_dim

    # Dictionary: class_label → feature firing counts
    class_feature_counts = {}

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    save_dir = f"/content/drive/MyDrive/deep_learning/text_outputs_{activation_type}_{threshold}_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)

    # Loop over all classes
    for label, loader in class_loaders.items():
        print(f"\n===== Evaluating Class {label} =====")

        # Initialize class-specific counter
        feature_counts = torch.zeros(num_features, dtype=torch.long)
        loss_diff_list = []

        # Loop over batches
        for x_batch, y_batch in loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            out = function(x_batch, y_batch, model, clip_classifier)

            loss_diff_list.append(out['loss_diff'])
            feats = out[activation_type]       # (B, tokens, features)
            binary_feats = (feats > threshold).sum(dim=(0,1))

            feature_counts += binary_feats.cpu()

            # free GPU memory
            del x_batch, y_batch, feats, binary_feats, out
            torch.cuda.empty_cache()

        # Save per-class results
        class_feature_counts[label] = feature_counts


        # ---- Print top 10 features for that class ----
        top_vals, top_idx = torch.topk(feature_counts, 50)

        buf = io.StringIO()
        buf2 = io.StringIO()

        with redirect_stdout(buf):
            print(f"Top 50 firing features for class {label} [rank, feature, count]:")
            for rank in range(50):
                print(f"  [{rank+1}, {top_idx[rank].item()}, {top_vals[rank].item()}]")

        with redirect_stdout(buf2):
            print(f"Loss difference from perturbations for class {label} + \n")
            for loss in loss_diff_list:
                print(str(loss) + '\n')

        text_output = buf.getvalue() # contain the prints yurr
        loss_output = buf2.getvalue()

        filepath = os.path.join(save_dir, f"top_50_{activation_type}_class_{label}.txt")
        filepath_loss = os.path.join(save_dir, f"loss_diff_values_class_{label}.txt")

        with open(filepath, "w") as f:
            f.write(text_output)
        with open(filepath_loss, "w") as f:
            f.write(loss_output)

    zip_path = shutil.make_archive(f"{save_dir}_zip", "zip", save_dir)
    print("Saved ZIP to:", zip_path)

    return class_feature_counts

In [None]:
def build_class_loaders(class_images, batch_size=100):
    class_loaders = {}

    for label, imgs in class_images.items():
        imgs_tensor = torch.stack(imgs)   # (N, C, H, W)
        ds = TensorDataset(imgs_tensor, torch.full((len(imgs),), label))
        class_loaders[label] = DataLoader(ds, batch_size=batch_size, shuffle=False)

    return class_loaders

In [None]:
# sort test_dataset by class
from collections import defaultdict

class_images = defaultdict(list)

for img, label in test_dataset:
    class_images[label].append(img)

In [None]:
class_loaders = build_class_loaders(class_images, batch_size=100)

In [None]:
# varying threshold

thresholds = torch.linspace(start=1, end=3, steps=5).tolist()

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

my_path = f"/content/drive/MyDrive/deep_learning/feature_counts_{timestamp}"
os.makedirs(my_path, exist_ok=True)

activation_types = ['hx_adv_features', 'h_features', 'h_adv_features']# 2 x 5 x 24 min = 2 hrs
for threshold in thresholds: # 5 x 24 min
    for activation_type in activation_types: # 3 activations
        if activation_type == 'hx_adv_features':
            function = function_2
        else:
            function = function_1

        feature_counts = count_feature_fires( # approx 24 mins
        class_loaders=class_loaders,
        model=model,
        clip_classifier=clip_classifier,
        activation_type=activation_type,
        function=function,
        threshold=threshold)

        threshold_str = f"{threshold:.2f}"


        torch.save(feature_counts, f"{my_path}/{activation_type}_counts_{threshold_str}.pt")


In [259]:

if (sample_2['hx_adv'] != 0).any():
    print("Tensor has at least one non-zero value.")
else:
    print("Tensor is all zeros.")

Tensor has at least one non-zero value.


## Tune the alphas

In [None]:
# alpha tuning

alphas = np.linspace(0.5, 1e-3, num=10).tolist()

loss_1_diffs = []
loss_2_diffs = []

for alpha in alphas:
    sample = function_1(image, label, model, clip_classifier, alpha=alpha)
    loss_1_diffs.append(sample['loss_diff'])
    sample_2 = function_2(image, label, model, clip_classifier, alpha=alpha)
    loss_2_diffs.append(sample_2['loss_diff'])

print(alphas)
print(loss_1_diffs)
print(loss_2_diffs)

# Visualizations non-feature

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import math

### Heatmap of adversarial perturbations

In [None]:
def comparison_grid(images, h_batch, h_adv_batch, num_rows=10):
    """
    Create a grid of size `num_rows × 4`:
      col1: image
      col2: ‖h‖ heatmap
      col3: ‖h_adv‖ heatmap
      col4: ‖h_adv - h‖ heatmap

    Arguments:
        images:      tensor (B,3,H,W)  batch of input images
        h_batch:     tensor (B,50,dim) clean tokens
        h_adv_batch: tensor (B,50,dim) adversarial tokens
        num_rows:    number of images to include (default=10)

    Returns:
        fig, axes
    """
    B = images.shape[0]
    assert B >= num_rows, f"Batch only has {B} images but num_rows={num_rows}"

    # Create figure
    fig, axes = plt.subplots(num_rows, 4, figsize=(20, 4 * num_rows))

    if num_rows == 1:
        axes = axes.reshape(1, 4)

    for row in range(num_rows):
        # image
        img = unnormalize(images[row]).squeeze().cpu().numpy()  # (3,H,W)
        img = np.transpose(img, (1,2,0))              # (H,W,3)

        # clean and adv tokens
        h = h_batch[row].cpu().numpy()       # (50,dim)
        h_adv = h_adv_batch[row].cpu().numpy()

        # remove CLS
        h_img = h[1:]         # (49, dim)
        h_adv_img = h_adv[1:]

        # compute norms
        h_norm = np.linalg.norm(h_img, axis=-1)
        h_adv_norm = np.linalg.norm(h_adv_img, axis=-1)
        diff = np.linalg.norm(h_adv_img - h_img, axis=-1)

        # reshape into 7×7 patch grid
        h_map = h_norm.reshape(7,7)
        h_adv_map = h_adv_norm.reshape(7,7)
        diff_map = diff.reshape(7,7)

        # upscale to image-size for visualization
        def upscale(m):
            return cv2.resize(m, (224,224), interpolation=cv2.INTER_NEAREST)

        h_big = upscale(h_map)
        h_adv_big = upscale(h_adv_map)
        diff_big = upscale(diff_map)

        axes[row,0].imshow(img)
        axes[row,0].set_title(f"Image {row}")
        axes[row,0].axis("off")

        # clean heatmap
        im1 = axes[row,1].imshow(h_big, cmap="viridis")
        axes[row,1].set_title("‖h‖")
        axes[row,1].axis("off")
        fig.colorbar(im1, ax=axes[row,1], fraction=0.046, pad=0.04)

        # adv heatmap
        im2 = axes[row,2].imshow(h_adv_big, cmap="viridis")
        axes[row,2].set_title("‖h_adv‖")
        axes[row,2].axis("off")
        fig.colorbar(im2, ax=axes[row,2], fraction=0.046, pad=0.04)

        # diff heatmap
        im3 = axes[row,3].imshow(diff_big, cmap="viridis")
        axes[row,3].set_title("‖h_adv − h‖")
        axes[row,3].axis("off")
        fig.colorbar(im3, ax=axes[row,3], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()
    # return fig, axes

In [None]:

# for k, v in eval_dataset.items():
k, v = list(eval_dataset.items())[4]
graph_images = torch.stack(v, dim=0).to(device)
k_list = torch.full((len(v),), k, device=device)
sample = function_1(graph_images, k_list, model, clip_classifier)
sample_2 = function_2(graph_images, k_list, model, clip_classifier)

graph_h = sample['h'].detach().cpu().squeeze(0)
graph_h_adv = sample['h_adv'].detach().cpu().squeeze(0)

graph_x_adv = sample_2['x_adv'].detach().cpu().squeeze(0)
graph_hx_adv = sample_2['hx_adv'].detach().cpu().squeeze(0)

comparison_grid(graph_images, graph_h, graph_h_adv)
comparison_grid(graph_x_adv, graph_h, graph_hx_adv)








AttributeError: 'list' object has no attribute 'items'

In [None]:
def avg_norm_difference(h, h_adv):
    # h and h_adv have shape (50, 768)

    # 1. Remove CLS token (token 0)
    h_img = h[1:]              # (49, 768)
    h_adv_img = h_adv[1:]      # (49, 768)

    # 2. Compute per-token norms
    h_norms = torch.norm(h_img, dim=-1)          # (49,)
    h_adv_norms = torch.norm(h_adv_img, dim=-1)  # (49,)

    # 3. Average norms across all spatial tokens
    avg_h = h_norms.mean()
    avg_h_adv = h_adv_norms.mean()

    # 4. Average difference norm
    avg_diff = torch.norm(h_adv_img - h_img, dim=-1).mean()
    return avg_diff


In [None]:
print(avg_norm_difference(graph_h, graph_h_adv))
print(avg_norm_difference(graph_h, graph_hx_adv))

### Feature visualization of perturbations

In [None]:
# heatmap 1d: x, y = position of activation, color = how strongly a feature fires per patch (for h, h_adv)

import
def feature_visualization(h, features, k=5, cols=3):
    """
    Visualize CLIP patch activations (2D) and top-k SAE features (2D)
    in a grid layout with `cols` images per row.

    Parameters
    ----------
    h : tensor (50, 768)
        CLIP hidden activations. Token 0 = CLS, 1–49 = 7×7 patches.
    features : tensor (50, F)
        SAE encoder output features for each token.
    k : int
        Number of top SAE features to visualize.
    cols : int
        Number of heatmaps per row (default: 3).
    """

    # Prepare patch-level norm for h
    h_img = h[1:]                                 # (49, 768)
    h_norm = torch.norm(h_img, dim=-1).cpu()      # (49,)
    h_map = h_norm.reshape(7, 7).numpy()          # (7, 7)

    # Rank SAE features (top-k)
    feat_img = features[1:]                       # (49, F)
    feature_scores = feat_img.abs().sum(dim=0)    # (F,)
    topk_scores, topk_idx = torch.topk(feature_scores, k)

    # Total number of images: 1 (for h) + k (features)
    total = 1 + k
    rows = math.ceil(total / cols)

    # Create large grid figure
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    axes = axes.flatten()  # always flatten for easy indexing

    # First heatmap = h patch magnitude
    ax = axes[0]
    hm = ax.imshow(h_map, cmap="viridis")
    ax.set_title("CLIP h — Patch Magnitude", fontsize=12)
    ax.axis("off")
    fig.colorbar(hm, ax=ax, fraction=0.046, pad=0.04)

    # plot the top-k features in 2D
    for i in range(k):
        feat_id = topk_idx[i].item()
        score = topk_scores[i].item()

        activation = feat_img[:, feat_id].cpu().numpy()  # (49,)
        activation_2d = activation.reshape(7, 7)          # (7, 7)

        ax = axes[i+1]
        hm = ax.imshow(activation_2d, cmap="magma")
        ax.set_title(
            f"Top Feature #{i+1}\nID {feat_id} | Score {score:.2f}",
            fontsize=11
        )
        ax.axis("off")
        fig.colorbar(hm, ax=ax, fraction=0.046, pad=0.04)

    # Hide unused axes
    for j in range(total, len(axes)):
        axes[j].axis("off")

    plt.tight_layout()
    # return fig

# Analysis

## Analysis dataset

In [147]:
'''
structure:
dataset = [
  {
    "image"
    "label"
    "meta"
  },
...
]
'''

def get_images_per_class(dataset, num_per_class=5):
    counts = {label: 0 for label in range(10)}
    items = []  # list of dicts

    for idx in range(len(dataset)):
        img, label = dataset[idx]

        if counts[label] < num_per_class:
            item = {
                "image": img,
                "label": label,
                "meta": {}
            }
            items.append(item)
            counts[label] += 1

        # stop when all classes collected
        if all(counts[l] >= num_per_class for l in counts):
            break

    return items


In [160]:
num_per_class = 5
eval_dataset = get_images_per_class(test_dataset, num_per_class)

In [161]:
features = [1064, 2420, 2642, 5167, 6847, 7636, 8709, 9028, 10216, 13978, 16979,19030, 20359, 21971, 24248, 25461, 25989, 29216, 29390, 31041, 40471,43948, 44551, 47241]

feature_to_index = {f: i for i, f in enumerate(features)}

In [162]:
def get_feature_outputs(x, y, model, clip_classifier, feature_list):
  results = get_all_outputs(x, y, model, clip_classifier)
  results["h_features"] = results["h_features"][:, :, feature_list]
  results["h_adv_features"] = results["h_adv_features"][:,:, feature_list]
  results["hx_adv_features"] = results["hx_adv_features"][:,:, feature_list]
  return results

In [163]:
'''
for each image in the dataset:
  - pre-adversarial: x, h, SAE(h)
  - adversarial (on inputs): x_adv, h(x_adv), SAE(h(x_adv)) (shape 1, 50, 49k)
  - adversarial (on hidden state): h, h_adv, SAE(h_adv)
'''

for n in eval_dataset:
    n['meta'] = get_feature_outputs(n['image'].to(device).unsqueeze(0), torch.tensor(n['label']).unsqueeze(0).to(device),
                                model, clip_classifier, features)

In [158]:
eval_dataset = sorted(eval_dataset, key=lambda x: x["label"])

In [153]:
torch.save(eval_dataset, f"eval_dataset_{num_per_class}.pt")