# Evaluating your SAE

Code based off Rob Graham's ([themachinefan](https://github.com/themachinefan)) SAE evaluation code.

In [23]:
import os
os.getcwd()

'/workspace/ViT-Prisma/src/vit_prisma/sae/evals'

In [24]:
import einops
import torch
import torchvision

import plotly.express as px

from tqdm import tqdm

import numpy as np
import os
import requests

# Setup

In [3]:
from dataclasses import dataclass
from vit_prisma.sae.config import VisionModelSAERunnerConfig


@dataclass
class EvalConfig(VisionModelSAERunnerConfig):
    sae_path: str = '/workspace/sae_checkpoints/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-0.0001/n_images_2600058.pt'
    model_name: str = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
    model_type: str =  "clip"
    patch_size: str = 32

    dataset_path = "/workspace"
    dataset_train_path: str = "/workspace/ILSVRC/Data/CLS-LOC/train"
    dataset_val_path: str = "/workspace/ILSVRC/Data/CLS-LOC/val"

    verbose: bool = True

    device: bool = 'cuda'

    eval_max: int = 50_000 # 50_000
    batch_size: int = 32

    # make the max image output folder a subfolder of the sae path


    @property
    def max_image_output_folder(self) -> str:
        # Get the base directory of sae_checkpoints
        sae_base_dir = os.path.dirname(os.path.dirname(self.sae_path))
        
        # Get the name of the original SAE checkpoint folder
        sae_folder_name = os.path.basename(os.path.dirname(self.sae_path))
        
        # Create a new folder path in sae_checkpoints/images with the original name
        output_folder = os.path.join(sae_base_dir, 'max_images', sae_folder_name)
        output_folder = os.path.join(output_folder, f"layer_{self.hook_point_layer}") # Add layer number

        
        # Ensure the directory exists
        os.makedirs(output_folder, exist_ok=True)
        
        return output_folder

cfg = EvalConfig()

n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 158691
Total training images: 13000000
Total wandb updates: 15869
Expansion factor: 16
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 158 times.
Number tokens in sparsity calculation window: 4.10e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder


In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x721aa6c01660>

## Load model

In [5]:
from vit_prisma.models.base_vit import HookedViT

model_name = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True).to(cfg.device)
 

model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
Official model name open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
Converting OpenCLIP weights
model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
visual projection shape torch.Size([768, 512])
Setting center_writing_weights to False for OpenCLIP
Setting fold_ln to False for OpenCLIP
Loaded pretrained model open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K into HookedTransformer


## Load datasets

In [6]:
import importlib
import vit_prisma
# importlib.reload(vit_prisma.dataloaders.imagenet_dataset)

In [7]:
# load dataset
import open_clip
from vit_prisma.utils.data_utils.imagenet_utils import setup_imagenet_paths
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_transforms_clip, ImageNetValidationDataset

from torchvision import transforms
from transformers import CLIPProcessor

og_model_name = "hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
og_model, _, preproc = open_clip.create_model_and_transforms(og_model_name)
processor = preproc

size=224

data_transforms = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                     std=[0.26862954, 0.26130258, 0.27577711]),
])
    
imagenet_paths = setup_imagenet_paths(cfg.dataset_path)
imagenet_paths["train"] = "/workspace/ILSVRC/Data/CLS-LOC/train"
imagenet_paths["val"] = "/workspace/ILSVRC/Data/CLS-LOC/val"
imagenet_paths["val_labels"] = "/workspace/LOC_val_solution.csv"
imagenet_paths["label_strings"] = "/workspace/LOC_synset_mapping.txt"
print()
train_data = torchvision.datasets.ImageFolder(cfg.dataset_train_path, transform=data_transforms)
val_data = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'], 
                                data_transforms,
                                return_index=True,
)
val_data_visualize = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'],
                                torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),]), return_index=True)

print(f"Validation data length: {len(val_data)}") if cfg.verbose else None



Validation data length: 50000


In [8]:
from vit_prisma.sae.training.activations_store import VisionActivationsStore
# import dataloader
from torch.utils.data import DataLoader

# activations_loader = VisionActivationsStore(cfg, model, train_data, eval_dataset=val_data)
val_dataloader = DataLoader(val_data, batch_size=cfg.batch_size, shuffle=False, num_workers=4)


## Load pretrained SAE to evaluate

In [9]:
from vit_prisma.sae.sae import SparseAutoencoder
sparse_autoencoder = SparseAutoencoder(cfg).load_from_pretrained("/workspace/sae_checkpoints/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-0.0001/n_images_2600058.pt")
sparse_autoencoder.to(cfg.device)
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who 


get_activation_fn received: activation_fn=relu, kwargs={}
n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 158691
Total training images: 13000000
Total wandb updates: 1586
Expansion factor: 64
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 158 times.
Number tokens in sparsity calculation window: 4.10e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder
get_activation_fn received: activation_fn=relu, kwargs={}


SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
  (activation_fn): ReLU()
)

## Clip Labeling AutoInterp

In [10]:
# all_imagenet_class_names

In [11]:
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_index_to_name
ind_to_name = get_imagenet_index_to_name()

all_imagenet_class_names = []
for i in range(len(ind_to_name)):
    all_imagenet_class_names.append(ind_to_name[str(i)][1])

In [12]:
cfg.max_image_output_folder

'/workspace/sae_checkpoints/max_images/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-0.0001/layer_9'

## Feature steering

In [13]:
def standard_replacement_hook_curry(feat_idx: int = 0, feat_activ: float = 1.0):
    def standard_replacement_hook(activations: torch.Tensor, hook):
        activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
        feature_acts = sparse_autoencoder.encode_standard(activations)

        # in all batches and patches, set feature w idx idx to 0
        print(f"feature_acts[:,:,idx].shape: {feature_acts[:,:,feat_idx].shape}")
        print(f"feat activ: {feature_acts[:,:,feat_idx]}")
        feature_acts[:,:,feat_idx] *= feat_activ
        print(f"feat activ: {feature_acts[:,:,feat_idx]}")
        print(f"feat activ: {feature_acts.shape}")
        print(f"feat activ: {feature_acts}")
        print("feature_acts[:,:,idx].sum(): (should be batch size x len seq x feat val)", feature_acts[:,:,feat_idx].sum())
        sae_out = sparse_autoencoder.hook_sae_out(
            einops.einsum(
                feature_acts,
                sparse_autoencoder.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            )
            + sparse_autoencoder.b_dec
        )
        
        print(f"sae_out.shape: {sae_out.shape}")
        print(f"sae_out: {sae_out}")

        # allows normalization. Possibly identity if no normalization
        sae_out = sparse_autoencoder.run_time_activation_norm_fn_out(sae_out)
        return sae_out
    return standard_replacement_hook


def steering_hook_fn(
    activations, cfg, hook, sae, steering_indices, steering_strength=1.0, mean_ablation_values=None, include_error=False

):
    sae.to(activations.device)


    sae_input = activations.clone()
    sae_output, feature_activations, *data = sae(sae_input)
    
    steered_feature_activations = feature_activations.clone()
    
    steered_feature_activations[:, :, steering_indices] = steering_strength

    steered_sae_out = einops.einsum(
                steered_feature_activations,
                sae.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            ) + sae.b_dec

    steered_sae_out = sae.run_time_activation_norm_fn_out(steered_sae_out)
    
    print(steered_sae_out.shape)
    print(steered_sae_out.shape)
    print(f"steering norm: {(steered_sae_out - sae_output).norm()}")
    
    

    if include_error:
        error = sae_input - sae_output
        print(f"error.norm(): {error.norm()}")
        return steered_sae_out + error
    return steered_sae_out

In [14]:
random_feat_idxs = np.random.randint(0, high=3000, size=(10))

In [15]:
# for a given feature, set it high/low on maxim activ. imgs and high/low on non-activ images
# hook SAE and replace desired feature with 0 or 1 
from typing import List, Dict, Tuple
import torch
import einops
from tqdm import tqdm

from functools import partial

@torch.no_grad()
def compute_feature_activations_set_feat(
    images: torch.Tensor,
    model: torch.nn.Module,
    sparse_autoencoder: torch.nn.Module,
    encoder_weights: torch.Tensor,
    encoder_biases: torch.Tensor,
    feature_ids: List[int],
    feature_categories: List[str],
    top_k: int = 10
):
    """
    Compute the highest activating tokens for given features in a batch of images.
    
    Args:
        images: Input images
        model: The main model
        sparse_autoencoder: The sparse autoencoder
        encoder_weights: Encoder weights for selected features
        encoder_biases: Encoder biases for selected features
        feature_ids: List of feature IDs to analyze
        feature_categories: Categories of the features
        top_k: Number of top activations to return per feature

    Returns:
        Dictionary mapping feature IDs to tuples of (top_indices, top_values)
    """
    _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
#     recons_image_embeddings_feat_altered = model.run_with_hooks(
#         images,
#         fwd_hooks=[("blocks.9.hook_mlp_out", standard_replacement_hook)],
#     )
    recons_image_embeddings_feat_altered_list = []
    for idx in np.array(range(sparse_autoencoder.W_dec.shape[0]))[random_feat_idxs]:
        print(f"Feature: {idx} ====================")
        
        steering_hook = partial(
            steering_hook_fn,
            cfg=cfg,
            sae=sparse_autoencoder,
            steering_indices=[idx],
            steering_strength=10.0,
            mean_ablation_values = [1.0],
            include_error=True,
            )
        
        
        recons_image_embeddings_feat_altered = model.run_with_hooks(
            images,
#             fwd_hooks=[("blocks.9.hook_mlp_out", standard_replacement_hook_curry(idx, 10.0))],
            fwd_hooks=[("blocks.9.hook_mlp_out", steering_hook)],
        )
        recons_image_embeddings_feat_altered_list.append(recons_image_embeddings_feat_altered)

    
    # output is in clip embedding space
    recons_image_embeddings_default = model.run_with_hooks(
        images,
        fwd_hooks=[("blocks.9.hook_mlp_out", lambda x, hook: x)],
    )
    
    print(f"recons_image_embeddings_default: {recons_image_embeddings_default}")
    print(f"recons_image_embeddings_default.shape: {recons_image_embeddings_default.shape}")
    print(f"recons_image_embeddings_default: {recons_image_embeddings_default.shape}")

    print(f"recons_image_embeddings_feat_altered: {recons_image_embeddings_feat_altered}")
    print(f"recons_image_embeddings_feat_altered.shape: {recons_image_embeddings_feat_altered.shape}")

    return recons_image_embeddings_feat_altered_list, recons_image_embeddings_default

In [16]:
max_samples = cfg.eval_max

# top_activations = {i: (None, None) for i in interesting_features_indices}
encoder_biases = sparse_autoencoder.b_enc#[interesting_features_indices]
encoder_weights = sparse_autoencoder.W_enc#[:, interesting_features_indices]

top_k=10
processed_samples = 0
for batch_images, _, batch_indices in tqdm(val_dataloader, total=max_samples // cfg.batch_size):
    batch_images = batch_images.to(cfg.device)
    batch_indices = batch_indices.to(cfg.device)
    batch_size = batch_images.shape[0]

    altered_embeds_list, default_embeds = compute_feature_activations_set_feat(
        batch_images, model, sparse_autoencoder, encoder_weights, encoder_biases,
        None, None, top_k
    )
    # either label embeds or optimize to maximal token in text transformer embedding face
    break

  0%|                                                                                                            | 0/1562 [00:00<?, ?it/s]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6396484375
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.2458190917969
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6795349121094
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 398.08526611328125
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.73394775390625
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.5750732421875
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6285400390625
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8177185058594
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8154907226

  0%|                                                                                                            | 0/1562 [00:01<?, ?it/s]


In [17]:
len(altered_embeds_list), altered_embeds_list[0].shape, default_embeds.shape

(10, torch.Size([32, 512]), torch.Size([32, 512]))

In [18]:
og_model.cuda()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine

In [19]:
# with open("/workspace/clip_dissect_raw.txt", "r") as f:
#     larger_vocab = [line[:-1] for line in f.readlines()][:5000]

with open("/workspace/better_img_desc.txt", "r") as f:
    larger_vocab = [line[:-1] for line in f.readlines()][:5000]

In [20]:
# use clip vocab here and compare embeds
import torch
from PIL import Image

tokenizer = open_clip.get_tokenizer('ViT-B-32')
text = tokenizer(larger_vocab)
text_features = og_model.encode_text(text.cuda())
text_features_normed = text_features/text_features.norm(dim=-1, keepdim=True)

text_probs_altered_list = []
# can probs make this one tensor operation
for altered_embeds in altered_embeds_list:
    with torch.no_grad(), torch.cuda.amp.autocast():
        # might want to still normalize
        
        # already normalized
        # altered_embeds /= altered_embeds.norm(dim=-1, keepdim=True)

        text_probs_altered = (100.0 * altered_embeds @ text_features_normed.T).softmax(dim=-1)
        text_probs_altered_list.append(text_probs_altered)
    # default_embds_norm = default_embeds.norm(dim=-1, keepdim=True)
    text_probs_default = (100.0 * default_embeds @ text_features_normed.T).softmax(dim=-1)

print("Label probs altered:", text_probs_altered.shape)  # prints: [[1., 0., 0.]]
print("Label probs default:", text_probs_default.shape)  # prints: [[1., 0., 0.]]

Label probs altered: torch.Size([32, 3498])
Label probs default: torch.Size([32, 3498])


In [21]:
text_probs_altered

tensor([[3.6889e-06, 9.2544e-07, 2.3448e-06,  ..., 8.9186e-06, 1.7155e-06,
         8.4263e-07],
        [1.7192e-08, 2.2639e-09, 2.1064e-08,  ..., 6.6941e-08, 6.9734e-09,
         6.6023e-09],
        [5.7456e-05, 1.4641e-05, 2.9346e-05,  ..., 2.1972e-07, 5.3250e-04,
         1.1558e-06],
        ...,
        [1.4036e-06, 9.0626e-07, 1.9599e-07,  ..., 9.8966e-06, 5.0836e-07,
         7.6913e-07],
        [2.4344e-06, 9.5878e-08, 2.0334e-08,  ..., 2.2159e-08, 5.6925e-07,
         9.1487e-08],
        [4.8636e-05, 1.3826e-05, 5.5311e-07,  ..., 5.4339e-08, 3.4958e-06,
         9.1908e-07]], device='cuda:0')

In [22]:
# subtract from default, label, and print trends
text_probs_altered.shape

# selected_vocab = all_imagenet_class_names
selected_vocab = larger_vocab

cov_stuff_avgs = []
# cov_stuff_avgs_least = []
for j, text_probs_altered in enumerate(text_probs_altered_list):
    print(f"\n\nFor Feature {random_feat_idxs[j]}")
    logit_diff = text_probs_altered - text_probs_default
    logit_ratio = text_probs_altered/text_probs_default
    
    vals, idxs = torch.topk(logit_diff,k=5)
    vals_least, idxs_least = torch.topk(logit_diff,k=5,largest=False)
    
    ratios, ratios_idxs = torch.topk(logit_ratio,k=5)
    ratios_least, ratios_idxs_least = torch.topk(logit_ratio,k=5,largest=False)
    
    cov_over_images = []
    for i in range(logit_diff.shape[0]//4):
#         print(f"\nImage {i} ========================\nMost Changed, by Absolute Diff\n:{vals[i]}")
#         print(np.array(all_imagenet_class_names)[idxs.cpu()][i])
#         print(vals_least[i])
#         print(np.array(all_imagenet_class_names)[idxs_least.cpu()][i])
        
        print("\nMost Changed, by Ratio:")
        print(ratios[i])
        print(np.array(selected_vocab)[ratios_idxs.cpu()][i])
        print(ratios_least[i])
        print(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        
        text = tokenizer(np.array(selected_vocab)[ratios_idxs.cpu()][i])
#         text_least = tokenizer(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        text_features = og_model.encode_text(text.cuda())
        cov_over_images.append(text_features)
#         text_features_least = og_model.encode_text(text_least.cuda())
        print(torch.tril(torch.cov(text_features), diagonal=-1))
        print(torch.tril(torch.cov(text_features), diagonal=-1).sum()/10)
#         cov_stuff = torch.tril(torch.cov(text_features), diagonal=-1).sum()/10
#         cov_stuff_least = torch.tril(torch.cov(text_features_least), diagonal=-1).sum()/10
    print(torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).shape)
    n = torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).shape[0]
    num_elements = (n**2)/2 - n
    cov_stuff = torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).sum()/num_elements
    cov_stuff_avgs.append(cov_stuff)
#         cov_stuff_avgs_least.append(cov_stuff_least)
    if j > 10:
        break
print(torch.tensor(cov_stuff_avgs).mean())
# print(torch.tensor(cov_stuff_avgs_least).mean())



For Feature 239

Most Changed, by Ratio:
tensor([15667.3857, 14397.1787, 10037.7832,  9383.1768,  5207.5737],
       device='cuda:0')
['Detailed illustration of a futuristic virtual reality'
 'Detailed illustration of a futuristic virtual realm'
 'Detailed illustration of a historical scene'
 'Detailed illustration of a futuristic quantum realm'
 'Miniature diorama photography']
tensor([0.0025, 0.0146, 0.0147, 0.0154, 0.0204], device='cuda:0')
['Flowing lines' 'A trunk' 'Shy facial expression' 'A swirling eddy'
 'Image with a trio of friends']
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3106, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2572, 0.2823, 0.0000, 0.0000, 0.0000],
        [0.2033, 0.2179, 0.1660, 0.0000, 0.0000],
        [0.1702, 0.1911, 0.2120, 0.1151, 0.0000]], device='cuda:0')
tensor(0.2126, device='cuda:0')

Most Changed, by Ratio:
tensor([2575.2451, 2074.5334, 1748.1523, 1665.1611, 1567.6134],
       device='cuda:0')
['Picture captured in the Canadian 

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1876, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1852, 0.1840, 0.0000, 0.0000, 0.0000],
        [0.1584, 0.2208, 0.1624, 0.0000, 0.0000],
        [0.1918, 0.2615, 0.1796, 0.2346, 0.0000]], device='cuda:0')
tensor(0.1966, device='cuda:0')

Most Changed, by Ratio:
tensor([1124.7959,  603.2236,  577.7757,  539.7874,  439.4832],
       device='cuda:0')
['Image with elemental magic and water' 'intricate gemstone arrangement'
 'An image of a Waiter/Waitress' 'Secluded beach cove'
 'intricate gemstone display']
tensor([0.0144, 0.0330, 0.0351, 0.0420, 0.0452], device='cuda:0')
['an image of samoa' 'A zebra stripe pattern' 'A burst of rays'
 'Golden hour lighting' 'Time-lapse image']
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1853, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1858, 0.1714, 0.0000, 0.0000, 0.0000],
        [0.1534, 0.1474, 0.1502, 0.0000, 0.0000],
        [0.1800, 0.2710, 0.1787, 0.1403, 0.0000]], dev

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3134, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2491, 0.2257, 0.0000, 0.0000, 0.0000],
        [0.2999, 0.2992, 0.2368, 0.0000, 0.0000],
        [0.2592, 0.2367, 0.2340, 0.2218, 0.0000]], device='cuda:0')
tensor(0.2576, device='cuda:0')
torch.Size([40, 40])


For Feature 756

Most Changed, by Ratio:
tensor([370.9428, 281.9456, 157.2586, 149.1168, 146.9364], device='cuda:0')
['A frame from a movie' 'Artwork featuring retro TV test patterns'
 'A photo of a vibrant festival' 'A pentagon'
 'Photo featuring a vibrant cultural procession']
tensor([0.0079, 0.0092, 0.0113, 0.0121, 0.0132], device='cuda:0')
['Subdued beauty' 'Photo taken in the Hawaiian beaches'
 'Graceful wings in motion' 'Photo with high key lighting'
 'High-key contrast']
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2193, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2370, 0.2033, 0.0000, 0.0000, 0.0000],
        [0.1818, 0.1492, 0.1748, 0.0000, 0.0

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2199, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1855, 0.1885, 0.0000, 0.0000, 0.0000],
        [0.2359, 0.2796, 0.1944, 0.0000, 0.0000],
        [0.2450, 0.2558, 0.1841, 0.2569, 0.0000]], device='cuda:0')
tensor(0.2246, device='cuda:0')

Most Changed, by Ratio:
tensor([194.2768, 167.0349, 162.9559, 153.3828,  91.9009], device='cuda:0')
['Artificial lighting' 'weathered religious icon' 'Reflective surfaces'
 'cutting-edge technology' 'Translucent materials']
tensor([0.0056, 0.0094, 0.0110, 0.0121, 0.0132], device='cuda:0')
['A pendulum' 'Golden hour glow' 'Ocean sunset silhouette'
 'Photo taken in Machu Picchu, Peru' 'Image with a pink color']
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1608, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2336, 0.1522, 0.0000, 0.0000, 0.0000],
        [0.2332, 0.1670, 0.2236, 0.0000, 0.0000],
        [0.2263, 0.1674, 0.2377, 0.2181, 0.0000]], device='cuda:0')
tensor(0.2020, dev

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2089, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2135, 0.1829, 0.0000, 0.0000, 0.0000],
        [0.1976, 0.1863, 0.2306, 0.0000, 0.0000],
        [0.2508, 0.2375, 0.2114, 0.1990, 0.0000]], device='cuda:0')
tensor(0.2118, device='cuda:0')

Most Changed, by Ratio:
tensor([769.8625, 592.6477, 483.2134, 427.3602, 364.6651], device='cuda:0')
['tranquil beach sunset' 'Tranquil waterfall scene '
 'serene waterfall scene' 'Serene beach sunset '
 'Serene countryside sunrise']
tensor([0.0177, 0.0208, 0.0274, 0.0277, 0.0315], device='cuda:0')
['A photo of Monaco' 'Point of view from above' 'Enigmatic forms'
 'A spirograph-like shape' 'Image with octagon tessellation']
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1814, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1723, 0.3022, 0.0000, 0.0000, 0.0000],
        [0.2544, 0.1813, 0.1912, 0.0000, 0.0000],
        [0.2152, 0.2034, 0.2178, 0.2335, 0.0000]], device='cuda:0')
tens

tensor([0.0009, 0.0021, 0.0027, 0.0028, 0.0037], device='cuda:0')
['Play of light and shadow' 'A shadow' 'Dynamic shadows'
 'Dappled sunlight' 'Dramatic shadows']
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2035, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3267, 0.2009, 0.0000, 0.0000, 0.0000],
        [0.2132, 0.1751, 0.2134, 0.0000, 0.0000],
        [0.2190, 0.2056, 0.2297, 0.1830, 0.0000]], device='cuda:0')
tensor(0.2170, device='cuda:0')

Most Changed, by Ratio:
tensor([2921.9075, 1311.1094,  917.1436,  890.2854,  874.4189],
       device='cuda:0')
['Image with shattered crystal sculptures'
 'Image with shattered crystal structures'
 'Photo taken in the Swiss chocolate factories'
 'Image with shattered crystal shards'
 'Image snapped in the Swiss chocolate factories']
tensor([0.0119, 0.0179, 0.0223, 0.0250, 0.0283], device='cuda:0')
['Hands in an embrace' 'Majestic animal' 'A hand' 'Subdued beauty'
 "Nature's embrace"]
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.000

## Stop

In [None]:
top_k_imgnet_labels = 10

In [None]:
from matplotlib import pyplot as plt
from collections import defaultdict, Counter

feat_autolabels_default = defaultdict(Counter)
for i in range(text_probs_default.shape[0]):
    vals, idxs = torch.topk(text_probs_default[i],k=top_k_imgnet_labels)
    print(i)
    for k, idx in enumerate(idxs):
        feat_autolabels_default[i][all_imagenet_class_names[idx]] += vals[k]
        print("\t", all_imagenet_class_names[idx])
feat_autolabels_default

In [None]:
# subtract from default, label, and print trends
text_probs_altered.shape

for text_probs_altered in text_probs_altered_list:
    logit_diff = text_probs_altered - text_probs_default
    print(logit_diff)
    vals, idxs = torch.topk(logit_diff,k=5)
    print(vals, np.array(all_imagenet_class_names[idxs])
    break

In [None]:
all_imagenet_class_names

In [None]:
from collections import defaultdict, Counter

feat_autolabels_altered_list = []
for text_probs_altered in text_probs_altered_list:
    feat_autolabels_altered = defaultdict(Counter)
    for i in range(text_probs_altered.shape[0]):
        vals, idxs = torch.topk(text_probs_altered[i],k=top_k_imgnet_labels)
#         print(i)
        for k, idx in enumerate(idxs):
            feat_autolabels_altered[i][all_imagenet_class_names[idx]] += vals[k]
#             print("\t", all_imagenet_class_names[idx])
    feat_autolabels_altered_list.append(feat_autolabels_altered)

start_idx = 9
end_idx = 10
    
h = 0
for key in feat_autolabels_default:
    print(f"\nfeat_autolabels_default img {key}:\n {feat_autolabels_default[key]}\n")
    h += 1
    if h > end_idx:
        break
for i, f_a_a in enumerate(feat_autolabels_altered_list):
    print("============= feature number ", i, "====================")
    h = 0
    for key in range(start_idx, end_idx):
#         print("\n", key)
#         for item in f_a_a[key]:
#             print("\t", item, f_a_a[key][item].cpu().item())
        print(f"\nf_a_a img {key}:\n {f_a_a[key]}\n")


In [None]:
for i in range(text_probs_default.shape[0]):
    vals, idxs = torch.topk(text_probs_default[i],k=1000)
    print(i, ind_to_name[str(idxs[0].cpu().item())][1])
    fig, ax = plt.subplots(figsize=(10, 10))
#     ax.xaxis.set_ticks((1000))
#     ax.set_xticks(list(range(1000)), [ind_to_name[str(idxs[idx].cpu().item())][1] for idx in idxs])
    plt.bar(idxs.cpu(), vals.cpu(), width=5)
    break

In [None]:
plt.imshow(batch_images[2].cpu().permute((1,2,0)).numpy())

In [None]:

@torch.no_grad()
def get_heatmap(
          image,
          model,
          sparse_autoencoder,
          feature_id,
): 
    image = image.to(cfg.device)
    _, cache = model.run_with_cache(image.unsqueeze(0))

    post_reshaped = einops.rearrange(cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic
    print(f"sae_in.shape: {sae_in.shape}")
    acts = einops.einsum(
            sae_in,
            sparse_autoencoder.W_enc[:, feature_id],
            "x d_in, d_in -> x",
        )
    return acts 
     
def image_patch_heatmap(activation_values,image_size=224, pixel_num=14):
    activation_values = activation_values.detach().cpu().numpy()
    activation_values = activation_values[1:]
    activation_values = activation_values.reshape(pixel_num, pixel_num)

    # Create a heatmap overlay
    heatmap = np.zeros((image_size, image_size))
    patch_size = image_size // pixel_num

    for i in range(pixel_num):
        for j in range(pixel_num):
            heatmap[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = activation_values[i, j]

    return heatmap


In [None]:
from matplotlib import pyplot as plt

grid_size = 1
fig, axs = plt.subplots(int(np.ceil(len(images)/grid_size)), grid_size, figsize=(15, 15))
name=  f"Category: uhh,  Feature: {0}"
fig.suptitle(name)#, y=0.95)
for ax in axs.flatten():
    ax.axis('off')
complete_bid = []

heatmap = get_heatmap(batch_images[2], model,sparse_autoencoder, 10000)
heatmap = image_patch_heatmap(heatmap, pixel_num=224//cfg.patch_size)

display = batch_images[2].cpu().numpy().transpose(1, 2, 0)

has_zero = False

In [None]:
plt.imshow(display)
plt.imshow(heatmap, alpha=0.3)

In [None]:
plt.imshow(display)
plt.imshow(heatmap, alpha=0.3)