In [None]:
#paths
OUTPUT_FOLDER = "" # output directory
import os
#os.makedirs(OUTPUT_FOLDER,exist_ok=True)

CHECKPOINT_PATH= "" # trainingscirp checkpoint path #TODO shouldn't be needed.
IMAGENET_PATH = "" #'folder containing imagenet1k data organized as follows: https://www.kaggle.com/c/imagenet-object-localization-challenge/overview/description'

SAE_PATH = ""# path to SAE folder, might look something like "final_sae_group_wkcn_TinyCLIP-ViT-40M-32-Text-19M-LAION400M_blocks.{layer}.mlp.hook_post_16384"
AUTOENCODER_NAME = "" #name of the particular sae within group (all names will get printed below )
#model specs TODO these should be infered from pretrained model checkpoint (if they aren't already)

#Optional
HUGGINGFACE_CACHE_DIR = None
LAYERS =  9
EXPANSION_FACTOR = 8
D_IN = 2048
MODEL_NAME = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"
CONTEXT_SIZE = 50 
PATCH_SIZE = 32
HOOKPOINT = "blocks.{layer}.mlp.hook_post"
LEGACY_LOAD= False


In [None]:
# eval constants

EVAL_MAX = 10_000 
BATCH_SIZE = 32


In [None]:
from sae.main import setup, ImageNetValidationDataset
import torch
import plotly.express as px
from typing import List

import torch
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torchvision
import einops
from transformers import CLIPProcessor
from vit_prisma.utils.data_utils.imagenet_dict import IMAGENET_DICT
from typing import List
from torch.utils.data import Dataset
from datasets import load_dataset



device = "cuda"

torch.set_grad_enabled(False)


In [None]:
default_dataset = load_dataset('Prisma-Multimodal/segmented-imagenet1k-subset', cache_dir =HUGGINGFACE_CACHE_DIR)

In [None]:
class PatchDataset(Dataset):
    def __init__(self, dataset, patch_size=32, width=224, height=224, return_label = True):
        """
        dataset: A list of dictionaries, each dictionary corresponds to an image and its details
        """
        self.dataset = dataset
        clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
        self.transform =  torchvision.transforms.Compose([
                        torchvision.transforms.Resize((224, 224)),
                        torchvision.transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
                        torchvision.transforms.ToTensor(),
                        #TODO for clip only 
                        torchvision.transforms.Normalize(mean=clip_processor.image_processor.image_mean,
                        std=clip_processor.image_processor.image_std), ])
        self.patch_size = patch_size

        self.width = width
        self.height = height
        self.return_label = return_label
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = self.transform(item['image'])
        if self.return_label:
            masks = item['masks']
            labels = item['labels']  # Assuming labels are aligned with masks
            
            # Calculate the size of the reduced mask
            num_patches = self.width // self.patch_size
            label_array = [[[] for _ in range(num_patches)] for _ in range(num_patches)]
            
            for mask, label in zip(masks, labels):
                # Resize and reduce the mask
                mask = mask.resize((self.width, self.height))
                mask_array = np.array(mask) > 0
                reduced_mask = self.reduce_mask(mask_array)
                
                # Populate the label array based on the reduced mask
                for i in range(num_patches):
                    for j in range(num_patches):
                        if reduced_mask[i, j]:
                            label_array[i][j].append(label)
            
            # Convert label_array to a format suitable for tensor operations, if necessary
            # For now, it's a list of lists of lists, which can be used directly in Python
            
            return image, label_array, idx
        else:
            return image, idx 
    

    def reduce_mask(self, mask):
        """
        Reduce the mask size by dividing it into patches and checking if there's at least
        one True value within each patch.
        """
        # Calculate new height and width
        new_h = mask.shape[0] // self.patch_size
        new_w = mask.shape[1] // self.patch_size
        
        reduced_mask = np.zeros((new_h, new_w), dtype=bool)
        
        for i in range(new_h):
            for j in range(new_w):
                patch = mask[i*self.patch_size:(i+1)*self.patch_size, j*self.patch_size:(j+1)*self.patch_size]
                reduced_mask[i, j] = np.any(patch)  # Set to True if any value in the patch is True
        
        return reduced_mask

def collate_fn(data):
    images = torch.stack([d[0] for d in data])
    ids = [d[1] for d in data]
    return images, ids

patch_label_dataset = PatchDataset(default_dataset['train'], return_label=False)
patch_label_dataset_with_label = PatchDataset(default_dataset['train'], return_label=True)
im, idx = patch_label_dataset[0]
print(im.shape)
print(idx)
im, l, idx = patch_label_dataset_with_label[0]
print(im.shape)
print(l)
print(idx)
data_loader = DataLoader(patch_label_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
im, idx = next(iter(data_loader))
print(im.shape)
print(idx)

In [None]:




# setup model
cfg ,model, activations_loader, sae_group = setup(checkpoint_path=CHECKPOINT_PATH, 
                                                  imagenet_path=IMAGENET_PATH ,
                                                    pretrained_path=SAE_PATH, layers= LAYERS, expansion_factor=EXPANSION_FACTOR,
                                                    model_name=MODEL_NAME, context_size=CONTEXT_SIZE, d_in=D_IN, hook_point=HOOKPOINT, legacy_load=LEGACY_LOAD)
model = model.to(device)
for i, (name, sae) in enumerate(sae_group):
    hyp = sae.cfg
    print(
        f"{i}: Name: {name} Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}"
    )

sparse_autoencoder = sae_group.autoencoders[AUTOENCODER_NAME]
sparse_autoencoder = sparse_autoencoder.to(device)
layer_num = sparse_autoencoder.cfg.hook_point_layer
print(f"Chosen layer {layer_num} hook point {sparse_autoencoder.cfg.hook_point}")

In [None]:
sparse_autoencoder.eval()  

In [None]:

# helper functions
update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor", "showlegend", "xaxis_tickmode", "yaxis_tickmode", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap", "coloraxis_showscale"}
def to_numpy(tensor):
    """
    Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
    """
    if isinstance(tensor, np.ndarray):
        return tensor
    elif isinstance(tensor, (list, tuple)):
        array = np.array(tensor)
        return array
    elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)):
        return tensor.detach().cpu().numpy()
    elif isinstance(tensor, (int, float, bool, str)):
        return np.array(tensor)
    else:
        raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")

def hist(tensor, save_name, show=True, renderer=None, **kwargs):
    '''
    '''
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.1
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])

    histogram_fig = px.histogram(x=to_numpy(tensor), **kwargs_pre)
    histogram_fig.update_layout(**kwargs_post)

    # Save the figure as a PNG file
    histogram_fig.write_image(os.path.join(OUTPUT_FOLDER, f"{save_name}.png"))
    if show:
        px.histogram(x=to_numpy(tensor), **kwargs_pre).update_layout(**kwargs_post).show(renderer)



In [None]:
@torch.no_grad()
def get_feature_probability(
    images,
    model,
    sparse_autoencoder,
):
    '''
    Returns the feature probabilities (i.e. fraction of time the feature is active) for each feature in the
    autoencoder, averaged over all `batch * seq` tokens.
    '''
    _, cache = model.run_with_cache(images)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    class_acts = feature_acts[:, 0, :]
    post_reshaped = einops.repeat(feature_acts, "batch seq d_mlp -> (batch seq) d_mlp")

    return post_reshaped.mean(0), class_acts.mean(0)

total_acts = None
total_class_acts = None
this_max = 500
for batch_idx, (total_images, _) in tqdm(enumerate(data_loader), total=this_max//BATCH_SIZE):
        total_images = total_images.to(device)
        new, new_class = get_feature_probability(total_images, model, sparse_autoencoder)

        if total_acts is None:
             total_acts = new
             total_class_acts = new_class 
        else:
             total_acts = total_acts + new 
             total_class_acts = total_class_acts + new_class


        if batch_idx*BATCH_SIZE >= this_max:
            break


In [None]:
feature_probability = total_acts/(this_max//BATCH_SIZE)

log_freq = (feature_probability + 1e-10).log10()

feature_probability_class = total_class_acts/(this_max//BATCH_SIZE)

log_freq_class = (feature_probability_class + 1e-10).log10()


In [None]:
print(feature_probability)
def visualize_sparsities(log_freq, conditions, condition_texts, name):
    # Visualise sparsities for each instance
    hist(
        log_freq,
        f"{name}_frequency_histogram",
        show=True,
        title=f"{name} Log Frequency of Features",
        labels={"x": "log<sub>10</sub>(freq)"},
        histnorm="percent",
        template="ggplot2"
    )



    for condition, condition_text in zip(conditions, condition_texts):
        percentage = (torch.count_nonzero(condition)/log_freq.shape[0]).item()*100
        if percentage == 0:
            continue
        percentage = int(np.round(percentage))
        rare_encoder_directions = sparse_autoencoder.W_enc[:, condition]
        rare_encoder_directions_normalized = rare_encoder_directions / rare_encoder_directions.norm(dim=0, keepdim=True)

        # Compute their pairwise cosine similarities & sample randomly from this N*N matrix of similarities
        cos_sims_rare = (rare_encoder_directions_normalized.T @ rare_encoder_directions_normalized).flatten()
        cos_sims_rare_random_sample = cos_sims_rare[torch.randint(0, cos_sims_rare.shape[0], (10000,))]

        # Plot results
        hist(
            cos_sims_rare_random_sample,
            f"{name}_low_prop_similarity_{condition_text}",
            show=True,
            marginal="box",
            title=f"{name} Cosine similarities of random {condition_text} encoder directions with each other ({percentage}% of features)",
            labels={"x": "Cosine sim"},
            histnorm="percent",
            template="ggplot2",
        )

#TODO these conditions should be tuned to distribution of your data!
conditions = [ torch.logical_and(log_freq < -3,log_freq > -4)]
condition_texts = [  "logfreq_[-4,-3]"]
visualize_sparsities(log_freq, conditions, condition_texts, "TOTAL")


In [None]:

# get random features from different bins

interesting_features_indices = []
interesting_features_values = []
interesting_features_category = []
number_features_per = 25
for condition, condition_text in zip(conditions, [f"TOTAL_{c}" for c in condition_texts]):
    

    potential_indices = torch.nonzero(condition, as_tuple=True)[0]

    # Shuffle these indices and select a subset
    sampled_indices = potential_indices[torch.randperm(len(potential_indices))[:number_features_per]]

    values = log_freq[sampled_indices]

    interesting_features_indices = interesting_features_indices + sampled_indices.tolist()
    interesting_features_values = interesting_features_values + values.tolist()

    interesting_features_category = interesting_features_category + [f"{condition_text}"]*len(sampled_indices)


# for v,i, c in zip(interesting_features_indices, interesting_features_values, interesting_features_category):
#     print(c, v,i)

print(set(interesting_features_category))



In [None]:



torch.no_grad()
def highest_activating_tokens(
    images,
    model,
    sparse_autoencoder,
    W_enc,
    b_enc,
    feature_ids: List[int],
    feature_categories,
    k: int = 10,
):
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''

    # Get the post activations from the clean run
    _, cache = model.run_with_cache(images)

    inp = cache[sparse_autoencoder.cfg.hook_point]
    b, seq_len, _ = inp.shape
    post_reshaped = einops.rearrange( inp, "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic

    acts = einops.einsum(
            sae_in,
            W_enc,
            "... d_in, d_in n -> ... n",
        )
    
    acts = acts + b_enc
    acts = torch.nn.functional.relu(acts)
    #TODO clean up
    unshape = einops.rearrange(acts, "(batch seq) d_in -> batch seq d_in", batch=b, seq=seq_len)
    cls_acts = unshape[:,0,:]
    per_image_acts = unshape.mean(1)


    to_return = {} 
    #TODO this is a bad way to do it.
    for i, (feature_id, feature_cat) in enumerate(zip(feature_ids, feature_categories)):
        if "CLS_" in feature_cat:
            top_acts_values, top_acts_indices = cls_acts[:,i].topk(k)

            to_return[feature_id]  = (top_acts_indices, top_acts_values)
        else:
            top_acts_values, top_acts_indices = per_image_acts[:,i].topk(min(k,per_image_acts[:,i].shape[0]))

            to_return[feature_id]  = (top_acts_indices, top_acts_values)
    return to_return, per_image_acts
this_max = EVAL_MAX

max_indices = {i:None for i in interesting_features_indices}
max_values =  {i:None for i in interesting_features_indices}

# for each feature, the total activation of a given image
all_activations_per_image = torch.zeros( len(patch_label_dataset), (len(interesting_features_indices))).to(device)

#print(all_activations_per_image.shape) 
b_enc = sparse_autoencoder.b_enc[interesting_features_indices]
W_enc = sparse_autoencoder.W_enc[:, interesting_features_indices]
for batch_idx, (total_images, total_indices) in tqdm(enumerate(data_loader), total=this_max//BATCH_SIZE):
        total_images = total_images.to(device)
        total_indices = torch.tensor(total_indices).to(device)
        new_stuff, new_per_image_acts = highest_activating_tokens(total_images, model, sparse_autoencoder, W_enc, b_enc, interesting_features_indices, interesting_features_category, k=25)
        all_activations_per_image[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + new_per_image_acts.shape[0], :] = new_per_image_acts
       # print(new_per_image_acts.shape)
        for feature_id in interesting_features_indices:

            new_indices, new_values = new_stuff[feature_id]
            new_indices = total_indices[new_indices]
            #  new_indices[:,0] = new_indices[:,0] + batch_idx*batch_size
            
            if max_indices[feature_id] is None:
                max_indices[feature_id] = new_indices
                max_values[feature_id] = new_values
            else:
                ABvals = torch.cat((max_values[feature_id], new_values))
                ABinds = torch.cat((max_indices[feature_id], new_indices))
                _, inds = torch.topk(ABvals, 25)
                max_values[feature_id] = ABvals[inds]
                max_indices[feature_id] = ABinds[inds]
           # print(max_indices[feature_id].shape)

        if batch_idx*BATCH_SIZE >= this_max:
            break
all_activations_per_image = all_activations_per_image.detach().cpu()
top_per_feature = {i:(max_values[i].detach().cpu(), max_indices[i].detach().cpu()) for i in interesting_features_indices}

In [None]:

torch.no_grad()
def get_heatmap(
          
          image,
          model,
          sparse_autoencoder,
          feature_id,
): 
    image = image.to(device)
    _, cache = model.run_with_cache(image.unsqueeze(0))

    post_reshaped = einops.rearrange( cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic
    acts = einops.einsum(
            sae_in,
            sparse_autoencoder.W_enc[:, feature_id],
            "x d_in, d_in -> x",
        )
    return acts 
     
import pandas as pd

num_rows = len(list(top_per_feature.keys()))*(25+25)*50

# Define the column names and types
columns = {
    'patchIdx': np.int32,
    'featureIdx': np.int32,
    'imageIdx': np.int32,
    'activationValue': np.float32,
    'label': object,  # Strings are object type in pandas
    'type': object,
    'layerIdx': np.int32
}

# Create the DataFrame with preallocated data
df = pd.DataFrame({
    'patchIdx': np.zeros(num_rows, dtype=np.int32),
    'featureIdx': np.zeros(num_rows, dtype=np.int32),
    'imageIdx': np.zeros(num_rows, dtype=np.int32),
    'activationValue': np.zeros(num_rows, dtype=np.float32),
    'label': [''] * num_rows,  # Initialize with empty strings
    'type': [''] * num_rows,   # Initialize with empty strings
    'layerIdx': np.zeros(num_rows, dtype=np.int32)
})

# Print the updated table
# Save the table to a Parquet file
df_count = 0
for feature_ids, cat, logfreq in tqdm(zip(top_per_feature.keys(), interesting_features_category, interesting_features_values), total=len(interesting_features_category)):
  #  print(f"looking at {feature_ids}, {cat}")
    _, max_inds = top_per_feature[feature_ids]
    max_inds = [m.item() for m in max_inds]
    if len(max_inds) != len(set(max_inds)):
        print("SKIPPING")
        continue
    images = []
    gt_labels = []
    is_random = []
    bids = []
    for bid in max_inds:
        image, labels, image_ind = patch_label_dataset_with_label[bid]

        assert image_ind == bid
        images.append(image)

        gt_labels.append(labels)
        is_random.append("top")
        bids.append(bid)

    for _ in range(len(max_inds)):
        bid = np.random.randint(0, len(patch_label_dataset_with_label) )
        image, labels, image_ind = patch_label_dataset_with_label[bid]

        assert image_ind == bid
        images.append(image)

        gt_labels.append(labels)
        is_random.append("random")
        bids.append(bid)


    for i, (all_label, bid,img, typestr) in enumerate(zip(gt_labels,bids,images, is_random )):

        # TODO create entries 
 

        # image = np.transpose(img, (1, 2, 0))

        # # Display the image
        # plt.imshow(image)
        # plt.axis('off')  # Turn off axis labels
        # plt.show()
        heatmap = get_heatmap(img,model,sparse_autoencoder, feature_ids )
        featureIdx = feature_ids
        imageIdx = bid 
        layerIdx = LAYERS
        
    # ('patchIdx', pa.int32()),
    # ('featureIdx', pa.int32()),
    # ('imageIdx', pa.int32()),
    # ('activationValue', pa.float32()),
    # ('label', pa.string()),
    # ('type', pa.string()),
    # ("layerIdx", pa.int32())
        for patchIdx in range(heatmap.shape[0]):
            activationValue = heatmap[patchIdx].item()
            if patchIdx == 0:
                label = ""
            else:
                pi = (patchIdx-1)//7 # TODO general
                pj = (patchIdx-1)%7
                labels = list(set(all_label[pi][pj]))
                label = ", ".join(labels)

            row = [patchIdx, featureIdx, imageIdx, activationValue, label, typestr, layerIdx]
            new_data = {
                'patchIdx': patchIdx,
                'featureIdx': featureIdx,
                'imageIdx': imageIdx,
                'activationValue':activationValue,
                'label': label,
                'type': typestr,
                'layerIdx': layerIdx
            }

            df.loc[df_count] = new_data
            df_count += 1

    
df.to_parquet(os.path.join(OUTPUT_FOLDER,'example.parquet'))

   # plt.show()


In [None]:
torch.save(all_activations_per_image, os.path.join(OUTPUT_FOLDER, "all_activations.pt"))

In [None]:
#print(all_activations_per_image)

for feature_id, (mvalues, mindices) in top_per_feature.items():
    acts, act_indices = torch.sort(all_activations_per_image[:,0], descending=True)
    print(mvalues.tolist())
    print(mindices.tolist())

    print(act_indices[0:25].tolist())
    break

In [None]:
df = pd.read_parquet(os.path.join(OUTPUT_FOLDER,'example.parquet'))
31250
print(os.path.join(OUTPUT_FOLDER,'example.parquet'))
print(len(list(top_per_feature.keys()))*(25+25)*50)
# Define the number of rows to display
n = 10
print(df_count)
# Display the first n rows of the DataFrame
print(df.head(n))
print(df.tail(n))