In [51]:
import torch 

from torch.utils.data import DataLoader
import numpy as np 
import matplotlib.pyplot as plt
import os 
from tqdm import tqdm
from typing import List, Dict, Tuple
import einops
from vit_prisma.utils.data_utils.imagenet_utils import setup_imagenet_paths
from vit_prisma.dataloaders.imagenet_dataset import ImageNetValidationDataset
from vit_prisma.transforms.open_clip_transforms import get_clip_val_transforms
from vit_prisma.models.base_vit import HookedViT
from vit_prisma.sae.sae import SparseAutoencoder
from vit_prisma.sae.sae_utils import download_sae_from_huggingface
from vit_prisma.sae.evals.evals import EvalConfig
from transformers import CLIPModel 
from transformers import CLIPProcessor
from typing import List 
import urllib.request
from fancy_einsum import einsum
import torchvision
from functools import partial
import torch.nn.functional as F


torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x13cfbb59960>

helper functions

In [52]:

def get_imagenet_val_dataset(dataset_path):

    data_transforms = get_clip_val_transforms()
    imagenet_paths = setup_imagenet_paths(dataset_path)
    return ImageNetValidationDataset(imagenet_paths['val'], 
                                    imagenet_paths['label_strings'], 
                                    imagenet_paths['val_labels'], 
                                    data_transforms, return_index=True
    )

def get_imagenet_val_dataset_visualize(dataset_path):
    imagenet_paths = setup_imagenet_paths(dataset_path)

    return ImageNetValidationDataset(imagenet_paths['val'], 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'],
                                torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),]), return_index=True)

def get_clip_model(model_name='open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K'):                
    model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True, fold_ln=False, center_writing_weights=False) 
    return model

def load_sae(download_dir):
    repo_name = 'soniajoseph/updated-sae-weights'
    file_id = 'UPDATED-final_sae_group_wkcn_TinyCLIP-ViT-40M-32-Text-19M-LAION400M_blocks.9.hook_mlp_out_8192.pt'
    download_sae_from_huggingface(repo_name, file_id, download_dir)

    sae_path= os.path.join(download_dir, file_id)
    sae = SparseAutoencoder(EvalConfig()).load_from_pretrained_legacy_saelens_v2(sae_path)  #TODO may need option to modify cfg
    return sae 

class ClipTextStuff:

    def __init__(self, model_name, device="cuda"):

        self.device= device 

        self.clip_processor = CLIPProcessor.from_pretrained(model_name)



        full_clip_model = CLIPModel.from_pretrained(model_name)

        self.text_model = full_clip_model.text_model.to(device)
        self.text_projection = full_clip_model.text_projection.to(device)
        self.logit_scale = full_clip_model.logit_scale
        with urllib.request.urlopen("https://raw.githubusercontent.com/yossigandelsman/clip_text_span/main/text_descriptions/image_descriptions_general.txt") as response:
            self.labels = response.read().decode('utf-8').split('\n')

        self.labels_projected = self.get_text_embeds(self.labels)

        # for context here is how similarity is computed:
        # image_embeds = vision_outputs[1]
        # image_embeds = self.visual_projection(image_embeds)

        # text_embeds = text_outputs[1]
        # text_embeds = self.text_projection(text_embeds)

        # # normalized features
        # image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        # text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

        # # cosine similarity as logits
        # logit_scale = self.logit_scale.exp()
        # logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
        # logits_per_image = logits_per_text.t()

        pass

    def get_text_embeds(self, list_of_strings:List[str]):
        labels_input = self.clip_processor(text=list_of_strings, return_tensors='pt',  padding=True).input_ids
        labels_input = labels_input.to(self.device)

        labels_projections = self.text_projection(self.text_model(labels_input)[1])

        # normalize
        labels_projections = labels_projections/ labels_projections.norm(p=2, dim=-1, keepdim=True)

        return labels_projections

setup

In [53]:
output_folder = r"F:/ViT-Prisma_fork/data/textspans/output"
imagenet_dataset_path = r"F:/prisma_data/imagenet-object-localization-challenge"  
sae_download_dir =   r'F:/ViT-Prisma_fork/data/textspans/attempted_models/new'

model_name = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"

sae = load_sae(sae_download_dir)
model = get_clip_model(model_name)

os.makedirs(output_folder, exist_ok=True)


hook_layer = 9
hook_point = f"blocks.{hook_layer}.hook_mlp_out"
#F:\ViT-Prisma_fork\data\vision_sae_checkpoints\comparisons\mlp\mlp_16
feature_ids = [2381, 6405, 5379, 6827, 3471, 1436, 264, 6999, 1554, 2498, 4285 ]
descriptions = ["mustache", "bike", "food market (CLS TOKEN!)", "black and white photo (CLS TOKEN!)",
                 "computer hardware (CLS TOKEN!)", "image boundaries", "The text 'in' and 'on'",
                 "Polysemantic? animals in trees and certain products in foreground with plain backgrounds",
                 "Chess?",
                 "Sport players? Orange jerseys?",
                 "Legs of sport players",
                 
                 ]
cls_token_or_not = [False, False, True, True, True, False, False, False, False, False,False]

assert len(feature_ids) == len(descriptions) == len(cls_token_or_not)

stop_at_layer = hook_layer+1 


dataset = get_imagenet_val_dataset(imagenet_dataset_path)
visualize_dataset = get_imagenet_val_dataset_visualize(imagenet_dataset_path)



n_ctx = 50 

device= "cuda"



model = model.to(device)
sae = sae.to(device)
clip_text_stuff = ClipTextStuff(model_name,device=device)

File downloaded successfully to: F:\ViT-Prisma_fork\data\textspans\attempted_models\new\UPDATED-final_sae_group_wkcn_TinyCLIP-ViT-40M-32-Text-19M-LAION400M_blocks.9.hook_mlp_out_8192.pt
File size: 33592403 bytes
n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 15869
Total training images: 1300000
Total wandb updates: 158
Expansion factor: 16
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 15 times.
Number tokens in sparsity calculation window: 4.10e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder
get_activation_fn received: activation_fn=relu, kwargs={}
n_tokens_per_buffer (millions): 0.0512
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 15869
Total training images: 1300000
Total wandb updates: 158
Expansion factor: 16
n_token


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M into HookedTransformer


steering hook

In [54]:
def steering_hook(x, hook, sparse_autoencoder, feature, amount):


    reconstruction = sparse_autoencoder(x)[0]
    error = x - reconstruction

    boosted_feature_acts = sparse_autoencoder.encode_standard(x)
    
    #boosted_feature_acts[:,:,feature] = amount
    #boosted_feature_acts[:,:,feature] =  boosted_feature_acts[:,:,feature] *amount
    boosted_feature_acts[:,0,feature] = amount  
    boosted_sae_out = einops.einsum(
            boosted_feature_acts,
            sparse_autoencoder.W_dec,
            "... d_sae, d_sae d_in -> ... d_in",
        ) + sparse_autoencoder.b_dec
    

    boosted_sae_out = sparse_autoencoder.run_time_activation_norm_fn_out(boosted_sae_out)

    return boosted_sae_out + error

def run_model_with_steering(images,feature,amount ):
    return F.normalize(model.run_with_hooks(
                images,
                fwd_hooks=[
                    (hook_point, partial(steering_hook, sparse_autoencoder=sae, feature=feature, amount=amount))
                ],
                clear_contexts=True
            ), p=1, dim=-1)

In [55]:
labels = ["mustache", "facial hair", "dog", "cat", "sports", "black and white", "past", "40s", "photo", "computer", "hardware", "market" ,"crowd", 
          "bikes", "cycle", "giraffe", "sports", "organge", "collage", "mix up", "cut", "boundary", "chess", "knees", "legs",
          "animals in tree", "trees", "animals", "in", "on", "text"]

labels_projected = clip_text_stuff.get_text_embeds(labels)

Simple steering experiments. We'll see which labels get the biggest boost? 

In [56]:
dataloader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0)
random_batch = next(iter(dataloader))
random_batch = random_batch[0].to(device)

random_output = model(random_batch)

random_output = F.normalize(random_output, p=2, dim=-1)

default_scores = einsum( "I D, T D -> I T", random_output, clip_text_stuff.labels_projected)

In [62]:
boost_amount = 50
top_k=10

for i, feature_id in enumerate(feature_ids):
   # if not cls_token_or_not[i]:
   #     continue
    print("FEATURE", feature_id, "HUMAN DESCRIPTION", descriptions[i])
    outputs = run_model_with_steering(random_batch, feature_id,boost_amount)



    scores = einsum( "I D, T D -> I T", outputs,  clip_text_stuff.labels_projected)

    diff = scores - default_scores

    avg_diffs = torch.mean(scores, dim=0) 
    top_k_values, top_k_indices = torch.topk(avg_diffs, top_k, dim=0)
   # max_diffs = torch.max(scores, dim=0)[0]
   # top_k_values, top_k_indices = torch.topk(max_diffs, top_k, dim=0)
    for j, ind in enumerate(top_k_indices):
        ind = ind.item()
        print(f"{j}. {clip_text_stuff.labels[ind]}")



FEATURE 2381 HUMAN DESCRIPTION mustache
0. Detailed illustration
1. A gem
2. A painting
3. classic fine art piece
4. An elegant photo
5. Playful juxtaposition
6. Timeless fine art piece
7. A photo with a texture of mammals
8. Striking juxtaposition
9. Playful scenes
FEATURE 6405 HUMAN DESCRIPTION bike
0. An elegant photo
1. A boot
2. A gem
3. A floor
4. A zoomed out photo
5. A beautiful photo
6. An animal
7. A low-resolution image
8. Ocean
9. A leg
FEATURE 5379 HUMAN DESCRIPTION food market (CLS TOKEN!)
0. A leg
1. A close-up shot
2. A gem
3. A badge
4. A low-resolution image
5. A high-resolution image
6. Close-up view
7. Ocean
8. An animal
9. A portrait
FEATURE 6827 HUMAN DESCRIPTION black and white photo (CLS TOKEN!)
0. A leg
1. A zoomed out photo
2. A gem
3. Ocean
4. A low-resolution image
5. A zoomed in photo
6. A photograph of a big object
7. Posed shot
8. A close-up shot
9. A beautiful photo
FEATURE 3471 HUMAN DESCRIPTION computer hardware (CLS TOKEN!)
0. A beautiful photo
1. Oce