In [None]:
## Imports
import numpy as np
import torch
from tabulate import tabulate
from PIL import Image
import json
from utils.misc.misc import accuracy, accuracy_correct
from utils.scripts.algorithms_text_explanations import *
from utils.models.factory import create_model_and_transforms, get_tokenizer
from utils.misc.visualization import visualization_preprocess
from utils.models.prs_hook import hook_prs_logger
from utils.datasets_constants.imagenet_classes import imagenet_classes
from utils.scripts.algorithms_text_explanations import svd_data_approx
from utils.datasets.dataset_helpers import dataset_to_dataloader
from torch.nn import functional as F
from utils.scripts.algorithms_text_explanations_funcs import *


In [2]:
## Parameters
device = 'cpu'
pretrained = 'laion2b_s34b_b79k' # 'laion2b_s32b_b79k'
model_name = 'ViT-B-32' # 'ViT-H-14'
seed = 0
num_last_layers = 4
subset_dim = 10
dataset_text_name = "top_1500_nouns_5_sentences_imagenet_clean"
datataset_image_name = "imagenet"
algorithm = "svd_data_approx"
batch_size = 16 # only needed for the nn search
imagenet_path = './datasets/imagenet/' # only needed for the nn search

In [None]:
## Loading Model
model, _, preprocess = create_model_and_transforms(model_name, pretrained=pretrained, cache_dir="../cache")
model.to(device)
model.eval()
context_length = model.context_length
vocab_size = model.vocab_size
tokenizer = get_tokenizer(model_name)

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Len of res:", len(model.visual.transformer.resblocks))

prs = hook_prs_logger(model, device, spatial=False) # This attach hook to get the residual stream

In [None]:
## Run the chosen algorithm on a dataset to derive text explanations 
command = f"python -m utils.scripts.compute_text_explanations --device {device} --model {model_name} --algorithm {algorithm} --seed {seed} --text_per_princ_comp 20 --num_of_last_layers {num_last_layers} --text_descriptions {dataset_text_name}"
!{command}

In [5]:
# Load the new created attention datasets
attention_dataset = f"output_dir/{datataset_image_name}_completeness_{dataset_text_name}_{model_name}_algo_{algorithm}_seed_{seed}.jsonl"

# Load necessary data
attns_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_attn_{model_name}_seed_{seed}.npy", mmap_mode="r"))  # [b, l, h, d], attention values
mlps_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_mlp_{model_name}_seed_{seed}.npy", mmap_mode="r"))  # [b, l, h, d], mlp values
classifier_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_classifier_{model_name}.npy", mmap_mode="r"))  # [b, l, h, d], embedding of the labels
labels_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_labels_{model_name}_seed_{seed}.npy", mmap_mode="r")) # Position of the labels in the cosndiered dataset
final_embeddings_images = torch.tensor(np.load(f"output_dir/{datataset_image_name}_embeddings_{model_name}_seed_{seed}.npy", mmap_mode="r"))
final_embeddings_texts = torch.tensor(np.load(f"output_dir/{dataset_text_name}_{model_name}.npy", mmap_mode="r"))
with open( f"utils/text_descriptions/{dataset_text_name}.txt", "r") as f:
    texts_str = np.array([i.replace("\n", "") for i in f.readlines()])
# Get mean ablation
no_heads_attentions_ = attns_.sum(axis=(2))  # Sum over heads dimension
last_ = attns_.shape[1] - num_last_layers
# Replace attention activations until 'last' layer with their average, while keeping later layers intact.
current_mean_ablation_per_head_sum_ = torch.mean(no_heads_attentions_[:, :last_ + 1], axis=0).sum(0)

# Save important stuff
nr_layers_ = attns_.shape[1]
nr_heads_ = attns_.shape[2]

# Print the top Principal Components text-interpretation for each Head

In [None]:
data = []
min_princ_comp = 10

# Read JSON lines from attention_dataset
# This file contains data about layers, heads, and their principal components (PCs) with associated metrics.
data = get_data(attention_dataset, -1)
    
# Print the data in a nice formatted table
print_data(data, min_princ_comp)

# Strongest Principal Components per Dataset

In [None]:
# Number of top entries to retrieve
top_k = 10
min_heap = []

# Retrieve data
data = get_data(attention_dataset, -1, skip_final=True)

# Sort data entries in descending order of strength_abs of the princial component
top_k_entries = top_data(sort_data_by(data, "strength_abs", descending=True), top_k=top_k)

# Print the top_k entries in a nice formatted table
print_used_heads(top_k_entries)
print_data(top_k_entries)


# Visualize singular values of a principal component (both text and images)

In [8]:
# Info on data
layer = 10
head = 7
princ_comp = 1
# Define the number of top and worst images to look at for each princ_comp
nr_top_imgs = 20  # Number of top elements
nr_worst_imgs = 20  # Number of worst elements
nr_cont_imgs = 0  # Length of continuous elements


In [None]:
## OPT. Visualize textSpan
attention_dataset_ts = f"output_dir/{datataset_image_name}_completeness_{dataset_text_name}_{model_name}_algo_text_span_seed_{seed}.jsonl"

visualize_text_span(layer, head, attention_dataset_ts, top_k= top_k)


In [None]:
visualize_principal_component(layer, head, princ_comp, nr_top_imgs, nr_worst_imgs, nr_cont_imgs, attention_dataset, final_embeddings_images, final_embeddings_texts, seed, imagenet_path, texts_str, imagenet_classes, samples_per_class=subset_dim)

In [None]:
# Visualize PCs strength
data = get_data(attention_dataset)
plot_pc_sv(data, layer, head)

# Test accuracy of reconstruction of text and images using only the final embedding and their projections

In [None]:
# Number of top entries to retrieve
top_k = 1
min_heap = []
image = preprocess(Image.open('images/woman.png'))[np.newaxis, :, :, :]  # Add batch dimension
text_query = "An image of a woman."

# Encode the image
prs.reinit()  # Reinitialize the residual stream hook

# Encode the image with no gradient calculation
with torch.no_grad():
      image_emb = model.encode_image(
        image.to(device),
        attn_method='head_no_spatial',
        normalize=True)

      # Encode the text
      text_query_token = tokenizer(text_query).to(device)  # Tokenize the text query
      topic_emb = model.encode_text(text_query_token, normalize=True)  # Encode the text query

# Retrieve data
data = get_data(attention_dataset, -1, skip_final=True)

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0).to(device)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0).to(device)

# Mean center the embeddings
topic_emb -= mean_final_texts
image_emb -= mean_final_images

# Iterate through the attention dataset and reconstruct embeddings
[topic_emb_rec, image_emb_rec], _ = reconstruct_embeddings(data, [topic_emb, image_emb], ["text", "image"], device=device)


# Print norms to understand magnitude before normalization
print("Norm of topic_emb_rec before normalization:", topic_emb_rec.norm().item())
print("Norm of image_emb_rec before normalization:", image_emb_rec.norm().item())

# Normalize the reconstructed embeddings so they lie on the unit sphere
topic_emb_rec /= topic_emb_rec.norm(dim=-1, keepdim=True)
image_emb_rec /= image_emb_rec.norm(dim=-1, keepdim=True)

topic_emb /= topic_emb.norm(dim=-1, keepdim=True)
image_emb /= image_emb.norm(dim=-1, keepdim=True)

# Print similarities between original and reconstructed embeddings
print("Cosine similarity between original topic_emb and reconstructed topic_emb_rec:",
      (topic_emb @ topic_emb_rec.T).item())

print("Cosine similarity between original image_emb and reconstructed image_emb_rec:",
      (image_emb @ image_emb_rec.T).item())

# Print cross-similarities to compare text-image embeddings before and after reconstruction
print("Cosine similarity between original topic_emb and original image_emb:",
      (topic_emb @ image_emb.T).item())

print("Cosine similarity between original topic_emb and reconstructed image_emb_rec:",
      (topic_emb @ image_emb_rec.T).item())

print("Cosine similarity between reconstructed topic_emb_rec and original image_emb:",
      (topic_emb_rec @ image_emb.T).item())

print("Cosine similarity between reconstructed topic_emb_rec and reconstructed image_emb_rec:",
      (topic_emb_rec @ image_emb_rec.T).item())


# Query a topic or image and NNs on that

### Define the query and analyze each Principal Component and derive a strength metric for reconstruction of the query-embedding

In [None]:
# Set the model to evaluation mode so that no gradients are computed
model.eval()
query_text = True

# Retrieve an embedding
with torch.no_grad():
    if query_text:
        # If querying by text, define a text prompt and encode it into an embedding
        text_query = "woman"
        # Tokenize the text query and move it to the device (GPU/CPU)
        text_query_token = tokenizer(text_query).to(device)  
        # Encode the tokenized text into a normalized embedding
        topic_emb = model.encode_text(text_query_token, normalize=True)
    else:
        # If querying by image, load and preprocess the image from disk
        prs.reinit()  # Reinitialize any hooks if required
        text_query = "woman.png"
        image_pil = Image.open(f'images/{text_query}')
        image = preprocess(image_pil)[np.newaxis, :, :, :]  # Add batch dimension
        # Encode the image into a normalized embedding
        topic_emb = model.encode_image(
            image.to(device), 
            attn_method='head_no_spatial',
            normalize=True
        )

### Reconstruct embedding and find contributions from principal components
# Retrieve data
data = get_data(attention_dataset, -1, skip_final=True)

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0).to(device)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0).to(device)

# Mean center the embeddings
mean_final = mean_final_texts if query_text else mean_final_images

# Mean center the embeddings
topic_emb_cent = topic_emb - mean_final
final_embeddings_texts_cent = final_embeddings_texts.to(device) - mean_final_texts
# Recontruct embedding
[topic_emb_rec_cent], data = reconstruct_embeddings(data, [topic_emb_cent], ["text" if query_text else "image"], return_princ_comp=True, plot=True, means=[mean_final], device=device)

# Normalize the embeddings
topic_emb_rec_cent_norm = topic_emb_rec_cent / topic_emb_rec_cent.norm(dim=-1, keepdim=True)

# The maximum reconstruction score is how close the reconstructed embedding is to the original,
# adjusted by the baseline score. This gives a cosine similarity measure.
topic_emb_cent_norm = topic_emb_cent / topic_emb_cent.norm(dim=-1, keepdim=True)
max_reconstr_score = topic_emb_rec_cent_norm @ topic_emb_cent_norm.T
# Print out the cosine similarity between the original and reconstructed embeddings
print(f"We have a max cosine similarity of: {(max_reconstr_score).item():.4f}")


### Use the strength of the previous reconstruction to derive a good enough reconstruction of the query

In [None]:
# Extract relevant details from the top-k principal component entries based on the reconstruction of the query
top_k = 50  # Maximum number of top entries to retrieve
approx = 1.1  # Target approximation threshold for the reconstruction quality

# Initialize a tensor to accumulate the reconstructed topic embedding from selected principal components
topic_emb_rec_act = torch.zeros_like(topic_emb)

### Extract relevant details from the top k entries
data = sort_data_by(data, "correlation_princ_comp_abs", descending=True) 

top_k_entries = top_data(data, top_k)

top_k_details = reconstruct_top_embedding(top_k_entries, topic_emb_cent, mean_final, "text" if query_text else "image", max_reconstr_score, top_k, approx, device=device)
# Convert the collected principal component details into a DataFrame for easy processing
print(f"Currently querying the topic: {text_query}")
print_data(top_k_details, is_corr_present=True)

### Prepare scores of images and texts 

In [None]:
# VISUALIZE OTHERS PCs (uncomment here)
#data = get_data(attention_dataset, -1, skip_final=True)
#top_k_entries = get_remaining_pcs(data, top_k_entries)


## For Reconstructed Embedding
# Visualize ds
ds_vis = create_dataset_imagenet(imagenet_path, visualization_preprocess, samples_per_class=subset_dim, tot_samples_per_class=50, seed=seed)
# Initialize arrays to store the top and lowest scores based on similarity with original query
scores_array_images = np.empty(
    final_embeddings_images.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('img_index', 'i4')]
)
scores_array_texts = np.empty(
    final_embeddings_texts.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('txt_index', 'i4')]
)
# Initialize arrays to store the top and lowest scores based on similarity with self reconstructed query
scores_array_images_self = np.empty(
    final_embeddings_images.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('img_index', 'i4')]
)
scores_array_texts_self = np.empty(
    final_embeddings_texts.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('txt_index', 'i4')]
)

# Create arrays of indexes for referencing images and texts.
indexes_images = np.arange(0, final_embeddings_images.shape[0], 1) 
indexes_texts = np.arange(0, final_embeddings_texts.shape[0], 1) 

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0).to(device)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0).to(device)

images_centered = final_embeddings_images.to(device) - mean_final_images
texts_centered = final_embeddings_texts.to(device) - mean_final_texts

# Recontruct embedding
[texts_rec_cent, images_rec_cent], _ = reconstruct_embeddings(top_k_entries, [texts_centered, images_centered], ["text", "image"], return_princ_comp=False, device=device)

texts_rec = texts_rec_cent
images_rec = images_rec_cent
# Compute the similarity scores between the reconstructed embeddings (images or texts) and the original query embedding.
# The dot product gives a similarity measure, which we store in the scores arrays along with the index.
# We do NOT normalize the score.
# Compute scores for images

scores_array_images["score"] = (images_rec @ topic_emb_cent.T).squeeze().cpu().numpy()
scores_array_images_self["score"] = (torch.diag(images_rec @ images_centered.T)).squeeze().cpu().numpy()

images_rec /= images_rec.norm(dim=-1, keepdim=True)

scores_array_images["score_vis"] = (images_rec @ topic_emb_cent.T).squeeze().cpu().numpy()
scores_array_images_self["score_vis"] = (torch.diag(images_rec @ images_centered.T)).squeeze().cpu().numpy()

scores_array_images["img_index"] = indexes_images
scores_array_images_self["img_index"] = indexes_images

# Compute scores for texts
scores_array_texts["score"] = (texts_rec @ topic_emb_cent.T).squeeze().cpu().numpy()
scores_array_texts_self["score"] = (torch.diag(texts_rec @ texts_centered.T)).squeeze().cpu().numpy()

texts_rec /= texts_rec.norm(dim=-1, keepdim=True)
scores_array_texts["score_vis"] = (texts_rec @ topic_emb_cent.T).squeeze().cpu().numpy()
scores_array_texts_self["score_vis"] = (torch.diag(texts_rec @ texts_centered.T)).squeeze().cpu().numpy()

scores_array_texts["txt_index"] = indexes_texts
scores_array_texts_self["txt_index"] = indexes_texts

In [20]:
# For full CLIP Embedding
# Scores array of real CLIP embeddings
scores_array_images_full = np.empty(
    final_embeddings_images.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('img_index', 'i4')]
)
scores_array_texts_full = np.empty(
    final_embeddings_texts.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('txt_index', 'i4')]
)

# Compute scores for images
images = final_embeddings_images.to(device)
scores_array_images_full["score"] = (images @ topic_emb.T).squeeze().cpu().numpy()

images /= images.norm(dim=-1, keepdim=True)
scores_array_images_full["score_vis"] = (images @ topic_emb.T).squeeze().cpu().numpy()

scores_array_images_full["img_index"] = indexes_images

# Compute scores for texts
texts = final_embeddings_texts.to(device)
scores_array_texts_full["score"] = (texts @ topic_emb.T).squeeze().cpu().numpy()

texts /= texts.norm(dim=-1, keepdim=True)
scores_array_texts_full["score_vis"] = (texts @ topic_emb.T).squeeze().cpu().numpy()

scores_array_texts_full["txt_index"] = indexes_texts

# Define the number of top and worst images to look at for each princ_comp
nr_top_imgs = 20  # Number of top elements
nr_worst_imgs = 20  # Number of worst elements
nr_cont_imgs = 20  # Length of continuous elements

### Visualize

In [None]:
# Define the number of top and worst images to look at for each princ_comp
nr_top_imgs = 20  # Number of top elements
nr_worst_imgs = 20  # Number of worst elements
nr_cont_imgs = 0  # Length of continuous elements

dbs = create_dbs(scores_array_images, scores_array_texts, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
# Hardcoded visualizations
nrs_dbs = [nr_top_imgs, nr_worst_imgs, nr_cont_imgs]
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis, texts_str, imagenet_classes, text_query)

In [None]:
# Visualize full Embeddings similarity
dbs = create_dbs(scores_array_images_full, scores_array_texts_full, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis, texts_str, imagenet_classes, text_query)

In [None]:
# Define the number of top and worst images to look at for each princ_comp
dbs = create_dbs(scores_array_images_self, scores_array_texts_self, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis, texts_str, imagenet_classes, text_query)

# Evaluate classification using reconstruction

## Ablation Study

In [None]:
nr_layers = attns_.shape[1]
# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
accs = []
for layer_nr in range(nr_layers):
    current_mean_ablation_per_head_sum = torch.mean(no_heads_attentions_[:, :layer_nr], axis=0).sum(0)
    current_model = (current_mean_ablation_per_head_sum  + no_heads_attentions_[:, layer_nr + 1:].sum(1)) + mlps_.sum(axis=1) 
    acc, _ = test_accuracy(current_model @ classifier_, labels_, label=f"Mean ablation from layer {nr_layers - layer_nr} to {nr_layers}")
    accs.append(acc)

# Create an x-axis that has one increment for each element in acc
x_values = range(len(accs))


# Plot
plt.figure(figsize=(8, 5))
plt.plot(x_values, accs, linestyle='-', label=model_name)

# Labeling
plt.xlabel("Accumulated mean-ablated layers")
plt.ylabel("Accuracy")
plt.title("Accuracy vs. Accumulated Mean-Ablated Layers")

# Add legend for the line
plt.legend()

plt.grid(True)
plt.show()

## Proof of concept

In [None]:
pcs_per_class = 400
max_approx_per_class = 1
class_embeddings = classifier_.T  # M x D

# Print baseline accuracy
# Baseline accuracy computation:
baseline = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1)
acc, _ = test_accuracy(baseline @ classifier_, labels_, label="Baseline")

# Reconstruct embeddings for each class label

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

images_centered = final_embeddings_images - mean_final_images.unsqueeze(0)
classes_centered = class_embeddings - mean_final_texts.unsqueeze(0)

# Initialize a (num_images x 2) array to track:
#   [best_score_so_far, class_index_for_that_score]
all_preds = torch.zeros((final_embeddings_images.shape[0], 2), dtype=torch.double)

for text_idx in range(classes_centered.shape[0]):
    # Perform query system on entry
    class_i_centered = classes_centered[text_idx, :].unsqueeze(0)

    data = get_data(attention_dataset, -1, skip_final=True)

    [class_i_rec], data = reconstruct_embeddings(
        data, 
        [class_i_centered], 
        ["text"], 
        return_princ_comp=True, 
        plot=False, 
        means=[mean_final_texts]
    )

    # Normalize the embeddings
    class_i_rec_norm = class_i_rec / class_i_rec.norm(dim=-1, keepdim=True)
    class_i_rec_norm += mean_final_texts
    class_i_rec_norm /= class_i_rec_norm.norm(dim=-1, keepdim=True)
    # The maximum reconstruction score is how close the reconstructed embedding is
    # to the original, adjusted by the baseline score. This is a cosine similarity measure.
    class_i_centered_norm = class_i_centered / class_i_centered.norm(dim=-1, keepdim=True)
    max_reconstr_score = class_i_rec_norm @ class_i_rec_norm.T

    # Extract relevant details from the top k entries
    data_pcs = sort_data_by(data, "correlation_princ_comp_abs", descending=True)
    top_k_entries = top_data(data_pcs, pcs_per_class)

    top_k_details = reconstruct_top_embedding(
        top_k_entries,
        class_i_centered,
        mean_final_texts,
        "text",
        max_reconstr_score,
        pcs_per_class,
        max_approx_per_class,
        plot=False
    )

    # Reconstruct final_embeddings_images
    [images_rec_cent], _ = reconstruct_embeddings_proj(
        top_k_entries, 
        [images_centered], 
        ["image"], 
        return_princ_comp=False
    )

    # Compute predictions for all images under the current text_idx
    images_rec_cent = images_rec_cent / images_rec_cent.norm(dim=-1, keepdim=True)

    images_rec_cent += mean_final_images
    # predictions = torch.diag(images_rec_cent @ final_embeddings_images.T)
    predictions = images_rec_cent @ class_i_rec_norm.squeeze(0).T #class_embeddings[text_idx, :].T

    # Update "best so far" scores in all_preds
    best_vals_this_round = predictions
    improved_mask = best_vals_this_round > all_preds[:, 0]

    best_idxs_this_round = torch.full_like(all_preds[:, 1], fill_value=text_idx)
    all_preds[improved_mask, 0] = best_vals_this_round[improved_mask].double()
    all_preds[improved_mask, 1] = best_idxs_this_round[improved_mask].double()

    # Optionally, check accuracy for the current text_idx predictions
    acc, idxs = test_accuracy(predictions.unsqueeze(-1), labels_, label=f"{imagenet_classes[text_idx]}")
    print_correct_elements(idxs, subset_dim)

    # Build a fictitious one-hot matrix from all_preds
    num_images = final_embeddings_images.shape[0]
    num_classes = classifier_.shape[1]  # Typically M x D => M classes => classifier_.shape[1] is #classes

    # Convert the best class index to a LongTensor
    best_class_idxs = all_preds[:, 1].long()

    # Create zero matrix [num_images, num_classes]
    fictitious_preds = torch.zeros((num_images, num_classes), device=best_class_idxs.device)

    # Fill 1.0 in the best predicted class for each image
    fictitious_preds[torch.arange(num_images), best_class_idxs] = 1.0

    # Test accuracy on these "hard" predictions
    acc_best, idxs_best = test_accuracy(fictitious_preds, labels_, label="Best So Far (One-Hot)")
    sorted_output = print_correct_elements(idxs_best, subset_dim)

    # Print overall accuracy so far
    tot_sum = 0
    for _, el_nr in sorted_output:
        tot_sum += el_nr
    print(f"Tot accuracy so far is {tot_sum/((text_idx + 1) * subset_dim)}")    


## Test different accuracies

In [None]:
# Print shapes of the tensors for debugging purposes:
# attns_: attention activations
# mlps_: MLP activations
# classifier_: classifier weights
# labels_: ground truth labels
print(attns_.shape, mlps_.shape, classifier_.shape, labels_.shape)


# Baseline accuracy computation:
baseline = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1)
test_accuracy(baseline @ classifier_, labels_, label="Baseline")
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

# Test accuracy of mean centered data with mean centered text
mean_centered_data = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1) - mean_final_images
mean_centered_data /= mean_centered_data.norm(dim=-1, keepdim=True)
mean_centered_classifier_ = classifier_ - mean_final_texts.unsqueeze(-1)
mean_centered_classifier_ /= mean_centered_classifier_.norm(dim=-1, keepdim=True)

test_accuracy(mean_centered_data @ mean_centered_classifier_, labels_, label="Mean centered data with mean centered text")

# Test accuracy of mean centered data with original text
mean_centered_data = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1) - mean_final_images
mean_centered_data /= mean_centered_data.norm(dim=-1, keepdim=True)
mean_centered_data += mean_final_images
mean_centered_data /= mean_centered_data.norm(dim=-1, keepdim=True)
test_accuracy(mean_centered_data @ classifier_, labels_, label="Mean centered data with original (not mean centered) text")

# We now attempt a "mean ablation" approach for attention
current_model = (current_mean_ablation_per_head_sum_
                 + no_heads_attentions_[:, last_ + 1:].sum(1)) + mlps_.sum(axis=1) 
_, indexes_mean_ablate = test_accuracy(current_model @ classifier_, labels_, label=f"Mean ablation from layer {last_} until layer {attns_.shape[1]}")

# We now attempt a "mean ablation" approach for attention
current_model = (current_mean_ablation_per_head_sum_
                 + no_heads_attentions_[:, last_ + 1:].sum(1)) + mlps_.sum(axis=1) 
current_model -= mean_final_images
current_model /= current_model.norm(dim=-1, keepdim=True)
_, indexes_mean_ablate = test_accuracy(current_model @ classifier_, labels_, label=f"Mean ablation from layer {last_} until layer {attns_.shape[1]} with mean centered images")

## Test different accuracies using reconstructions

In [None]:
final_embeddings_images_rec_embed = torch.zeros_like(final_embeddings_images)
final_embeddings_images_rec_attns = torch.zeros_like(final_embeddings_images)
final_embeddings_images_rec_attns_not_mean_centered = torch.zeros_like(final_embeddings_images)
image_emb_cent_embed = final_embeddings_images - mean_final_images

final_embeddings_texts_rec_embed = torch.zeros_like(classifier_.T)
texts_emb_cent_embed = (classifier_ - mean_final_texts.unsqueeze(-1)).T

# Open the attention dataset to retrieve projection matrices and mean values
with open(attention_dataset, "r") as json_file:
    for line in json_file:
        entry = json.loads(line)
        # If this entry is the final embedding entry (head == -1), skip it.
        if entry["head"] == -1:
            last_line = entry
            continue

        project_matrix = torch.tensor(entry["project_matrix"])
        vh = torch.tensor(entry["vh"])
        # Reconstruct the image embeddings using final embeddings:
        # Center them by subtracting mean attention values, project them through vh, 
        # apply project_matrix and vh again, then add mean values back.
        final_embeddings_images_rec_embed += (image_emb_cent_embed) @ vh.T @ project_matrix @ vh
        final_embeddings_texts_rec_embed += (texts_emb_cent_embed) @ vh.T @ project_matrix @ vh
        # Reconstruct the image embeddings using attention activations:
        # Similar process, but start from attns_ for the given layer/head.
        image_emb_cent_attns = attns_[:, entry["layer"], entry["head"], :] - torch.tensor(entry["mean_values_att"])
        final_embeddings_images_rec_attns += (image_emb_cent_attns) @ vh.T @ project_matrix @ vh + torch.tensor(entry["mean_values_att"])
        final_embeddings_images_rec_attns_not_mean_centered += (image_emb_cent_attns) @ vh.T @ project_matrix @ vh

final_embeddings_images_rec_embed_norm = final_embeddings_images_rec_embed/final_embeddings_images_rec_embed.norm(dim=-1, keepdim=True)

final_embeddings_texts_rec_embed_norm = final_embeddings_texts_rec_embed/final_embeddings_texts_rec_embed.norm(dim=-1, keepdim=True)

final_embeddings_images_rec_attns_not_mean_centered_norm = final_embeddings_images_rec_attns_not_mean_centered/final_embeddings_images_rec_attns_not_mean_centered.norm(dim=-1, keepdim=True)

texts_emb_cent_embed /= texts_emb_cent_embed.norm(dim=-1, keepdim=True)
# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = final_embeddings_images_rec_embed_norm + mean_final_images
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation with final embeddings on only the last layers")

current_model = mlps_.sum(axis=1) + current_mean_ablation_per_head_sum_ + final_embeddings_images_rec_attns_not_mean_centered
_, indexes_approx_activ_only = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation of images with direct contribution of activation space")

# Evaluate accuracy using the reconstructed embeddings from the attention activations approach
current_model = (mlps_.sum(axis=1) + current_mean_ablation_per_head_sum_ + final_embeddings_images_rec_attns)
_, indexes_approx_activ = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation with attention activations")


# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = final_embeddings_images_rec_embed+ mean_final_images
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation of images without mean-ablation")


# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = final_embeddings_images_rec_embed+ mean_final_images
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation of images and texts without mean-ablation")

# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = (mlps_.sum(axis=1) + current_mean_ablation_per_head_sum_ + final_embeddings_images_rec_attns)
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Original images and approximation of texts without mean-ablation")


## Test Bias Correction

In [None]:
cache_dir = "../cache"
top_k = 30  # Maximum number of top entries to retrieve
approx = 1.1  # Target approximation threshold for the reconstruction quality
## Run the chosen algorithm on a dataset to derive text explanations 
command = f"python -m utils.scripts.bias_removal_test \
    --device {device} --model {model_name} --pretrained {pretrained} --seed {seed} \
    --subset_dim {subset_dim} --dataset_text {dataset_text_name} --dataset {datataset_image_name} \
    --device {device} --top_k {top_k} --max_approx {approx} --cache_dir {cache_dir}"
!{command}

## Test bias removal and subset model

In [None]:
# Next, we calculate scores for each principal component (PC) while ignoring query info.
# We do this by reconstructing embeddings from the principal components alone, both from 
# the final embeddings perspective and the attention activations perspective.

final_embeddings_images_rec_embed_topic = torch.zeros_like(final_embeddings_images)
final_embeddings_images_rec_attns_topic = torch.zeros_like(final_embeddings_images)

image_emb_cent_embed = final_embeddings_images - mean_final_images

top_k_other_details = get_remaining_pcs(data, top_k_details)
# Iterate through the top_k entries and reconstruct embeddings
for entry in top_k_details:
    # Reconstruct embeddings focusing on each principal component:
    # 1. Start from the final embeddings, center them, and extract the component of interest.
    vh = torch.tensor(entry["vh"])
    project_matrix = torch.tensor(entry["project_matrix"])
    princ_comp = torch.tensor(entry["princ_comp"])

    projection_image_embed = image_emb_cent_embed @ vh.T
    mask_images_embed = torch.zeros_like(projection_image_embed)
    mask_images_embed[:, princ_comp] = projection_image_embed[:, princ_comp]
    final_embeddings_images_rec_embed_topic += mask_images_embed @ project_matrix @ vh

    # Repeat for attention-based activations:
    mean_values_att = torch.tensor(entry["mean_values_att"])
    image_emb_cent_attns = attns_[:, entry["layer"], entry["head"], :] - mean_values_att
    projection_images_attns = image_emb_cent_attns @ vh.T
    mask_images_attns = torch.zeros_like(projection_images_attns)
    mask_images_attns[:, princ_comp] = projection_images_attns[:, princ_comp]
    final_embeddings_images_rec_attns_topic += mask_images_attns @ project_matrix @ vh

""" # Mean ablate the other components
for entry in top_k_other_details:
    # Reconstruct embeddings focusing on each principal component:
    # 1. Start from the final embeddings, center them, and extract the component of interest.
    vh = torch.tensor(entry["vh"])
    project_matrix = torch.tensor(entry["project_matrix"])
    princ_comp = torch.tensor(entry["princ_comp"])
    mean_values_att = torch.tensor(entry["mean_values_att"]).unsqueeze(0)

    projection_image_embed = mean_values_att @ vh.T
    mask_images_embed = torch.zeros_like(projection_image_embed)
    mask_images_embed[:, princ_comp] = projection_image_embed[:, princ_comp]
    final_embeddings_images_rec_embed_topic += mask_images_embed @ project_matrix @ vh

    # Repeat for attention-based activations:
    image_emb_cent_attns = mean_values_att
    projection_images_attns = image_emb_cent_attns @ vh.T
    mask_images_attns = torch.zeros_like(projection_images_attns)
    mask_images_attns[:, princ_comp] = projection_images_attns[:, princ_comp]
    final_embeddings_images_rec_attns_topic += mask_images_attns @ project_matrix @ vh """

# Compute accuracy using the reconstruction from final embeddings, ignoring the query information.
current_model = final_embeddings_images_rec_embed - final_embeddings_images_rec_embed_topic
_, indexes_approx_final_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Bias removal)")
print_diff_elements(indexes_approx_final, indexes_approx_final_rem, subset_dim)

# Compute accuracy using the reconstruction from attention activations, also ignoring the query information.
current_model = final_embeddings_images_rec_attns_not_mean_centered - final_embeddings_images_rec_attns_topic
_, indexs_approx_activ_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Bias Removal)")
print_diff_elements(indexes_approx_activ_only, indexs_approx_activ_rem, subset_dim)

# Compute accuracy using the reconstruction from final embeddings, ignoring the query information.
current_model = final_embeddings_images_rec_embed_topic
_, indexes_approx_final_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Subset)")
print_correct_elements(indexes_approx_final_rem, subset_dim)

# Compute accuracy using the reconstruction from attention activations, also ignoring the query information.
current_model = final_embeddings_images_rec_attns_topic
_, indexs_approx_activ_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Subset)")
print_correct_elements(indexs_approx_activ_rem, subset_dim)
