# Evaluating your SAE

Code based off Rob Graham's ([themachinefan](https://github.com/themachinefan)) SAE evaluation code.

In [51]:
import os
os.getcwd()

'/workspace/ViT-Prisma/src/vit_prisma/sae/evals'

In [52]:
import einops
import torch
import torchvision

import plotly.express as px

from tqdm import tqdm

import numpy as np
import os
import requests

# Setup

In [53]:
from dataclasses import dataclass
from vit_prisma.sae.config import VisionModelSAERunnerConfig


@dataclass
class EvalConfig(VisionModelSAERunnerConfig):
    sae_path: str = '/workspace/sae_checkpoints/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-11-hook_resid_post-l1-0.0001/n_images_2600058.pt'
    model_name: str = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
    model_type: str =  "clip"
    patch_size: str = 32

    dataset_path = "/workspace"
    dataset_train_path: str = "/workspace/ILSVRC/Data/CLS-LOC/train"
    dataset_val_path: str = "/workspace/ILSVRC/Data/CLS-LOC/val"

    verbose: bool = True

    device: bool = 'cuda'

    eval_max: int = 50_000 # 50_000
    batch_size: int = 32

    # make the max image output folder a subfolder of the sae path


    @property
    def max_image_output_folder(self) -> str:
        # Get the base directory of sae_checkpoints
        sae_base_dir = os.path.dirname(os.path.dirname(self.sae_path))
        
        # Get the name of the original SAE checkpoint folder
        sae_folder_name = os.path.basename(os.path.dirname(self.sae_path))
        
        # Create a new folder path in sae_checkpoints/images with the original name
        output_folder = os.path.join(sae_base_dir, 'max_images', sae_folder_name)
        output_folder = os.path.join(output_folder, f"layer_{self.hook_point_layer}") # Add layer number

        
        # Ensure the directory exists
        os.makedirs(output_folder, exist_ok=True)
        
        return output_folder

cfg = EvalConfig()

n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 158691
Total training images: 13000000
Total wandb updates: 15869
Expansion factor: 16
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 158 times.
Number tokens in sparsity calculation window: 4.10e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder


In [54]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7ef6b8585e40>

## Load model

In [55]:
from vit_prisma.models.base_vit import HookedViT

model_name = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True).to(cfg.device)
 

model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
Official model name open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
Converting OpenCLIP weights
model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
visual projection shape torch.Size([768, 512])
Setting center_writing_weights to False for OpenCLIP
Setting fold_ln to False for OpenCLIP
Loaded pretrained model open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K into HookedTransformer


## Load datasets

In [56]:
import importlib
import vit_prisma
# importlib.reload(vit_prisma.dataloaders.imagenet_dataset)

In [57]:
# load dataset
import open_clip
from vit_prisma.utils.data_utils.imagenet_utils import setup_imagenet_paths
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_transforms_clip, ImageNetValidationDataset

from torchvision import transforms
from transformers import CLIPProcessor

og_model_name = "hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
og_model, _, preproc = open_clip.create_model_and_transforms(og_model_name)
processor = preproc

size=224

data_transforms = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                     std=[0.26862954, 0.26130258, 0.27577711]),
])
    
imagenet_paths = setup_imagenet_paths(cfg.dataset_path)
imagenet_paths["train"] = "/workspace/ILSVRC/Data/CLS-LOC/train"
imagenet_paths["val"] = "/workspace/ILSVRC/Data/CLS-LOC/val"
imagenet_paths["val_labels"] = "/workspace/LOC_val_solution.csv"
imagenet_paths["label_strings"] = "/workspace/LOC_synset_mapping.txt"
print()
train_data = torchvision.datasets.ImageFolder(cfg.dataset_train_path, transform=data_transforms)
val_data = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'], 
                                data_transforms,
                                return_index=True,
)
val_data_visualize = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'],
                                torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),]), return_index=True)

print(f"Validation data length: {len(val_data)}") if cfg.verbose else None



Validation data length: 50000


In [58]:
from vit_prisma.sae.training.activations_store import VisionActivationsStore
# import dataloader
from torch.utils.data import DataLoader

# activations_loader = VisionActivationsStore(cfg, model, train_data, eval_dataset=val_data)
val_dataloader = DataLoader(val_data, batch_size=cfg.batch_size, shuffle=False, num_workers=4)


## Load pretrained SAE to evaluate

In [59]:
from vit_prisma.sae.sae import SparseAutoencoder
sparse_autoencoder = SparseAutoencoder(cfg).load_from_pretrained("/workspace/sae_checkpoints/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-11-hook_resid_post-l1-0.0001/n_images_2600058.pt")
sparse_autoencoder.to(cfg.device)
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who 


get_activation_fn received: activation_fn=relu, kwargs={}
n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 158691
Total training images: 13000000
Total wandb updates: 1586
Expansion factor: 64
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 158 times.
Number tokens in sparsity calculation window: 4.10e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder
get_activation_fn received: activation_fn=relu, kwargs={}


SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
  (activation_fn): ReLU()
)

## Clip Labeling AutoInterp

In [60]:
# all_imagenet_class_names

In [61]:
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_index_to_name
ind_to_name = get_imagenet_index_to_name()

all_imagenet_class_names = []
for i in range(len(ind_to_name)):
    all_imagenet_class_names.append(ind_to_name[str(i)][1])

In [62]:
cfg.max_image_output_folder

'/workspace/sae_checkpoints/max_images/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-11-hook_resid_post-l1-0.0001/layer_9'

## Feature steering

In [63]:
def standard_replacement_hook_curry(feat_idx: int = 0, feat_activ: float = 1.0):
    def standard_replacement_hook(activations: torch.Tensor, hook):
        activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
        feature_acts = sparse_autoencoder.encode_standard(activations)

        # in all batches and patches, set feature w idx idx to 0
        print(f"feature_acts[:,:,idx].shape: {feature_acts[:,:,feat_idx].shape}")
        print(f"feat activ: {feature_acts[:,:,feat_idx]}")
        feature_acts[:,:,feat_idx] *= feat_activ
        print(f"feat activ: {feature_acts[:,:,feat_idx]}")
        print(f"feat activ: {feature_acts.shape}")
        print(f"feat activ: {feature_acts}")
        print("feature_acts[:,:,idx].sum(): (should be batch size x len seq x feat val)", feature_acts[:,:,feat_idx].sum())
        sae_out = sparse_autoencoder.hook_sae_out(
            einops.einsum(
                feature_acts,
                sparse_autoencoder.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            )
            + sparse_autoencoder.b_dec
        )
        
        print(f"sae_out.shape: {sae_out.shape}")
        print(f"sae_out: {sae_out}")

        # allows normalization. Possibly identity if no normalization
        sae_out = sparse_autoencoder.run_time_activation_norm_fn_out(sae_out)
        return sae_out
    return standard_replacement_hook


def steering_hook_fn(
    activations, cfg, hook, sae, steering_indices, steering_strength=1.0, mean_ablation_values=None, include_error=False

):
    sae.to(activations.device)


    sae_input = activations.clone()
    sae_output, feature_activations, *data = sae(sae_input)
    
    steered_feature_activations = feature_activations.clone()
    
    steered_feature_activations[:, :, steering_indices] = steering_strength

    steered_sae_out = einops.einsum(
                steered_feature_activations,
                sae.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            ) + sae.b_dec

    steered_sae_out = sae.run_time_activation_norm_fn_out(steered_sae_out)
    
    print(steered_sae_out.shape)
    print(steered_sae_out.shape)
    print(f"steering norm: {(steered_sae_out - sae_output).norm()}")
    
    

    if include_error:
        error = sae_input - sae_output
        print(f"error.norm(): {error.norm()}")
        return steered_sae_out + error
    return steered_sae_out

In [64]:
random_feat_idxs = np.random.randint(0, high=3000, size=(25))

In [89]:
# for a given feature, set it high/low on maxim activ. imgs and high/low on non-activ images
# hook SAE and replace desired feature with 0 or 1 
from typing import List, Dict, Tuple
import torch
import einops
from tqdm import tqdm

from functools import partial

@torch.no_grad()
def compute_feature_activations_set_feat(
    images: torch.Tensor,
    model: torch.nn.Module,
    sparse_autoencoder: torch.nn.Module,
    encoder_weights: torch.Tensor,
    encoder_biases: torch.Tensor,
    feature_ids: List[int],
    feature_categories: List[str],
    top_k: int = 10
):
    """
    Compute the highest activating tokens for given features in a batch of images.
    
    Args:
        images: Input images
        model: The main model
        sparse_autoencoder: The sparse autoencoder
        encoder_weights: Encoder weights for selected features
        encoder_biases: Encoder biases for selected features
        feature_ids: List of feature IDs to analyze
        feature_categories: Categories of the features
        top_k: Number of top activations to return per feature

    Returns:
        Dictionary mapping feature IDs to tuples of (top_indices, top_values)
    """
    _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
#     recons_image_embeddings_feat_altered = model.run_with_hooks(
#         images,
#         fwd_hooks=[("blocks.9.hook_mlp_out", standard_replacement_hook)],
#     )
    recons_image_embeddings_feat_altered_list = []
    for idx in np.array(range(sparse_autoencoder.W_dec.shape[0]))[random_feat_idxs]:
        print(f"Feature: {idx} ====================")
        
        steering_hook = partial(
            steering_hook_fn,
            cfg=cfg,
            sae=sparse_autoencoder,
            steering_indices=[idx],
            steering_strength=50.0,
            mean_ablation_values = [1.0],
            include_error=True,
            )
        
        
        recons_image_embeddings_feat_altered = model.run_with_hooks(
            images,
#             fwd_hooks=[("blocks.9.hook_mlp_out", standard_replacement_hook_curry(idx, 10.0))],
            fwd_hooks=[("blocks.9.hook_mlp_out", steering_hook)],
        )
        recons_image_embeddings_feat_altered_list.append(recons_image_embeddings_feat_altered)

    
    # output is in clip embedding space
    recons_image_embeddings_default = model.run_with_hooks(
        images,
        fwd_hooks=[("blocks.9.hook_mlp_out", lambda x, hook: x)],
    )
    
    print(f"recons_image_embeddings_default: {recons_image_embeddings_default}")
    print(f"recons_image_embeddings_default.shape: {recons_image_embeddings_default.shape}")
    print(f"recons_image_embeddings_default: {recons_image_embeddings_default.shape}")

    print(f"recons_image_embeddings_feat_altered: {recons_image_embeddings_feat_altered}")
    print(f"recons_image_embeddings_feat_altered.shape: {recons_image_embeddings_feat_altered.shape}")

    return recons_image_embeddings_feat_altered_list, recons_image_embeddings_default

In [90]:
from collections import defaultdict
max_samples = cfg.eval_max

# top_activations = {i: (None, None) for i in interesting_features_indices}
encoder_biases = sparse_autoencoder.b_enc#[interesting_features_indices]
encoder_weights = sparse_autoencoder.W_enc#[:, interesting_features_indices]

top_k=10
processed_samples = 0
default_embeds_list = []
feature_steered_embeds = defaultdict(list)
l = 0
for batch_images, _, batch_indices in tqdm(val_dataloader, total=max_samples // cfg.batch_size):
    batch_images = batch_images.to(cfg.device)
    batch_indices = batch_indices.to(cfg.device)
    batch_size = batch_images.shape[0]

    altered_embeds_list, default_embeds = compute_feature_activations_set_feat(
        batch_images, model, sparse_autoencoder, encoder_weights, encoder_biases,
        None, None, top_k
    )
    default_embeds_list.append(default_embeds)
    for j, altered_embeds in enumerate(altered_embeds_list):
        feature_steered_embeds[random_feat_idxs[j]].extend(altered_embeds)
    # either label embeds or optimize to maximal token in text transformer embedding face
    l += 1
    if l >= 5:
        break

  0%|                                                                                                            | 0/1562 [00:00<?, ?it/s]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 1999.88671875
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0001220703125
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.000244140625
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0001220703125
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0
error.norm(): 561.3306884765625
torc

  0%|                                                                                                    | 1/1562 [00:01<49:29,  1.90s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0
error.norm(): 561.3306884765625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0001220703125
error.norm(): 561.3306884765625
recons_image_embeddings_default: tensor([[ 0.0352,  0.0083, -0.0740,  ..., -0.0311,  0.0275,  0.0019],
        [-0.0101, -0.0539, -0.0622,  ...,  0.0199, -0.0555, -0.0743],
        [-0.0206,  0.0059, -0.0366,  ..., -0.0307,  0.0756, -0.0016],
        ...,
        [ 0.0099, -0.0045, -0.0059,  ..., -0.0521,  0.0647, -0.0225],
        [-0.0422,  0.0518, -0.0482,  ...,  0.0098,  0.0418,  0.0290],
        [-0.0411, -0.0590,  0.0014,  ..., -0.0432, -0.0089, -0.0449]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0528, -0.0427, -0.0113,  ..., -0.0205,  0.0527,  0.0599],
        [ 0.0414, -0.0500, -0.0162,  ..., -0.0282,  0.0233,  

  0%|▏                                                                                                   | 2/1562 [00:03<43:04,  1.66s/it]

recons_image_embeddings_default: tensor([[ 0.0146, -0.0148, -0.0460,  ...,  0.0118,  0.0082,  0.0083],
        [-0.0018,  0.0212, -0.0113,  ...,  0.0519, -0.0585, -0.0361],
        [-0.0171, -0.0393, -0.0432,  ...,  0.0160,  0.0028,  0.0136],
        ...,
        [-0.0224, -0.0082, -0.0361,  ..., -0.0352,  0.0784,  0.0265],
        [-0.0062,  0.0247, -0.0572,  ...,  0.0121, -0.0083,  0.0222],
        [-0.0130,  0.0321, -0.0363,  ...,  0.0437,  0.0279, -0.0109]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0335, -0.0481, -0.0024,  ..., -0.0238,  0.0495,  0.0592],
        [ 0.0450, -0.0409,  0.0037,  ..., -0.0079,  0.0234,  0.0526],
        [ 0.0535, -0.0529, -0.0222,  ..., -0.0185,  0.0463,  0.0485],
        ...,
        [ 0.0366, -0.0527, -0.0098,  ..., -0.0252,  0.0559,  0.0742],
        [ 0.0292, -0.0445, -0.0072,  ..., -0.0206,  0.0475,  0

  0%|▏                                                                                                   | 3/1562 [00:04<40:56,  1.58s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0001220703125
error.norm(): 559.0997314453125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0
error.norm(): 559.0997314453125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0001220703125
error.norm(): 559.0997314453125
recons_image_embeddings_default: tensor([[ 0.0294,  0.0383,  0.0048,  ..., -0.0036,  0.0256,  0.0279],
        [ 0.0004,  0.0353, -0.0868,  ..., -0.0146,  0.0002,  0.0059],
        [ 0.0709, -0.0185, -0.0175,  ...,  0.0050,  0.0293,  0.0257],
        ...,
        [-0.0168, -0.0003, -0.0274,  ..., -0.0302,  0.0601, -0.0477],
        [ 0.0075,  0.0213, -0.0235,  ..., -0.0346,  0.0216,  0.0487],
        [ 0.0059, -0.0119, -0.0019,  ...,  0.0249, -0.0424,  0.0157]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 4.6

  0%|▎                                                                                                   | 4/1562 [00:06<39:56,  1.54s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 2000.0001220703125
error.norm(): 564.427001953125
recons_image_embeddings_default: tensor([[ 0.0495,  0.0061, -0.0375,  ..., -0.0073, -0.0049,  0.0464],
        [ 0.0656,  0.0185, -0.0169,  ..., -0.0542,  0.0806,  0.0280],
        [ 0.0439,  0.0136,  0.0194,  ..., -0.0279,  0.0640, -0.0370],
        ...,
        [ 0.0259,  0.0402, -0.0065,  ..., -0.0289,  0.0129,  0.0450],
        [ 0.0245,  0.0248, -0.0074,  ..., -0.0344,  0.0273, -0.0038],
        [ 0.0167,  0.0346, -0.0975,  ...,  0.0074,  0.0849, -0.0346]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0520, -0.0430, -0.0186,  ..., -0.0307,  0.0407,  0.0708],
        [ 0.0430, -0.0385, -0.0082,  ..., -0.0425,  0.0528,  0.0652],
        [ 0.0313, -0.0452,  0.0096,  ..., -0.0211,  0.0687,  0.0537],
        ...,
        [ 0.033

  0%|▎                                                                                                   | 4/1562 [00:07<51:40,  1.99s/it]


In [91]:
len(feature_steered_embeds[random_feat_idxs[0]])

160

In [92]:
default_embeds.shape
len(default_embeds_list)
default_embeds = torch.cat(default_embeds_list)
default_embeds.shape

torch.Size([160, 512])

In [93]:
len(altered_embeds_list), altered_embeds_list[0].shape, default_embeds.shape

(25, torch.Size([32, 512]), torch.Size([160, 512]))

In [94]:
og_model.cuda()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine

In [95]:
with open("/workspace/clip_dissect_raw.txt", "r") as f:
    larger_vocab = [line[:-1] for line in f.readlines()][:5000]

# with open("/workspace/better_img_desc.txt", "r") as f:
#     larger_vocab = [line[:-1] for line in f.readlines()][:5000]

In [96]:
# use clip vocab here and compare embeds
import torch
from PIL import Image

tokenizer = open_clip.get_tokenizer('ViT-B-32')
text = tokenizer(larger_vocab)
text_features = og_model.encode_text(text.cuda())
text_features_normed = text_features/text_features.norm(dim=-1, keepdim=True)


print(f"text_features_normed.shape: {text_features_normed.shape}")
text_probs_altered_list = []
# can probs make this one tensor 
with torch.no_grad(), torch.cuda.amp.autocast():
    for key in feature_steered_embeds:
        print(key)
        # embeds already have L2 norm of 1
        text_probs_altered = (100.0 * torch.stack(feature_steered_embeds[key]) @ text_features_normed.T).softmax(dim=-1)
        text_probs_altered_list.append(text_probs_altered)
    text_probs_default = (100.0 * default_embeds @ text_features_normed.T).softmax(dim=-1)


# for altered_embeds in altered_embeds_list:
#     with torch.no_grad(), torch.cuda.amp.autocast():
#         # might want to still normalize
        
#         # already normalized
#         # altered_embeds /= altered_embeds.norm(dim=-1, keepdim=True)

#         text_probs_altered = (100.0 * altered_embeds @ text_features_normed.T).softmax(dim=-1)
#         text_probs_altered_list.append(text_probs_altered)
#     # default_embds_norm = default_embeds.norm(dim=-1, keepdim=True)
#     text_probs_default = (100.0 * default_embeds @ text_features_normed.T).softmax(dim=-1)

print("Label probs altered:", text_probs_altered.shape)  # prints: [[1., 0., 0.]]
print("Label probs default:", text_probs_default.shape)  # prints: [[1., 0., 0.]]

text_features_normed.shape: torch.Size([5000, 512])
919
258
1266
995
1706
996
391
2750
2807
1757
747
2681
912
164
2021
1331
1763
146
2769
929
949
283
1780
2180
1323
Label probs altered: torch.Size([160, 5000])
Label probs default: torch.Size([160, 5000])


In [97]:
feature_steered_embeds.keys()

dict_keys([919, 258, 1266, 995, 1706, 996, 391, 2750, 2807, 1757, 747, 2681, 912, 164, 2021, 1331, 1763, 146, 2769, 929, 949, 283, 1780, 2180, 1323])

In [98]:
text_probs_altered

tensor([[3.0703e-03, 3.9035e-04, 2.0367e-05,  ..., 5.0409e-05, 2.4376e-05,
         1.5616e-05],
        [3.1685e-03, 2.6834e-04, 2.2027e-05,  ..., 1.1451e-04, 2.5353e-05,
         1.5139e-05],
        [1.8051e-03, 2.0253e-04, 8.9683e-06,  ..., 7.5091e-05, 3.8054e-05,
         6.0683e-06],
        ...,
        [2.2572e-03, 1.3556e-04, 5.7952e-06,  ..., 5.1451e-05, 2.1115e-05,
         3.1451e-05],
        [3.1749e-03, 1.4169e-04, 8.5761e-06,  ..., 1.8480e-04, 4.1559e-05,
         1.0105e-05],
        [4.2913e-03, 3.5502e-04, 9.9926e-06,  ..., 1.8855e-04, 2.1488e-05,
         9.4608e-06]], device='cuda:0')

### Summed Logit Difference

In [101]:
# subtract from default, label, and print trends
text_probs_altered.shape

# selected_vocab = all_imagenet_class_names
selected_vocab = larger_vocab

# switch to datacomp eval??
# we run this for each feature over all of imagenet eval and create average absolute/ratio
# difference vectors
for j, text_probs_altered in enumerate(text_probs_altered_list):
    print(f"{'============================================'*2}\n\nFor Feature {random_feat_idxs[j]}")
    logit_diff = text_probs_altered - text_probs_default
#     print(f"logit_diff.shape: {logit_diff.shape}")
    logit_diff_aggregate = logit_diff.sum(dim=0)
#     print(f"logit_diff_aggregate.shape: {logit_diff_aggregate.shape}")
    
    logit_ratio = text_probs_altered/text_probs_default
    logit_ratio_aggregate = logit_ratio.mean(dim=0)
    
    print(f"text_probs_altered.softmax(): {text_probs_altered.softmax(1).shape}")
    text_probs_altered_softmax = text_probs_altered.softmax(1)
    vals_softmax, idxs_softmax = torch.topk(text_probs_altered_softmax,k=10)
    
    print(f"\nSoftmax Over {text_probs_altered.shape[0]} Images:\n{vals_softmax}")
    print(np.array(selected_vocab)[idxs_softmax.cpu()])
    for i in range(vals_softmax.shape[0]):
        print(vals_softmax[i], "\n", np.array(selected_vocab)[idxs_softmax.cpu()][i])
        break
    
    vals_agg, idxs_agg = torch.topk(logit_diff_aggregate,k=10)
    vals_least_agg, idxs_least_agg = torch.topk(logit_diff_aggregate,k=10,largest=False)
    
    ratios_agg, ratios_idxs_agg = torch.topk(logit_ratio_aggregate,k=10)
    ratios_least_agg, ratios_idxs_least_agg = torch.topk(logit_ratio_aggregate,k=10,largest=False)
    
    vals, idxs = torch.topk(logit_diff,k=5)
    vals_least, idxs_least = torch.topk(logit_diff,k=5,largest=False)
    
    ratios, ratios_idxs = torch.topk(logit_ratio,k=5)
    ratios_least, ratios_idxs_least = torch.topk(logit_ratio,k=5,largest=False)
    
    print(f"\nMost Changed, by Absolute Diff Over {logit_diff.shape[0]} Images:\n{vals_agg}")
    print(np.array(selected_vocab)[idxs_agg.cpu()])
    print(vals_least_agg)
    print(np.array(selected_vocab)[idxs_least_agg.cpu()])
    
    print(f"\nMost Changed, by Ratio Over {logit_diff.shape[0]} Images:")
    print(ratios_agg)
    print(np.array(selected_vocab)[ratios_idxs_agg.cpu()])
    print(vals_least_agg)
    print(np.array(selected_vocab)[ratios_idxs_least_agg.cpu()])
        
    
#     for i in range(logit_diff.shape[0]):
# #         print(f"\nImage {i} ========================\nMost Changed, by Absolute Diff\n:{vals[i]}")
# #         print(np.array(selected_vocab)[idxs.cpu()][i])
# #         print(vals_least[i])
# #         print(np.array(selected_vocab)[idxs_least.cpu()][i])
#         print(f"\nImage {i} Most Changed, by Ratio:")
#         print(ratios[i])
#         print(np.array(selected_vocab)[ratios_idxs.cpu()][i])
#         print(ratios_least[i])
#         print(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        
# #         text = tokenizer(np.array(selected_vocab)[ratios_idxs.cpu()][i])
# #         text_least = tokenizer(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
# #         text_features = og_model.encode_text(text.cuda())
# #         text_features_least = og_model.encode_text(text_least.cuda())


For Feature 919
text_probs_altered.softmax(): torch.Size([160, 5000])

Softmax Over 160 Images:
tensor([[0.0003, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0003, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        ...,
        [0.0004, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0003, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002]],
       device='cuda:0')
[['parts' 'coverage' 'kit' ... 'layout' 'dodge' 'seat']
 ['parts' 'coverage' 'headlines' ... 'set' 'trailers' 'shown']
 ['parts' 'kit' 'wing' ... 'dodge' 'trailers' 'paint']
 ...
 ['parts' 'templates' 'custom' ... 'trailers' 'modified' 'automotive']
 ['parts' 'kit' 'set' ... 'wing' 'seat' 'shown']
 ['parts' 'coverage' 'kit' ... 'headlines' 'dodge' 'shown']]
torch.Size([10]) (10,)
tensor([0.0003, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002,
        0.0002], d


Softmax Over 160 Images:
tensor([[0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        ...,
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002]],
       device='cuda:0')
[['birth' 'john' 'tom' ... 'asian' 'lingerie' 'matt']
 ['tom' 'want' 'john' ... 'feet' 'ski' 'jon']
 ['tom' 'birth' 'want' ... 'jennifer' 'princess' 'womens']
 ...
 ['want' 'tom' 'gun' ... 'launched' 'these' 'matt']
 ['tom' 'want' 'birth' ... 'matt' 'these' 'walker']
 ['want' 'tom' 'john' ... 'korea' 'these' 'oh']]
torch.Size([10]) (10,)
tensor([0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002,
        0.0002], device='cuda:0') ['birth' 'john' 'tom' 'feet' 'want' 'womens' 'jon' 'asian' 'lingerie'
 'matt']

Most Changed, by A

## Stop

### Covariance Analysis

In [None]:
# subtract from default, label, and print trends
text_probs_altered.shape

# selected_vocab = all_imagenet_class_names
selected_vocab = larger_vocab

cov_stuff_avgs = []

# switch to datacomp eval??
# we run this for each feature over all of imagenet eval and create average absolute/ratio
# difference vectors
per_feat_avg_vectors = []
for j, text_probs_altered in enumerate(text_probs_altered_list):
    print(f"\n\nFor Feature {random_feat_idxs[j]}")
    print(f"\n\nFor Feature {random_feat_idxs[j]}")
    logit_diff = text_probs_altered - text_probs_default
    logit_ratio = text_probs_altered/text_probs_default
    
    vals, idxs = torch.topk(logit_diff,k=5)
    vals_least, idxs_least = torch.topk(logit_diff,k=5,largest=False)
    
    ratios, ratios_idxs = torch.topk(logit_ratio,k=5)
    ratios_least, ratios_idxs_least = torch.topk(logit_ratio,k=5,largest=False)
    
    feat_avg_vectors = []
    for i in range(logit_diff.shape[0]):
#         print(f"\nImage {i} ========================\nMost Changed, by Absolute Diff\n:{vals[i]}")
#         print(np.array(all_imagenet_class_names)[idxs.cpu()][i])
#         print(vals_least[i])
#         print(np.array(all_imagenet_class_names)[idxs_least.cpu()][i])
        
        print("\nMost Changed, by Ratio:")
        print(ratios[i])
        print(np.array(selected_vocab)[ratios_idxs.cpu()][i])
#         print(ratios_least[i])
#         print(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        
        text = tokenizer(np.array(selected_vocab)[ratios_idxs.cpu()][i])
#         text_least = tokenizer(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        text_features = og_model.encode_text(text.cuda())
        cov_over_images.append(text_features)
#         text_features_least = og_model.encode_text(text_least.cuda())
        print(torch.tril(torch.cov(text_features), diagonal=-1))
        print(torch.tril(torch.cov(text_features), diagonal=-1).sum()/10)
#         cov_stuff = torch.tril(torch.cov(text_features), diagonal=-1).sum()/10
#         cov_stuff_least = torch.tril(torch.cov(text_features_least), diagonal=-1).sum()/10
    print(torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).shape)
    n = torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).shape[0]
    num_elements = (n**2)/2 - n
    cov_stuff = torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).sum()/num_elements
    cov_stuff_avgs.append(cov_stuff)
#         cov_stuff_avgs_least.append(cov_stuff_least)
    if j > 10:
        break
print(torch.tensor(cov_stuff_avgs).mean())
# print(torch.tensor(cov_stuff_avgs_least).mean())

In [None]:
top_k_imgnet_labels = 10

In [None]:
from matplotlib import pyplot as plt
from collections import defaultdict, Counter

feat_autolabels_default = defaultdict(Counter)
for i in range(text_probs_default.shape[0]):
    vals, idxs = torch.topk(text_probs_default[i],k=top_k_imgnet_labels)
    print(i)
    for k, idx in enumerate(idxs):
        feat_autolabels_default[i][all_imagenet_class_names[idx]] += vals[k]
        print("\t", all_imagenet_class_names[idx])
feat_autolabels_default

In [None]:
# subtract from default, label, and print trends
text_probs_altered.shape

for text_probs_altered in text_probs_altered_list:
    logit_diff = text_probs_altered - text_probs_default
    print(logit_diff)
    vals, idxs = torch.topk(logit_diff,k=5)
    print(vals, np.array(all_imagenet_class_names[idxs])
    break

In [None]:
all_imagenet_class_names

In [None]:
from collections import defaultdict, Counter

feat_autolabels_altered_list = []
for text_probs_altered in text_probs_altered_list:
    feat_autolabels_altered = defaultdict(Counter)
    for i in range(text_probs_altered.shape[0]):
        vals, idxs = torch.topk(text_probs_altered[i],k=top_k_imgnet_labels)
#         print(i)
        for k, idx in enumerate(idxs):
            feat_autolabels_altered[i][all_imagenet_class_names[idx]] += vals[k]
#             print("\t", all_imagenet_class_names[idx])
    feat_autolabels_altered_list.append(feat_autolabels_altered)

start_idx = 9
end_idx = 10
    
h = 0
for key in feat_autolabels_default:
    print(f"\nfeat_autolabels_default img {key}:\n {feat_autolabels_default[key]}\n")
    h += 1
    if h > end_idx:
        break
for i, f_a_a in enumerate(feat_autolabels_altered_list):
    print("============= feature number ", i, "====================")
    h = 0
    for key in range(start_idx, end_idx):
#         print("\n", key)
#         for item in f_a_a[key]:
#             print("\t", item, f_a_a[key][item].cpu().item())
        print(f"\nf_a_a img {key}:\n {f_a_a[key]}\n")


In [None]:
for i in range(text_probs_default.shape[0]):
    vals, idxs = torch.topk(text_probs_default[i],k=1000)
    print(i, ind_to_name[str(idxs[0].cpu().item())][1])
    fig, ax = plt.subplots(figsize=(10, 10))
#     ax.xaxis.set_ticks((1000))
#     ax.set_xticks(list(range(1000)), [ind_to_name[str(idxs[idx].cpu().item())][1] for idx in idxs])
    plt.bar(idxs.cpu(), vals.cpu(), width=5)
    break

In [None]:
plt.imshow(batch_images[2].cpu().permute((1,2,0)).numpy())

In [None]:

@torch.no_grad()
def get_heatmap(
          image,
          model,
          sparse_autoencoder,
          feature_id,
): 
    image = image.to(cfg.device)
    _, cache = model.run_with_cache(image.unsqueeze(0))

    post_reshaped = einops.rearrange(cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic
    print(f"sae_in.shape: {sae_in.shape}")
    acts = einops.einsum(
            sae_in,
            sparse_autoencoder.W_enc[:, feature_id],
            "x d_in, d_in -> x",
        )
    return acts 
     
def image_patch_heatmap(activation_values,image_size=224, pixel_num=14):
    activation_values = activation_values.detach().cpu().numpy()
    activation_values = activation_values[1:]
    activation_values = activation_values.reshape(pixel_num, pixel_num)

    # Create a heatmap overlay
    heatmap = np.zeros((image_size, image_size))
    patch_size = image_size // pixel_num

    for i in range(pixel_num):
        for j in range(pixel_num):
            heatmap[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = activation_values[i, j]

    return heatmap


In [None]:
from matplotlib import pyplot as plt

grid_size = 1
fig, axs = plt.subplots(int(np.ceil(len(images)/grid_size)), grid_size, figsize=(15, 15))
name=  f"Category: uhh,  Feature: {0}"
fig.suptitle(name)#, y=0.95)
for ax in axs.flatten():
    ax.axis('off')
complete_bid = []

heatmap = get_heatmap(batch_images[2], model,sparse_autoencoder, 10000)
heatmap = image_patch_heatmap(heatmap, pixel_num=224//cfg.patch_size)

display = batch_images[2].cpu().numpy().transpose(1, 2, 0)

has_zero = False

In [None]:
plt.imshow(display)
plt.imshow(heatmap, alpha=0.3)

In [None]:
plt.imshow(display)
plt.imshow(heatmap, alpha=0.3)