In [None]:
#paths
OUTPUT_FOLDER = "" # output directory
import os
os.makedirs(OUTPUT_FOLDER,exist_ok=True)

CHECKPOINT_PATH= "" # trainingscirp checkpoint path #TODO shouldn't be needed.
IMAGENET_PATH = "" #'folder containing imagenet1k data organized as follows: https://www.kaggle.com/c/imagenet-object-localization-challenge/overview/description'

SAE_PATH = ""# path to SAE folder, might look something like "final_sae_group_wkcn_TinyCLIP-ViT-40M-32-Text-19M-LAION400M_blocks.{layer}.mlp.hook_post_16384"
AUTOENCODER_NAME = "" #name of the particular sae within group (all names will get printed below )
#model specs TODO these should be infered from pretrained model checkpoint (if they aren't already)
LAYERS =  9
EXPANSION_FACTOR = 8
D_IN = 2048
MODEL_NAME = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"
CONTEXT_SIZE = 50 
PATCH_SIZE = 32
HOOKPOINT = "blocks.{layer}.mlp.hook_post"
LEGACY_LOAD= False


In [None]:
# eval constants

EVAL_MAX = 50_000 
BATCH_SIZE = 32


In [None]:
from sae.main import setup, ImageNetValidationDataset
import torch
import plotly.express as px
from typing import List

import torch
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torchvision
import einops
from transformers import CLIPProcessor
from vit_prisma.utils.data_utils.imagenet_dict import IMAGENET_DICT
from typing import List

device = "cuda"

torch.set_grad_enabled(False)


In [None]:
# setup eval data 
clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    #TODO for clip only 
    torchvision.transforms.Normalize(mean=clip_processor.image_processor.image_mean,
                        std=clip_processor.image_processor.image_std), ])



# assuming the same structure as here: https://www.kaggle.com/c/imagenet-object-localization-challenge/overview/description
imagenet_val_path  =os.path.join(IMAGENET_PATH, "ILSVRC/Data/CLS-LOC/val")
imagenet_val_labels = os.path.join(IMAGENET_PATH, "LOC_val_solution.csv")
imagenet_label_strings = os.path.join(IMAGENET_PATH, "LOC_synset_mapping.txt" )
imagenet_data = ImageNetValidationDataset(imagenet_val_path,imagenet_label_strings, imagenet_val_labels ,data_transforms, return_index=True)
imagenet_data_visualize = ImageNetValidationDataset(imagenet_val_path,imagenet_label_strings, imagenet_val_labels ,torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),]), return_index=True)

data_loader = DataLoader(imagenet_data, batch_size=BATCH_SIZE, shuffle=False)


ind_to_name = {}

with open( os.path.join(IMAGENET_PATH, "LOC_synset_mapping.txt" ), 'r') as file:
    # Iterate over each line in the file
    for line_num, line in enumerate(file):
        line = line.strip()
        if not line:
            continue
        parts = line.split(' ')
        label = parts[1].split(',')[0]
        ind_to_name[line_num] = label


# setup model
cfg ,model, activations_loader, sae_group = setup(checkpoint_path=CHECKPOINT_PATH, 
                                                  imagenet_path=IMAGENET_PATH ,
                                                    pretrained_path=SAE_PATH, layers= LAYERS, expansion_factor=EXPANSION_FACTOR,
                                                    model_name=MODEL_NAME, context_size=CONTEXT_SIZE, d_in=D_IN, hook_point=HOOKPOINT, legacy_load=LEGACY_LOAD)
model = model.to(device)
for i, (name, sae) in enumerate(sae_group):
    hyp = sae.cfg
    print(
        f"{i}: Name: {name} Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}"
    )

sparse_autoencoder = sae_group.autoencoders[AUTOENCODER_NAME]
sparse_autoencoder = sparse_autoencoder.to(device)
layer_num = sparse_autoencoder.cfg.hook_point_layer
print(f"Chosen layer {layer_num} hook point {sparse_autoencoder.cfg.hook_point}")

## Test the Autoencoder


### L0 Test

In [None]:
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    batch_tokens, labels = activations_loader.get_val_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

### ARENA stuff
https://arena3-chapter1-transformer-interp.streamlit.app/[1.4]_Superposition_&_SAEs
first getting feature probability

In [None]:

# helper functions
update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor", "showlegend", "xaxis_tickmode", "yaxis_tickmode", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap", "coloraxis_showscale"}
def to_numpy(tensor):
    """
    Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
    """
    if isinstance(tensor, np.ndarray):
        return tensor
    elif isinstance(tensor, (list, tuple)):
        array = np.array(tensor)
        return array
    elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)):
        return tensor.detach().cpu().numpy()
    elif isinstance(tensor, (int, float, bool, str)):
        return np.array(tensor)
    else:
        raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")

def hist(tensor, save_name, show=True, renderer=None, **kwargs):
    '''
    '''
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.1
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])

    histogram_fig = px.histogram(x=to_numpy(tensor), **kwargs_pre)
    histogram_fig.update_layout(**kwargs_post)

    # Save the figure as a PNG file
    histogram_fig.write_image(os.path.join(OUTPUT_FOLDER, f"{save_name}.png"))
    if show:
        px.histogram(x=to_numpy(tensor), **kwargs_pre).update_layout(**kwargs_post).show(renderer)



In [None]:
@torch.no_grad()
def get_feature_probability(
    images,
    model,
    sparse_autoencoder,
):
    '''
    Returns the feature probabilities (i.e. fraction of time the feature is active) for each feature in the
    autoencoder, averaged over all `batch * seq` tokens.
    '''
    _, cache = model.run_with_cache(images)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    class_acts = feature_acts[:, 0, :]
    post_reshaped = einops.repeat(feature_acts, "batch seq d_mlp -> (batch seq) d_mlp")

    return post_reshaped.mean(0), class_acts.mean(0)

total_acts = None
total_class_acts = None
this_max = EVAL_MAX
for batch_idx, (total_images, total_labels, total_indices) in tqdm(enumerate(data_loader), total=this_max//BATCH_SIZE):
        total_images = total_images.to(device)
        new, new_class = get_feature_probability(total_images, model, sparse_autoencoder)

        if total_acts is None:
             total_acts = new
             total_class_acts = new_class 
        else:
             total_acts = total_acts + new 
             total_class_acts = total_class_acts + new_class


        if batch_idx*BATCH_SIZE >= this_max:
            break


In [None]:
feature_probability = total_acts/(this_max//BATCH_SIZE)

log_freq = (feature_probability + 1e-10).log10()

feature_probability_class = total_class_acts/(this_max//BATCH_SIZE)

log_freq_class = (feature_probability_class + 1e-10).log10()


In [None]:
print(feature_probability)
def visualize_sparsities(log_freq, conditions, condition_texts, name):
    # Visualise sparsities for each instance
    hist(
        log_freq,
        f"{name}_frequency_histogram",
        show=True,
        title=f"{name} Log Frequency of Features",
        labels={"x": "log<sub>10</sub>(freq)"},
        histnorm="percent",
        template="ggplot2"
    )



    for condition, condition_text in zip(conditions, condition_texts):
        percentage = (torch.count_nonzero(condition)/log_freq.shape[0]).item()*100
        if percentage == 0:
            continue
        percentage = int(np.round(percentage))
        rare_encoder_directions = sparse_autoencoder.W_enc[:, condition]
        rare_encoder_directions_normalized = rare_encoder_directions / rare_encoder_directions.norm(dim=0, keepdim=True)

        # Compute their pairwise cosine similarities & sample randomly from this N*N matrix of similarities
        cos_sims_rare = (rare_encoder_directions_normalized.T @ rare_encoder_directions_normalized).flatten()
        cos_sims_rare_random_sample = cos_sims_rare[torch.randint(0, cos_sims_rare.shape[0], (10000,))]

        # Plot results
        hist(
            cos_sims_rare_random_sample,
            f"{name}_low_prop_similarity_{condition_text}",
            show=True,
            marginal="box",
            title=f"{name} Cosine similarities of random {condition_text} encoder directions with each other ({percentage}% of features)",
            labels={"x": "Cosine sim"},
            histnorm="percent",
            template="ggplot2",
        )

#TODO these conditions should be tuned to distribution of your data!
conditions = [ torch.logical_and(log_freq < -4,log_freq > -5),torch.logical_and(log_freq > -4,log_freq < -2),log_freq>-2, log_freq <-8, torch.logical_and(log_freq < -4,log_freq > -6.5),torch.logical_and(log_freq < -6.5,log_freq > -8)]
condition_texts = [  "logfreq_[-5,-4]", "logfreq_[-4,-2]", "logfreq_[-2,inf]","logfreq_[-inf,-8]", "logfreq_[-6.5,-4]", "logfreq_[-8,-6.5]",]
visualize_sparsities(log_freq, conditions, condition_texts, "TOTAL")
conditions_class = [torch.logical_and(log_freq_class < -4,log_freq_class > -8), log_freq_class <-9, log_freq_class>-4]
condition_texts_class = ["logfreq_[-8,-4]", "logfreq_[-inf,-9]","logfreq_[-4,inf]"]
visualize_sparsities(log_freq_class, conditions_class, condition_texts_class,"CLS")

reconstruction and substitution loss.

In [None]:

def get_reconstruction_loss(
    images,
    model,
    autoencoder,
):
    '''
    Returns the reconstruction loss of each autoencoder instance on the given batch of tokens (i.e.
    the L2 loss between the activations and the autoencoder's reconstructions, averaged over all tokens).
    '''

    logits, cache = model.run_with_cache(images)
    sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )

    # Print out the avg L2 norm of activations
    print("Avg L2 norm of acts: ", cache[sparse_autoencoder.cfg.hook_point].pow(2).mean().item())

    # Print out the cosine similarity between original neuron activations & reconstructions (averaged over neurons)
    print("Avg cos sim of neuron reconstructions: ", torch.cosine_similarity(einops.rearrange( cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp"),
                                                                              einops.rearrange( sae_out, "batch seq d_mlp -> (batch seq) d_mlp"),
                                                                                dim=0).mean(-1).tolist())
    print("l1", l1_loss.sum().item())
    return mse_loss.item()

this_max = 4
count = 0
print(sparse_autoencoder.cfg.hook_point)
for batch_idx, (total_images, total_labels, total_indices) in enumerate(data_loader):
        total_images = total_images.to(device)
        reconstruction_loss = get_reconstruction_loss(total_images, model, sparse_autoencoder)
        print("mse", reconstruction_loss)



        if batch_idx >= this_max:
            break

Notes:
Language model results for comparsion
Avg L2 norm of acts:  0.11062075197696686
Avg cos sim of neuron reconstructions:  0.8348199129104614
l1 19.51767921447754
mse 0.043452925980091095



In [None]:

# get random features from different bins

interesting_features_indices = []
interesting_features_values = []
interesting_features_category = []
number_features_per = 50
for condition, condition_text in zip(conditions + conditions_class, [f"TOTAL_{c}" for c in condition_texts] + [f"CLS_{c}" for c in condition_texts_class]):
    

    potential_indices = torch.nonzero(condition, as_tuple=True)[0]

    # Shuffle these indices and select a subset
    sampled_indices = potential_indices[torch.randperm(len(potential_indices))[:number_features_per]]

    values = log_freq[sampled_indices]

    interesting_features_indices = interesting_features_indices + sampled_indices.tolist()
    interesting_features_values = interesting_features_values + values.tolist()

    interesting_features_category = interesting_features_category + [f"{condition_text}"]*len(sampled_indices)


# for v,i, c in zip(interesting_features_indices, interesting_features_values, interesting_features_category):
#     print(c, v,i)

print(set(interesting_features_category))



In [None]:



torch.no_grad()
def highest_activating_tokens(
    images,
    model,
    sparse_autoencoder,
    W_enc,
    b_enc,
    feature_ids: List[int],
    feature_categories,
    k: int = 10,
):
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''

    # Get the post activations from the clean run
    _, cache = model.run_with_cache(images)

    inp = cache[sparse_autoencoder.cfg.hook_point]
    b, seq_len, _ = inp.shape
    post_reshaped = einops.rearrange( inp, "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic

    acts = einops.einsum(
            sae_in,
            W_enc,
            "... d_in, d_in n -> ... n",
        )
    
    acts = acts + b_enc
    acts = torch.nn.functional.relu(acts)
    #TODO clean up
    unshape = einops.rearrange(acts, "(batch seq) d_in -> batch seq d_in", batch=b, seq=seq_len)
    cls_acts = unshape[:,0,:]
    per_image_acts = unshape.mean(1)



    to_return = {} 
    #TODO this is a bad way to do it.
    for i, (feature_id, feature_cat) in enumerate(zip(feature_ids, feature_categories)):
        if "CLS_" in feature_cat:
            top_acts_values, top_acts_indices = cls_acts[:,i].topk(k)

            to_return[feature_id]  = (top_acts_indices, top_acts_values)
        else:
            top_acts_values, top_acts_indices = per_image_acts[:,i].topk(k)

            to_return[feature_id]  = (top_acts_indices, top_acts_values)
    return to_return 
this_max = EVAL_MAX

max_indices = {i:None for i in interesting_features_indices}
max_values =  {i:None for i in interesting_features_indices} 
b_enc = sparse_autoencoder.b_enc[interesting_features_indices]
W_enc = sparse_autoencoder.W_enc[:, interesting_features_indices]
for batch_idx, (total_images, total_labels, total_indices) in tqdm(enumerate(data_loader), total=this_max//BATCH_SIZE):
        total_images = total_images.to(device)
        total_indices = total_indices.to(device)
        new_stuff = highest_activating_tokens(total_images, model, sparse_autoencoder, W_enc, b_enc, interesting_features_indices, interesting_features_category, k=16)
        for feature_id in interesting_features_indices:

            new_indices, new_values = new_stuff[feature_id]
            new_indices = total_indices[new_indices]
            #  new_indices[:,0] = new_indices[:,0] + batch_idx*batch_size
            
            if max_indices[feature_id] is None:
                max_indices[feature_id] = new_indices
                max_values[feature_id] = new_values
            else:
                ABvals = torch.cat((max_values[feature_id], new_values))
                ABinds = torch.cat((max_indices[feature_id], new_indices))
                _, inds = torch.topk(ABvals, new_values.shape[0])
                max_values[feature_id] = ABvals[inds]
                max_indices[feature_id] = ABinds[inds]
    

        if batch_idx*BATCH_SIZE >= this_max:
            break
top_per_feature = {i:(max_values[i].detach().cpu(), max_indices[i].detach().cpu()) for i in interesting_features_indices}

In [None]:

torch.no_grad()
def get_heatmap(
          
          image,
          model,
          sparse_autoencoder,
          feature_id,
): 
    image = image.to(device)
    _, cache = model.run_with_cache(image.unsqueeze(0))

    post_reshaped = einops.rearrange( cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic
    acts = einops.einsum(
            sae_in,
            sparse_autoencoder.W_enc[:, feature_id],
            "x d_in, d_in -> x",
        )
    return acts 
     
def image_patch_heatmap(activation_values,image_size=224, pixel_num=14):
    activation_values = activation_values.detach().cpu().numpy()
    activation_values = activation_values[1:]
    activation_values = activation_values.reshape(pixel_num, pixel_num)

    # Create a heatmap overlay
    heatmap = np.zeros((image_size, image_size))
    patch_size = image_size // pixel_num

    for i in range(pixel_num):
        for j in range(pixel_num):
            heatmap[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = activation_values[i, j]

    return heatmap

    # Removing axes


for feature_ids, cat, logfreq in tqdm(zip(top_per_feature.keys(), interesting_features_category, interesting_features_values), total=len(interesting_features_category)):
  #  print(f"looking at {feature_ids}, {cat}")
    max_vals, max_inds = top_per_feature[feature_ids]
    images = []
    model_images = []
    gt_labels = []
    for bid, v in zip(max_inds, max_vals):

        image, label, image_ind = imagenet_data_visualize[bid]

        assert image_ind.item() == bid
        images.append(image)

        model_img, _, _ = imagenet_data[bid]
        model_images.append(model_img)
        gt_labels.append(ind_to_name[label])
    
    grid_size = int(np.ceil(np.sqrt(len(images))))
    fig, axs = plt.subplots(int(np.ceil(len(images)/grid_size)), grid_size, figsize=(15, 15))
    name=  f"Category: {cat},  Feature: {feature_ids}"
    fig.suptitle(name)#, y=0.95)
    for ax in axs.flatten():
        ax.axis('off')
    complete_bid = []

    for i, (image_tensor, label, val, bid,model_img) in enumerate(zip(images, gt_labels, max_vals,max_inds,model_images )):
        if bid in complete_bid:
            continue 
        complete_bid.append(bid)



        row = i // grid_size
        col = i % grid_size
        heatmap = get_heatmap(model_img,model,sparse_autoencoder, feature_ids )
        heatmap = image_patch_heatmap(heatmap, pixel_num=224//PATCH_SIZE)

        display = image_tensor.numpy().transpose(1, 2, 0)

        has_zero = False
        

        axs[row, col].imshow(display)
        axs[row, col].imshow(heatmap, cmap='viridis', alpha=0.3)  # Overlaying the heatmap
        axs[row, col].set_title(f"{label} {val.item():0.03f} {'class token!' if has_zero else ''}")  
        axs[row, col].axis('off')  

    plt.tight_layout()
    folder = os.path.join(OUTPUT_FOLDER, f"{cat}")
    os.makedirs(folder, exist_ok=True)
    plt.savefig(os.path.join(folder, f"neglogfreq_{-logfreq}feauture_id_{feature_ids}.png"))
    plt.close()
   # plt.show()
