In [72]:
## Imports
import numpy as np
import torch
from tabulate import tabulate
import pandas as pd
from PIL import Image
import heapq
import json
import gc
from sortedcontainers import SortedList
from torch.utils.data import DataLoader
from utils.factory import create_model_and_transforms, get_tokenizer
from utils.visualization import image_grid, visualization_preprocess
from prs_hook import hook_prs_logger
from matplotlib import pyplot as plt
from utils.imagenet_classes import imagenet_classes
from compute_complete_text_set import svd_data_approx

In [2]:
## Hyperparameters

device = 'cpu'
pretrained = 'laion2b_s34b_b79k' # 'laion2b_s32b_b79k'
model_name = 'ViT-B-32' # 'ViT-H-14'
seed = 42
dataset_text_name = "image_descriptions_general"
datataset_image_name = "imagenet"
algorithm = "svd_data_approx"
batch_size = 16 # only needed for the nn search
imagenet_path = './datasets/imagenet/' # only needed for the nn search

In [3]:
## Loading Model

model, _, preprocess = create_model_and_transforms(model_name, pretrained=pretrained)
model.to(device)
model.eval()
context_length = model.context_length
vocab_size = model.vocab_size
tokenizer = get_tokenizer(model_name)

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Len of res:", len(model.visual.transformer.resblocks))

prs = hook_prs_logger(model, device, spatial=False) # This attach hook to get the residual stream

Using local files


  checkpoint = torch.load(checkpoint_path, map_location=map_location)


Model parameters: 151,277,313
Context length: 77
Vocab size: 49408
Len of res: 12


In [62]:
## Run algorithm on a dataset to derive eigenvectors 
command = f"python compute_complete_text_set.py --device {device} --model {model_name} --algorithm {algorithm} --seed {seed} --num_of_last_layers 4 --text_descriptions {dataset_text_name}"
!{command}

Number of layers: 12
100%|█████████████████████████████████████████████| 8/8 [00:00<00:00, 53.20it/s]
  0%|                                                     | 0/4 [00:00<?, ?it/s]
Layer [8], Head: 0

Layer [8], Head: 1

Layer [8], Head: 2

Layer [8], Head: 3

Layer [8], Head: 4

Layer [8], Head: 5

Layer [8], Head: 6

Layer [8], Head: 7

Layer [8], Head: 8

Layer [8], Head: 9

Layer [8], Head: 10

Layer [8], Head: 11
 25%|███████████▎                                 | 1/4 [00:04<00:13,  4.42s/it]
Layer [9], Head: 0

Layer [9], Head: 1

Layer [9], Head: 2

Layer [9], Head: 3

Layer [9], Head: 4

Layer [9], Head: 5

Layer [9], Head: 6

Layer [9], Head: 7

Layer [9], Head: 8

Layer [9], Head: 9

Layer [9], Head: 10

Layer [9], Head: 11
 50%|██████████████████████▌                      | 2/4 [00:08<00:08,  4.09s/it]
Layer [10], Head: 0

Layer [10], Head: 1

Layer [10], Head: 2

Layer [10], Head: 3

Layer [10], Head: 4

Layer [10], Head: 5

Layer [10], Head: 6

Layer [10], Head: 7

Layer

In [63]:
# Load the new created attention datasets
attention_dataset = f"output_dir/{datataset_image_name}_completeness_{dataset_text_name}_{model_name}_algo_{algorithm}_seed_{seed}.jsonl"

# Strongest Contributions per Dataset

In [None]:
# Number of top entries to retrieve
top_k = 40
min_heap = []
# Read JSON lines
with open(attention_dataset, "r") as json_file:
    for line in json_file:
        entry = json.loads(line)  # Parse each line as a JSON object

        if entry["head"] == -1: # Skip the last entry
            continue
        # Analyze each eigenvector    
        for i, eigenvector_data in enumerate(entry["embeddings_sort"]):
            strength_abs = eigenvector_data["strength_abs"]
            if len(min_heap) < top_k:
                heapq.heappush(min_heap, (strength_abs, i, entry))
            else:
                heapq.heappushpop(min_heap, (strength_abs, i, entry))

        
# Extract relevant details from the top k entries
top_k_entries = sorted(min_heap, key=lambda x: x[0], reverse=True)

top_k_details = [{
    "layer": entry["layer"],
    "head": entry["head"],
    "eigenvector": i,
    "strength_abs": entry["embeddings_sort"][i]["strength_abs"],
    "texts": entry["embeddings_sort"][i]["text"]
} for _, i, entry in top_k_entries]

# Display the results
top_k_df = pd.DataFrame(top_k_details)

for row in top_k_df.itertuples():
    output_rows = []
    texts = row.texts
    half_length = len(texts) // 2
    
    # Check if the first half is positive
    is_positive_first = list(texts[0].values())[1] > 0
    
    # Split into positive and negative based on the order
    positive_texts = texts[:half_length]
    negative_texts = texts[half_length:]
    
    for pos, neg in zip(positive_texts, negative_texts):
        pos_text = list(pos.values())[0]
        pos_val  = list(pos.values())[1]
        neg_text = list(neg.values())[0]
        neg_val  = list(neg.values())[1]    
        
        output_rows.append([pos_text, pos_val, neg_text, neg_val])

    print(f"Layer {row.layer}, Head {row.head}, Eigenvector {row.eigenvector}, Strength {row.strength_abs}")
    # Create a DataFrame for the output
    if is_positive_first:
        columns = ["Positive", "Positive_Strength", "Negative", "Negative_Strength"]    
    else:
        columns = ["Negative", "Negative_Strength", "Positive", "Positive_Strength"]    
    output_df = pd.DataFrame(output_rows, columns=columns)

    print(tabulate(output_df, headers='keys', tablefmt='psql'))


Layer 11, Head 7, Eigenvector 0, Strength 4.590941429138184
+----+-------------------------+---------------------+---------------------------------------+---------------------+
|    | Negative                |   Negative_Strength | Positive                              |   Positive_Strength |
|----+-------------------------+---------------------+---------------------------------------+---------------------|
|  0 | An image with dogs      |           -0.248193 | Image with a dragonfly                |           0.125204  |
|  1 | A dog                   |           -0.212832 | A scorpion                            |           0.102048  |
|  2 | Photo of a furry animal |           -0.13023  | Artwork featuring zebra stripe motifs |           0.0806464 |
|  3 | A wolf                  |           -0.124596 | A snail                               |           0.0799246 |
|  4 | A paw                   |           -0.123882 | Captivating curves                    |           0.0786279 |
+---

# Query a topic

In [None]:
# Number of top entries to retrieve
top_k = 10
text_query = "The concept of music"

# Evaluate clip embedding for the given text
text_query_token = tokenizer(text_query).to(device)  # tokenize
topic_emb = model.encode_text(text_query_token)
topic_emb /= topic_emb.norm(dim=-1, keepdim=True)  # normalize

top_entries = SortedList()
# Read JSON lines
with open(attention_dataset, "r") as json_file:
    for line in json_file:
        entry = json.loads(line)  # Parse each line as a JSON object

        if entry["head"] == -1:  # Skip the final embedding entry
            continue
        # Get necessary reconstruction data
        mean_text = torch.tensor(entry["mean_values_text"])
        vh = torch.tensor(entry["vh"])
        project_matrix = torch.tensor(entry["project_matrix"])

        # Get projection of text on the head
        topic_emb_proj = (topic_emb - mean_text) @ vh.T
        # Analyze each eigenvectors
        top_k_eigv = 20
        for i, eigenvector_data in enumerate(entry["embeddings_sort"]):
            if i == top_k_eigv:
                break

            # Get the eigenvector direction
            eigen_v_emb = torch.tensor(eigenvector_data["eigen_v_emb"]) @ vh.T
            
            # Build correlation matrix
            text_corr = eigen_v_emb @ topic_emb_proj.T

            # Evaluate correlation as the maximum value
            corr_sign = text_corr > 0
            correlation = torch.abs(text_corr).item()

            top_entries.add((correlation, (i, corr_sign, entry)))

            if len(top_entries) > top_k:
                top_entries.pop(0)  # Remove the smallest correlation

# Extract relevant details from the top k entries
top_k_entries = list(top_entries)[::-1]  # Reverse to have largest first

top_k_details = [{
    "layer": entry["layer"],
    "head": entry["head"],
    "eigenvector": i,
    "eigenvector_strength": entry["embeddings_sort"][i]["strength_abs"],
    "correlation": correlation if corr_sign else -correlation,
    "texts": sorted(
    entry["embeddings_sort"][i]["text"],
    key=lambda x: list(x.values())[1],
    reverse=(corr_sign != (list(entry["embeddings_sort"][i]["text"][0].values())[1] > 0))
)} for (correlation, (i, corr_sign, entry)) in top_k_entries]

top_k_df = pd.DataFrame(top_k_details)
# Display the results
for row in top_k_df.itertuples():
    output_rows = []
    texts = row.texts
    half_length = len(texts) // 2
    
    # Check if the first half is positive
    is_positive_corr = row.correlation > 0
    is_positive_first = list(texts[0].values())[1] > 0

    # Create a DataFrame for the output
    if is_positive_corr:
        columns = ["Positive", "Positive_Strength", "Negative", "Negative_Strength"] 
        first = texts[:half_length] if is_positive_first else texts[half_length:]
        second = texts[half_length:] if is_positive_first else texts[:half_length] 
    else:
        columns = ["Negative", "Negative_Strength", "Positive", "Positive_Strength"]   
        first = texts[half_length:] if is_positive_first else texts[:half_length]
        second = texts[:half_length] if is_positive_first else texts[half_length:]
        
    for pos, neg in zip(first, second):
        pos_text = list(pos.values())[0]
        pos_val  = list(pos.values())[1]
        neg_text = list(neg.values())[0]
        neg_val  = list(neg.values())[1]    
        
        output_rows.append([pos_text, pos_val, neg_text, neg_val])

    print(f"Layer {row.layer}, Head {row.head}, Eigenvector {row.eigenvector}, Eigenvector Strength {row.eigenvector_strength}, Correlation {row.correlation}")
     
    output_df = pd.DataFrame(output_rows, columns=columns)

    print(tabulate(output_df, headers='keys', tablefmt='psql'))


Layer 9, Head 10, Eigenvector 5, Eigenvector Strength 0.5183016061782837, Correlation -0.13885445892810822
+----+------------------------------+---------------------+----------------------------------------+---------------------+
|    | Negative                     |   Negative_Strength | Positive                               |   Positive_Strength |
|----+------------------------------+---------------------+----------------------------------------+---------------------|
|  0 | Intense facial expression    |          -0.0906498 | Striking fashion silhouette            |           0.110057  |
|  1 | An image of a Paramedic      |          -0.0908404 | Photograph taken in a fashion boutique |           0.10244   |
|  2 | Determined facial expression |          -0.0961331 | Urban street fashion                   |           0.0975344 |
|  3 | Grumpy facial expression     |          -0.0965072 | Sunlit meadow path                     |           0.0940897 |
|  4 | Skeptical facial expressi

# Test accuracy of reconstruction using only basis

In [79]:
# Number of top entries to retrieve
top_k = 20
min_heap = []

# Prepare both text and image for the query

# Image
image = preprocess(Image.open('images/woman.png'))[np.newaxis, :, :, :]
## Run the image:
prs.reinit() # Reinitialize the residual stream hook
with torch.no_grad():
    representation = model.encode_image(image.to(device), 
                                        attn_method='head_no_spatial', 
                                        normalize=False)
    attentions, mlps = prs.finalize(representation)  # attentions: [1, 16, 16, 512], [b, l, h, d] & mlps: [1, 17, 1024], [b, l + 1 (class), d]

image_emb = representation / representation.norm(dim=-1, keepdim=True)  # normalize

# Text
text_query = "A beautiful woman."
text_query_token = tokenizer(text_query).to(device)  # tokenize
topic_emb = model.encode_text(text_query_token)
topic_emb /= topic_emb.norm(dim=-1, keepdim=True)  # normalize

# Reconstructions
image_emb_rec = torch.zeros_like(topic_emb)
topic_emb_rec = torch.zeros_like(topic_emb)
# Read JSON lines
with open(attention_dataset, "r") as json_file:
    for line in json_file:

        
        entry = json.loads(line)  # Parse each line as a JSON object
        if entry["head"] == -1:  # Skip the final embedding entry
            last_line = entry
            continue

        project_matrix = torch.tensor(entry["project_matrix"])
        vh = torch.tensor(entry["vh"])
        # Get projection of text on the head
        topic_emb_rec += (topic_emb - torch.tensor(entry["mean_values_text"])) @ vh.T @ project_matrix @ vh+ torch.tensor(entry["mean_values_text"])
        image_emb_rec += (image_emb - torch.tensor(entry["mean_values_att"])) @ vh.T @ project_matrix @ vh + torch.tensor(entry["mean_values_att"])

print(topic_emb_rec.norm()) 
print(image_emb_rec.norm())
topic_emb_rec /= topic_emb_rec.norm(dim=-1, keepdim=True)  # normalize
image_emb_rec /= image_emb_rec.norm(dim=-1, keepdim=True)  # normalize  
print(topic_emb @ topic_emb_rec.T)
print(image_emb @ image_emb_rec.T)

print(topic_emb @ image_emb.T)
print(topic_emb @ image_emb_rec.T)
print(topic_emb_rec @ image_emb.T)
print(topic_emb_rec @ image_emb_rec.T)

tensor(35.3541, grad_fn=<LinalgVectorNormBackward0>)
tensor(10.6186)
tensor([[0.8530]], grad_fn=<MmBackward0>)
tensor([[0.8454]])
tensor([[0.2579]], grad_fn=<MmBackward0>)
tensor([[0.1799]], grad_fn=<MmBackward0>)
tensor([[0.2091]], grad_fn=<MmBackward0>)
tensor([[0.1413]], grad_fn=<MmBackward0>)
