In [None]:
## Imports
import numpy as np
import torch
from tabulate import tabulate
from PIL import Image
import matplotlib.lines as mlines
import json
from utils.misc.misc import accuracy, accuracy_correct
from utils.scripts.algorithms_text_explanations import *
from utils.models.factory import create_model_and_transforms, get_tokenizer
from utils.misc.visualization import visualization_preprocess
from utils.models.prs_hook import hook_prs_logger
from utils.datasets_constants.imagenet_classes import imagenet_classes
from utils.datasets_constants.cifar_10_classes import cifar_10_classes
from utils.datasets_constants.cub_classes import cub_classes, waterbird_classes
import os
from utils.scripts.algorithms_text_explanations import svd_data_approx
from utils.datasets.dataset_helpers import dataset_to_dataloader
from torch.nn import functional as F
from utils.scripts.algorithms_text_explanations_funcs import *


In [2]:
## Parameters
device = 'cpu'
model_name = 'ViT-B-32' # 'ViT-H-14'
seed = 0
num_last_layers_ = 4
subset_dim = None
tot_samples_per_class = None
dataset_text_name = "top_1500_nouns_5_sentences_imagenet_clean"
datataset_image_name = "binary_waterbirds"
algorithm = "svd_data_approx"
path = './datasets/'

if model_name == "ViT-H-14":
    pretrained = "laion2B-s32B-b79K"
elif model_name == "ViT-L-14":
    pretrained = "laion2B-s32B-b82K"
elif model_name == "ViT-B-16":
    pretrained = "laion2B-s34B-b88K"
elif model_name == "ViT-B-32":
    pretrained = "laion2B-s34B-b79K"

In [None]:
## Loading Model
model, _, preprocess = create_model_and_transforms(model_name, pretrained=pretrained, cache_dir="../cache")
model.to(device)
model.eval()
context_length = model.context_length
vocab_size = model.vocab_size
tokenizer = get_tokenizer(model_name)

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Len of res:", len(model.visual.transformer.resblocks))

prs = hook_prs_logger(model, device, spatial=False) # This attach hook to get the residual stream

In [None]:
## Run the chosen algorithm on a dataset to derive text explanations 
command = f"python -m utils.scripts.compute_text_explanations --device {device} --model {model_name} --algorithm {algorithm} --seed {seed} --text_per_princ_comp 20 --num_of_last_layers {num_last_layers_} --text_descriptions {dataset_text_name}"
!{command}

In [None]:
# Load the new created attention datasets
attention_dataset = f"output_dir/{datataset_image_name}_completeness_{dataset_text_name}_{model_name}_algo_{algorithm}_seed_{seed}.jsonl"

# Load necessary data
attns_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_attn_{model_name}_seed_{seed}.npy", mmap_mode="r"))  # [b, l, h, d], attention values
mlps_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_mlp_{model_name}_seed_{seed}.npy", mmap_mode="r"))  # [b, l, h, d], mlp values
classifier_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_classifier_{model_name}.npy", mmap_mode="r"))  # [b, l, h, d], embedding of the labels
labels_ = torch.tensor(np.load(f"output_dir/{datataset_image_name}_labels_{model_name}_seed_{seed}.npy", mmap_mode="r")) # Position of the labels in the cosndiered dataset
final_embeddings_images = torch.tensor(np.load(f"output_dir/{datataset_image_name}_embeddings_{model_name}_seed_{seed}.npy", mmap_mode="r"))
final_embeddings_texts = torch.tensor(np.load(f"output_dir/{dataset_text_name}_{model_name}.npy", mmap_mode="r"))
with open( f"utils/text_descriptions/{dataset_text_name}.txt", "r") as f:
    texts_str = np.array([i.replace("\n", "") for i in f.readlines()])
# Get mean ablation
no_heads_attentions_ = attns_.sum(axis=(2))  # Sum over heads dimension
last_ = attns_.shape[1] - num_last_layers_
# Replace attention activations until 'last' layer with their average, while keeping later layers intact.
current_mean_ablation_per_head_sum_ = torch.mean(no_heads_attentions_[:, :last_ + 1], axis=0).sum(0)

# Save important stuff
nr_layers_ = attns_.shape[1]
nr_heads_ = attns_.shape[2]

if datataset_image_name == "imagenet":
    ds_ = ImageNet(root=path+"imagenet/", split="val", transform=visualization_preprocess)
elif datataset_image_name == "binary_waterbirds":
    ds_ = BinaryWaterbirds(root=path+"waterbird_complete95_forest2water2/", split="test", transform=visualization_preprocess)
elif datataset_image_name == "CIFAR100":
    ds_ = CIFAR100(
        root=path, download=True, train=False, transform=visualization_preprocess
    )
elif datataset_image_name == "CIFAR10":
    ds_ = CIFAR10(
        root=path, download=True, train=False, transform=visualization_preprocess
    )
else:
    ds_ = ImageFolder(root=path, transform=visualization_preprocess)

classes_ = {
        'imagenet': imagenet_classes, 
        'CIFAR10': cifar_10_classes,
        'waterbirds': cub_classes, 
        'binary_waterbirds': waterbird_classes, 
        'cub': cub_classes}[datataset_image_name]
# Depending
ds_vis_ = dataset_subset(
    ds_,
    samples_per_class=subset_dim,
    tot_samples_per_class=tot_samples_per_class,  # or whatever you prefer
    seed=seed,
)

# Print metadata accuracy if waterbird
if classes_ == waterbird_classes:
    root = "datasets/waterbird_complete95_forest2water2/"
    df = pd.read_csv(root + "metadata.csv")
    filtered_df = df[df['split'] == 2]

    s = [(os.path.join(root, filtered_df.iloc[i]['img_filename']), filtered_df.iloc[i]['y'], filtered_df.iloc[i]['place']) for i in range(len(filtered_df))]
    background_groups_ = list([x[2] for x in s])

# Retrieve Rank
data = get_data(attention_dataset, skip_final=True)
mean_rank_ = 0
for entry in data:
    mean_rank_ += entry["rank"]
mean_rank_ /= len(data)

# Print the top Principal Components text-interpretation for each Head

In [None]:
data = []
min_princ_comp = 10

# Read JSON lines from attention_dataset
# This file contains data about layers, heads, and their principal components (PCs) with associated metrics.
data = get_data(attention_dataset, -1)
    
# Print the data in a nice formatted table
print_data(data, min_princ_comp)

# Strongest Principal Components per Dataset

In [None]:
# Number of top entries to retrieve
top_k = 10
min_heap = []

# Retrieve data
data = get_data(attention_dataset, -1, skip_final=True)

# Sort data entries in descending order of strength_abs of the princial component
top_k_entries = top_data(sort_data_by(data, "strength_abs", descending=True), top_k=top_k)

# Print the top_k entries in a nice formatted table
print_used_heads(top_k_entries)
print_data(top_k_entries)


# Visualize singular values of a principal component (both text and images)

In [8]:
# Info on data
layer = 11
head = 11
princ_comp = 0
# Define the number of top and worst images to look at for each princ_comp
nr_top_imgs = 20  # Number of top elements
nr_worst_imgs = 20  # Number of worst elements
nr_cont_imgs = 0  # Length of continuous elements


In [None]:
## OPT. Visualize textSpan
attention_dataset_ts = f"output_dir/{datataset_image_name}_completeness_{dataset_text_name}_{model_name}_algo_text_span_seed_{seed}.jsonl"

visualize_text_span(layer, head, attention_dataset_ts, top_k= 5)


In [None]:
visualize_principal_component(layer, head, princ_comp, nr_top_imgs, nr_worst_imgs, nr_cont_imgs, attention_dataset, final_embeddings_images, final_embeddings_texts, seed, path, texts_str, dataset=datataset_image_name, samples_per_class=subset_dim, tot_samples_per_class=tot_samples_per_class)

In [None]:
# Visualize PCs strength
data = get_data(attention_dataset)
plot_pc_sv(data, layer, head)

# Query a topic or image and NNs on that

### Define the query and analyze each Principal Component and derive a strength metric for reconstruction of the query-embedding

In [None]:
# Set the model to evaluation mode so that no gradients are computed
model.eval()
query_text = True

# Retrieve an embedding
with torch.no_grad():
    if query_text:
        # If querying by text, define a text prompt and encode it into an embedding
        text_query = "An image of a woman"
        # Tokenize the text query and move it to the device (GPU/CPU)
        text_query_token = tokenizer(text_query).to(device)  
        # Encode the tokenized text into a normalized embedding
        topic_emb = model.encode_text(text_query_token, normalize=True)
    else:
        # If querying by image, load and preprocess the image from disk
        prs.reinit()  # Reinitialize any hooks if required
        text_query = "woman.png"
        image_pil = Image.open(f'images/{text_query}')
        image = preprocess(image_pil)[np.newaxis, :, :, :]  # Add batch dimension
        # Encode the image into a normalized embedding
        topic_emb = model.encode_image(
            image.to(device), 
            attn_method='head_no_spatial',
            normalize=True
        )

### Reconstruct embedding and find contributions from principal components
# Retrieve data
data = get_data(attention_dataset, -1, skip_final=True)

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0).to(device)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0).to(device)

# Mean center the embeddings
mean_final = mean_final_texts if query_text else mean_final_images

# Mean center the embeddings
topic_emb_cent = topic_emb - mean_final
final_embeddings_texts_cent = final_embeddings_texts.to(device) - mean_final_texts
# Recontruct embedding
[topic_emb_rec_cent], data = reconstruct_embeddings(data, [topic_emb_cent], ["text" if query_text else "image"], return_princ_comp=True, plot=True, means=[mean_final], device=device)

# Normalize the embeddings
topic_emb_rec_cent_norm = topic_emb_rec_cent / topic_emb_rec_cent.norm(dim=-1, keepdim=True)

# The maximum reconstruction score is how close the reconstructed embedding is to the original,
# adjusted by the baseline score. This gives a cosine similarity measure.
topic_emb_cent_norm = topic_emb_cent / topic_emb_cent.norm(dim=-1, keepdim=True)
max_reconstr_score = topic_emb_rec_cent_norm @ topic_emb_cent_norm.T
# Print out the cosine similarity between the original and reconstructed embeddings
print(f"We have a max cosine similarity of: {(max_reconstr_score).item():.4f}")


### Use the strength of the previous reconstruction to derive a good enough reconstruction of the query

In [None]:
# Extract relevant details from the top-k principal component entries based on the reconstruction of the query
top_k = 25  # Maximum number of top entries to retrieve
approx = 1.1  # Target approximation threshold for the reconstruction quality

# Initialize a tensor to accumulate the reconstructed topic embedding from selected principal components
topic_emb_rec_act = torch.zeros_like(topic_emb)

### Extract relevant details from the top k entries
data = sort_data_by(data, "correlation_princ_comp_abs", descending=True) 

top_k_entries = top_data(data, top_k)

top_k_details = reconstruct_top_embedding(top_k_entries, topic_emb_cent, mean_final, "text" if query_text else "image", max_reconstr_score, top_k, approx, device=device)
# Convert the collected principal component details into a DataFrame for easy processing
print(f"Currently querying the topic: {text_query}")
print_data(top_k_details, is_corr_present=True)

### Prepare scores of images and texts 

In [None]:
## For Reconstructed Embedding
# Visualize ds
# Initialize arrays to store the top and lowest scores based on similarity with original query
scores_array_images = np.empty(
    final_embeddings_images.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('img_index', 'i4')]
)

# Initialize arrays to store the top and lowest scores based on similarity with self reconstructed query
scores_array_images_self = np.empty(
    final_embeddings_images.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('img_index', 'i4')]
)

# Create arrays of indexes for referencing images and texts.
indexes_images = np.arange(0, final_embeddings_images.shape[0], 1) 

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0).to(device)


# Recontruct embedding
images_rec = reconstruct_all_embeddings_mean_ablation_pcs(top_k_entries, mlps_, attns_, attns_, nr_layers_, nr_heads_, last_, ratio=-1, ablation=True, mean_ablate_all=True)


# Compute the similarity scores between the reconstructed embeddings (images or texts) and the original query embedding.
# The dot product gives a similarity measure, which we store in the scores arrays along with the index.
# We do NOT normalize the score.
# Compute scores for images

scores_array_images["score"] = (images_rec @ topic_emb.T).squeeze().cpu().numpy()
scores_array_images_self["score"] = (torch.diag(images_rec @ final_embeddings_images.T)).squeeze().cpu().numpy()

images_rec /= images_rec.norm(dim=-1, keepdim=True)

scores_array_images["score_vis"] = (images_rec @ topic_emb.T).squeeze().cpu().numpy()
scores_array_images_self["score_vis"] = (torch.diag(images_rec @ final_embeddings_images.T)).squeeze().cpu().numpy()

scores_array_images["img_index"] = indexes_images
scores_array_images_self["img_index"] = indexes_images

In [None]:
# For full CLIP Embedding
# Scores array of real CLIP embeddings
scores_array_images_full = np.empty(
    final_embeddings_images.shape[0], 
    dtype=[('score', 'f4'), ('score_vis', 'f4'), ('img_index', 'i4')]
)

# Compute scores for images
images = final_embeddings_images.to(device)
scores_array_images_full["score"] = (images @ topic_emb.T).squeeze().cpu().numpy()

images /= images.norm(dim=-1, keepdim=True)
scores_array_images_full["score_vis"] = (images @ topic_emb.T).squeeze().cpu().numpy()

scores_array_images_full["img_index"] = indexes_images

# Define the number of top and worst images to look at for each princ_comp
nr_top_imgs = 20  # Number of top elements
nr_worst_imgs = 20  # Number of worst elements
nr_cont_imgs = 20  # Length of continuous elements

### Visualize

In [None]:
# Define the number of top and worst images to look at for each princ_comp
nr_top_imgs = 20  # Number of top elements
nr_worst_imgs = 20  # Number of worst elements
nr_cont_imgs = 0  # Length of continuous elements

dbs = create_dbs(scores_array_images, None, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
# Hardcoded visualizations
nrs_dbs = [nr_top_imgs, nr_worst_imgs, nr_cont_imgs]
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis_, texts_str, classes_, text_query)

In [None]:
# Visualize full Embeddings similarity
dbs = create_dbs(scores_array_images_full, None, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis_, texts_str, classes_, text_query)

In [None]:
# Define the number of top and worst images to look at for each princ_comp
dbs = create_dbs(scores_array_images_self, None, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis_, texts_str, classes_, text_query)

### Opt, visualize using all NOT selected PCs 

In [None]:
# Get other PCs
data = get_data(attention_dataset, -1, skip_final=True)

top_k_other_details = get_remaining_pcs(data, top_k_entries)

# Recontruct embedding

images_rec = reconstruct_all_embeddings_mean_ablation_pcs(top_k_other_details, mlps_, attns_, attns_, nr_layers_, nr_heads_, last_, ratio=-1, mean_ablate_all=True)

# Compute the similarity scores between the reconstructed embeddings (images or texts) and the original query embedding.
# The dot product gives a similarity measure, which we store in the scores arrays along with the index.
# We do NOT normalize the score.
# Compute scores for images

scores_array_images["score"] = (images_rec @ topic_emb.T).squeeze().cpu().numpy()
scores_array_images_self["score"] = (torch.diag(images_rec @ final_embeddings_images.T)).squeeze().cpu().numpy()

images_rec /= images_rec.norm(dim=-1, keepdim=True)

scores_array_images["score_vis"] = (images_rec @ topic_emb.T).squeeze().cpu().numpy()
scores_array_images_self["score_vis"] = (torch.diag(images_rec @ final_embeddings_images.T)).squeeze().cpu().numpy()

scores_array_images["img_index"] = indexes_images
scores_array_images_self["img_index"] = indexes_images


In [None]:
dbs = create_dbs(scores_array_images, None, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
# Hardcoded visualizations
nrs_dbs = [nr_top_imgs, nr_worst_imgs, nr_cont_imgs]
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis_, texts_str, classes_, text_query)

In [None]:
# Define the number of top and worst images to look at for each princ_comp
dbs = create_dbs(scores_array_images_self, None, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)
dbs_new = []
for i, db in enumerate(dbs):
    if nrs_dbs[i] == 0:
        continue
    dbs_new.append(db)
visualize_dbs(top_k_details, dbs_new, ds_vis_, texts_str, classes_, text_query)

# Evaluate classification using reconstruction

## Ablation Study

In [None]:
nr_layers = attns_.shape[1]
# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
accs = []
for layer_nr in range(nr_layers):
    current_mean_ablation_per_head_sum = torch.mean(no_heads_attentions_[:, :layer_nr], axis=0).sum(0)
    current_model = (current_mean_ablation_per_head_sum  + no_heads_attentions_[:, layer_nr + 1:].sum(1)) + mlps_.sum(axis=1) 
    acc, _ = test_accuracy(current_model @ classifier_, labels_, label=f"Mean ablation from layer {nr_layers - layer_nr} to {nr_layers}")
    accs.append(acc)

# Create an x-axis that has one increment for each element in acc
x_values = range(len(accs))

# Plot
plt.figure(figsize=(8, 5))
plt.plot(x_values, accs, linestyle='-', label=model_name)

# Labeling
plt.xlabel("Accumulated mean-ablated layers")
plt.ylabel("Accuracy")
plt.title("Accuracy vs. Accumulated Mean-Ablated Layers")

# Add legend for the line
plt.legend()

plt.grid(True)
plt.show()

## Proof of concept 1: Use aggregations of PCs of all labels at once

In [None]:
pcs_per_class_start = 1
pcs_per_class_end = int(nr_heads_*num_last_layers_*mean_rank_)
pcs_per_class_step = 10
max_pcs_per_head = -1
random = False
class_embeddings = classifier_.T  # M x D


# Baseline accuracy computation:
baseline = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1)
baseline_acc, idxs = test_accuracy(baseline @ classifier_, labels_, label="Baseline")
output = print_correct_elements(idxs, labels_, classes_)  

if classes_ == waterbird_classes:
    baseline_worst = test_waterbird_preds(idxs, labels_, background_groups_)

# Using the knwoledge of which class are we predicting wrongly, give more or less weight to pcs per class
# Reconstruct embeddings for each class label

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

classes_centered = class_embeddings - mean_final_texts.unsqueeze(0)

# Initialize a (num_images x 2) array to track:
#   [best_score_so_far, class_index_for_that_score]
all_preds = torch.zeros((final_embeddings_images.shape[0], 2), dtype=torch.double)

sorted_data = []
for text_idx in range(classes_centered.shape[0]):
    # Perform query system on entry
    concept_i_centered = classes_centered[text_idx, :].unsqueeze(0)

    data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)

    _, data_abs = reconstruct_embeddings(
        data, 
        [concept_i_centered], 
        ["text"], 
        return_princ_comp=True, 
        plot=False, 
        means=[mean_final_texts],
    )

    # Extract relevant details from the top k entries
    data_pcs = sort_data_by(data_abs, "correlation_princ_comp_abs", descending=True)
    # Derive nr_pcs_per_class
    sorted_data.append(data_pcs)

pcs_per_class_max_worst_acc = pcs_per_class_start
max_worst_acc = 0

worst_class_acc = []
worst_class_nr_pcs = []

total_accuracy = []  # Will store the average accuracy across all text_idx for each pcs_per_class

for pcs_per_class in range(pcs_per_class_start, pcs_per_class_end, pcs_per_class_step):
    entries = []
    temp_accs = []  # temporary list to store accuracy for each text_idx

    # Collect top_k_entries for each concept
    for text_idx in range(classes_centered.shape[0]):
        # Retrieve data
        data_pcs = sorted_data[text_idx]
        top_k_entries = top_data(data_pcs, pcs_per_class)
        print(f"Currently processing label {text_idx} with nr_pcs_per_class: {pcs_per_class}")
        entries += top_k_entries

    # Remove duplicates
    entries_set = []
    entries_meta = []
    for entry in entries:
        layer = entry["layer"]
        head = entry["head"]
        princ_comp = entry["princ_comp"]
        if (layer, head, princ_comp) not in entries_meta:
            entries_meta.append((layer, head, princ_comp))
            entries_set.append(entry)

    print(f"Total number of unique entries: {len(entries_set)}")

    # If `random` is True, randomly pick PCs instead of the actual top_k
    if random:
        data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)
        entries_set = random_pcs(data, pcs_per_class * len(classes_))

    # Reconstruct final_embeddings_images
    reconstructed_images = reconstruct_all_embeddings_mean_ablation_pcs(
        entries_set,
        mlps_,
        attns_,
        attns_,
        nr_layers_,
        nr_heads_,
        num_last_layers_,
        ratio=-1,
        mean_ablate_all=False
    )

    reconstructed_images /= reconstructed_images.norm(dim=-1, keepdim=True)
    predictions = reconstructed_images @ classifier_

    # Evaluate across all text_idx
    # (If you have different labels per text_idx, adapt accordingly.)
    # For simplicity, assume `test_accuracy` returns (acc, idxs) for the entire set.
    acc, idxs = test_accuracy(predictions, labels_, label="All classes combined")
    total_accuracy.append(acc)

    # If you need to evaluate `acc` separately for each text_idx, do so in the loop above.
    # Then store the average or final result below.

    # You might also be printing correctness here:
    print_correct_elements(idxs, labels_, classes_)

    # Worst-class accuracy for Waterbirds, if applicable
    if classes_ == waterbird_classes:
        curr_worst_acc = test_waterbird_preds(idxs, labels_, background_groups_)
        if curr_worst_acc > max_worst_acc:
            max_worst_acc = curr_worst_acc
            pcs_per_class_max_worst_acc = pcs_per_class

        worst_class_acc.append(curr_worst_acc)
        worst_class_nr_pcs.append(pcs_per_class)

# ---------------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# ---------------------------------------

# Suppose these are already known/computed:
max_worst_acc = max(worst_class_acc)
pcs_per_class_max_worst_acc = worst_class_nr_pcs[np.argmax(worst_class_acc)]

plt.figure(figsize=(6, 4))

# Plot the main data
plt.plot(
    worst_class_nr_pcs, worst_class_acc, color='blue', 
    linestyle='-', 
    label='Worst-Class Accuracy'
)

# Baseline line
plt.axhline(
    y=baseline_worst, 
    color='gray', 
    linestyle='--', 
    linewidth=2, 
    label=f'Baseline Worst-Class Acc = {baseline_worst:.2f}' 
)

# Horizontal line for max
plt.axhline(
    y=max_worst_acc, 
    color='blue', 
    linestyle=':', 
    linewidth=2
)

# Vertical line for max
plt.axvline(
    x=pcs_per_class_max_worst_acc, 
    color='blue', 
    linestyle=':', 
    linewidth=2
)

# Create a single custom legend entry for both max lines
max_line_legend = mlines.Line2D(
    [], [], 
    color='blue', 
    linestyle=':', 
    linewidth=2,
    label=f'Max Worst-Class Acc = {max_worst_acc:.2f} at PCs={pcs_per_class_max_worst_acc}\n(Worst total accuracy is {total_accuracy[np.argmax(worst_class_acc)]:.2f})'
)

# Collect existing handles and labels from the current axes
handles, labels = plt.gca().get_legend_handles_labels()

# Append the custom max-line legend entry
handles.append(max_line_legend)
labels.append(max_line_legend.get_label())

plt.xlabel('Number of PCs per Class')
plt.ylabel('Worst Accuracy')
plt.title('Worst-Class Accuracy vs. Number of PCs per Class')
plt.grid(True)
plt.legend(handles, labels)  # Use your updated handles and labels
plt.tight_layout()  # Ensures elements fit within the figure
plt.savefig(f"plt_1_worst_{model_name}.pdf", bbox_inches='tight', format='pdf')  # Corrected format argument
plt.show()

print(f"Max worst accuracy of {max_worst_acc:.2f} found at {pcs_per_class_max_worst_acc} PCs/class.")

# ----------------------------------------------------------------------------------
# PLOTTING SECTION for TOTAL ACCURACY
# ----------------------------------------------------------------------------------

# Suppose these are already known/computed:
max_total_acc = max(total_accuracy)
pcs_for_max_total_acc = worst_class_nr_pcs[np.argmax(total_accuracy)]

plt.figure(figsize=(6, 4))

# Plot the main data
plt.plot(
    worst_class_nr_pcs, total_accuracy, 
    linestyle='-', color='orange', 
    label='Total Accuracy'
)

# Baseline line
plt.axhline(
    y=baseline_acc, 
    color='gray', 
    linestyle='--', 
    linewidth=2, 
    label=f'Baseline Total Acc = {baseline_acc:.2f}'
)

# Horizontal line for the max total accuracy
plt.axhline(
    y=max_total_acc, 
    color='orange', 
    linestyle=':', 
    linewidth=2
)

# Vertical line for the max total accuracy
plt.axvline(
    x=pcs_for_max_total_acc, 
    color='orange', 
    linestyle=':', 
    linewidth=2
)

# Create a single custom legend entry for both max lines
max_line_legend_2 = mlines.Line2D(
    [], [], 
    color='orange', 
    linestyle=':', 
    linewidth=2,
    label=f'Max Total Acc = {max_total_acc:.2f} at PCs={pcs_for_max_total_acc}\n(Worst class accuracy is {worst_class_acc[np.argmax(total_accuracy)]:.2f})'
)

# Collect existing handles/labels
handles2, labels2 = plt.gca().get_legend_handles_labels()

# Append the custom line
handles2.append(max_line_legend_2)
labels2.append(max_line_legend_2.get_label())

plt.xlabel('Number of PCs per Class')
plt.ylabel('Total Accuracy (Avg)')
plt.title('Total Accuracy vs. Number of PCs per Class')
plt.grid(True)
plt.legend(handles2, labels2)
plt.tight_layout()  # Ensures elements fit within the figure
plt.savefig(f"plt_1_acc_{model_name}.pdf", bbox_inches='tight', format='pdf')  # Corrected format argument
plt.show()

print(f"Max total accuracy of {max_total_acc:.2f} found at {pcs_for_max_total_acc}")# ---------------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# ---------------------------------------
# -------------------------------
# COMPUTE PARETO OPTIMAL CANDIDATE
# -------------------------------
# A candidate is one where both total accuracy and worst-class accuracy exceed their baselines.
# Among these, we select the one with the highest worst-class accuracy.
# If there is a tie, we pick the candidate with the higher total accuracy.
total_accuracy_arr = np.array(total_accuracy)
worst_class_acc_arr = np.array(worst_class_acc)
worst_class_nr_pcs_arr = np.array(worst_class_nr_pcs)

candidates_mask = (total_accuracy_arr > baseline_acc) & (worst_class_acc_arr > baseline_worst)
if np.any(candidates_mask):
    candidate_indices = np.where(candidates_mask)[0]
    best_idx = candidate_indices[0]
    for idx in candidate_indices:
        if worst_class_acc_arr[idx] > worst_class_acc_arr[best_idx]:
            best_idx = idx
        elif worst_class_acc_arr[idx] == worst_class_acc_arr[best_idx] and total_accuracy_arr[idx] > total_accuracy_arr[best_idx]:
            best_idx = idx
    pareto_pcs = worst_class_nr_pcs_arr[best_idx]
    pareto_worst = worst_class_acc_arr[best_idx]
    pareto_total = total_accuracy_arr[best_idx]
else:
    pareto_pcs = None

# -------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# -------------------------------
# (Only applicable if classes_ == waterbird_classes)
if classes_ == waterbird_classes:
    # Recompute max values for worst-class accuracy for clarity
    max_worst_acc = max(worst_class_acc)
    pcs_per_class_max_worst_acc = worst_class_nr_pcs[np.argmax(worst_class_acc)]

    plt.figure(figsize=(6, 4))
    # Plot the worst-class accuracy curve
    plt.plot(
        worst_class_nr_pcs, worst_class_acc,
        color='blue',
        linestyle='-',
        label='Worst-Class Accuracy'
    )
    # Plot baseline worst-class accuracy
    plt.axhline(
        y=baseline_worst,
        color='gray',
        linestyle='--',
        linewidth=2,
        label=f'Baseline Worst-Class Acc = {baseline_worst:.2f}'
    )
    # Highlight the maximum worst-class accuracy with horizontal and vertical lines
    plt.axhline(
        y=max_worst_acc,
        color='blue',
        linestyle=':',
        linewidth=2
    )
    plt.axvline(
        x=pcs_per_class_max_worst_acc,
        color='blue',
        linestyle=':',
        linewidth=2
    )
    # --- Add Vertical Line for Pareto Optimal Solution (if found) ---
    if pareto_pcs is not None:
        plt.axvline(
            x=pareto_pcs,
            color='green',
            linestyle='--',
            linewidth=2
        )
        pareto_line_legend = mlines.Line2D(
            [],
            [],
            color='green',
            linestyle='--',
            linewidth=2,
            label=f'Pareto Optimal: Worst Acc = {pareto_worst:.2f}, Total Acc = {pareto_total:.2f} at PCs={pareto_pcs}'
        )
    else:
        pareto_line_legend = None

    # Custom legend entry for max worst-class accuracy
    max_line_legend = mlines.Line2D(
        [],
        [],
        color='blue',
        linestyle=':',
        linewidth=2,
        label=f'Max Worst-Class Acc = {max_worst_acc:.2f} at PCs={pcs_per_class_max_worst_acc}\n(Worst total accuracy is {total_accuracy[np.argmax(worst_class_acc)]:.2f})'
    )

    # Get current legend handles and append our custom entries
    handles, labels = plt.gca().get_legend_handles_labels()
    handles.append(max_line_legend)
    labels.append(max_line_legend.get_label())
    if pareto_line_legend is not None:
        handles.append(pareto_line_legend)
        labels.append(pareto_line_legend.get_label())

    plt.xlabel('Number of PCs per Class')
    plt.ylabel('Worst Accuracy')
    plt.title('Worst-Class Accuracy vs. Number of PCs per Class')
    plt.grid(True)
    plt.legend(handles, labels)
    plt.tight_layout()
    plt.savefig(f"plt_1worst_{model_name}.pdf", bbox_inches='tight', format='pdf')
    plt.show()

    print(f"Max worst accuracy of {max_worst_acc:.2f} found at {pcs_per_class_max_worst_acc} PCs/class.")
    if pareto_pcs is not None:
        print(f"Pareto optimal solution at {pareto_pcs} PCs/class with Worst Acc = {pareto_worst:.2f} and Total Acc = {pareto_total:.2f}.")

# -------------------------------
# PLOTTING SECTION for TOTAL ACCURACY
# -------------------------------
# Compute max total accuracy details
max_total_acc = max(total_accuracy)
pcs_for_max_total_acc = worst_class_nr_pcs[np.argmax(total_accuracy)]

plt.figure(figsize=(6, 4))
# Plot the total accuracy curve
plt.plot(
    worst_class_nr_pcs, total_accuracy,
    linestyle='-', color='orange',
    label='Total Accuracy'
)
# Plot baseline total accuracy
plt.axhline(
    y=baseline_acc,
    color='gray',
    linestyle='--',
    linewidth=2,
    label=f'Baseline Total Acc = {baseline_acc:.2f}'
)
# Highlight the maximum total accuracy with horizontal and vertical lines
plt.axhline(
    y=max_total_acc,
    color='orange',
    linestyle=':',
    linewidth=2
)
plt.axvline(
    x=pcs_for_max_total_acc,
    color='orange',
    linestyle=':',
    linewidth=2
)
# --- Add Vertical Line for Pareto Optimal Solution (if found) ---
if pareto_pcs is not None:
    plt.axvline(
        x=pareto_pcs,
        color='green',
        linestyle='--',
        linewidth=2
    )
    pareto_line_legend_2 = mlines.Line2D(
        [],
        [],
        color='green',
        linestyle='--',
        linewidth=2,
        label=f'Pareto Optimal: Worst Acc = {pareto_worst:.2f}, Total Acc = {pareto_total:.2f} at PCs={pareto_pcs}'
    )
else:
    pareto_line_legend_2 = None

# Custom legend entry for max total accuracy
max_line_legend_2 = mlines.Line2D(
    [],
    [],
    color='orange',
    linestyle=':',
    linewidth=2,
    label=f'Max Total Acc = {max_total_acc:.2f} at PCs={pcs_for_max_total_acc}\n(Worst class accuracy is {worst_class_acc[np.argmax(total_accuracy)]:.2f})'
)

handles2, labels2 = plt.gca().get_legend_handles_labels()
handles2.append(max_line_legend_2)
labels2.append(max_line_legend_2.get_label())
if pareto_line_legend_2 is not None:
    handles2.append(pareto_line_legend_2)
    labels2.append(pareto_line_legend_2.get_label())

plt.xlabel('Number of PCs per Class')
plt.ylabel('Total Accuracy (Avg)')
plt.title('Total Accuracy vs. Number of PCs per Class')
plt.grid(True)
plt.legend(handles2, labels2)
plt.tight_layout()
plt.savefig(f"plt_1_acc_{model_name}.pdf", bbox_inches='tight', format='pdf')
plt.show()

print(f"Max total accuracy of {max_total_acc:.2f} found at {pcs_for_max_total_acc}")

## Proof of concept 2: Remove concepts from the model and use other PCs to reconstruct

In [None]:
pcs_per_class_start = 1
pcs_per_class_end = int(nr_heads_*num_last_layers_*mean_rank_)
pcs_per_class_step = 10
max_pcs_per_head = -1
random = False
class_embeddings = classifier_.T  # M x D
concepts_to_remove = ["water background", "land background"]

# Derive embedding:
for k, concept in enumerate(concepts_to_remove):
    # Retrieve an embedding
    with torch.no_grad():
        # If querying by text, define a text prompt and encode it into an embedding
        # Tokenize the text query and move it to the device (GPU/CPU)
        text_query_token = tokenizer(concept).to(device)  
        # Encode the tokenized text into a normalized embedding
        topic_emb = model.encode_text(text_query_token, normalize=True)
        if k == 0:
            concepts_emb = torch.zeros(len(concepts_to_remove), topic_emb.shape[-1], device=device)
        concepts_emb[k] = topic_emb

# Print baseline accuracy
# Baseline accuracy computation:
baseline = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1)
baseline_acc, idxs = test_accuracy(baseline @ classifier_, labels_, label="Baseline")
if classes_ == waterbird_classes:
    baseline_worst = test_waterbird_preds(idxs, labels_, background_groups_)
# Reconstruct embeddings for each class label

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

concepts_centered = concepts_emb - mean_final_texts.unsqueeze(0)

# Initialize a (num_images x 2) array to track:
#   [best_score_so_far, class_index_for_that_score]
all_preds = torch.zeros((final_embeddings_images.shape[0], 2), dtype=torch.double)

sorted_data = []
for text_idx in range(concepts_centered.shape[0]):
    # Perform query system on entry
    concept_i_centered = concepts_centered[text_idx, :].unsqueeze(0)

    data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)

    _, data_abs = reconstruct_embeddings(
        data, 
        [concept_i_centered], 
        ["text"], 
        return_princ_comp=True, 
        plot=False, 
        means=[mean_final_texts],
    )

    # Extract relevant details from the top k entries
    data_pcs = sort_data_by(data_abs, "correlation_princ_comp_abs", descending=True)
    # Derive nr_pcs_per_class
    sorted_data.append(data_pcs)

pcs_per_class_max_worst_acc = pcs_per_class_start
max_worst_acc = 0

worst_class_acc = []
worst_class_nr_pcs = []

total_accuracy = []  # Will store the average accuracy across all text_idx for each pcs_per_class

for pcs_per_class in range(pcs_per_class_start, pcs_per_class_end, pcs_per_class_step):
    entries = []
    temp_accs = []  # temporary list to store accuracy for each text_idx

    # Collect top_k_entries for each concept
    for text_idx in range(concepts_centered.shape[0]):
        # Retrieve data
        data_pcs = sorted_data[text_idx]
        top_k_entries = top_data(data_pcs, pcs_per_class)
        print(f"Currently processing label: {concepts_to_remove[text_idx]} with nr_pcs_per_class: {pcs_per_class}")
        entries += top_k_entries

    # Remove duplicates
    entries_set = []
    entries_meta = []
    for entry in entries:
        layer = entry["layer"]
        head = entry["head"]
        princ_comp = entry["princ_comp"]
        if (layer, head, princ_comp) not in entries_meta:
            entries_meta.append((layer, head, princ_comp))
            entries_set.append(entry)

    print(f"Total number of unique entries: {len(entries_set)}")

    # Extract other components
    top_k_other_details = get_remaining_pcs(data, entries_set)

    # If `random` is True, randomly pick PCs instead of the actual top_k
    if random:
        data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)
        top_k_other_details = random_pcs(data, pcs_per_class * len(classes_))

    # Reconstruct final_embeddings_images
    reconstructed_images = reconstruct_all_embeddings_mean_ablation_pcs(
        top_k_other_details,
        mlps_,
        attns_,
        attns_,
        nr_layers_,
        nr_heads_,
        num_last_layers_,
        ratio=-1,
        mean_ablate_all=True
    )

    reconstructed_images /= reconstructed_images.norm(dim=-1, keepdim=True)
    predictions = reconstructed_images @ classifier_

    # Evaluate across all text_idx
    # (If you have different labels per text_idx, adapt accordingly.)
    # For simplicity, assume `test_accuracy` returns (acc, idxs) for the entire set.
    acc, idxs = test_accuracy(predictions, labels_, label="All classes combined")
    total_accuracy.append(acc)

    # If you need to evaluate `acc` separately for each text_idx, do so in the loop above.
    # Then store the average or final result below.

    # You might also be printing correctness here:
    print_correct_elements(idxs, labels_, classes_)

    # Worst-class accuracy for Waterbirds, if applicable
    if classes_ == waterbird_classes:
        curr_worst_acc = test_waterbird_preds(idxs, labels_, background_groups_)
        if curr_worst_acc > max_worst_acc:
            max_worst_acc = curr_worst_acc
            pcs_per_class_max_worst_acc = pcs_per_class

        worst_class_acc.append(curr_worst_acc)
        worst_class_nr_pcs.append(pcs_per_class)

# ---------------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# ---------------------------------------
# -------------------------------
# COMPUTE PARETO OPTIMAL CANDIDATE
# -------------------------------
# A candidate is one where both total accuracy and worst-class accuracy exceed their baselines.
# Among these, we select the one with the highest worst-class accuracy.
# If there is a tie, we pick the candidate with the higher total accuracy.
total_accuracy_arr = np.array(total_accuracy)
worst_class_acc_arr = np.array(worst_class_acc)
worst_class_nr_pcs_arr = np.array(worst_class_nr_pcs)

candidates_mask = (total_accuracy_arr > baseline_acc) & (worst_class_acc_arr > baseline_worst)
if np.any(candidates_mask):
    candidate_indices = np.where(candidates_mask)[0]
    best_idx = candidate_indices[0]
    for idx in candidate_indices:
        if worst_class_acc_arr[idx] > worst_class_acc_arr[best_idx]:
            best_idx = idx
        elif worst_class_acc_arr[idx] == worst_class_acc_arr[best_idx] and total_accuracy_arr[idx] > total_accuracy_arr[best_idx]:
            best_idx = idx
    pareto_pcs = worst_class_nr_pcs_arr[best_idx]
    pareto_worst = worst_class_acc_arr[best_idx]
    pareto_total = total_accuracy_arr[best_idx]
else:
    pareto_pcs = None

# -------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# -------------------------------
# (Only applicable if classes_ == waterbird_classes)
if classes_ == waterbird_classes:
    # Recompute max values for worst-class accuracy for clarity
    max_worst_acc = max(worst_class_acc)
    pcs_per_class_max_worst_acc = worst_class_nr_pcs[np.argmax(worst_class_acc)]

    plt.figure(figsize=(6, 4))
    # Plot the worst-class accuracy curve
    plt.plot(
        worst_class_nr_pcs, worst_class_acc,
        color='blue',
        linestyle='-',
        label='Worst-Class Accuracy'
    )
    # Plot baseline worst-class accuracy
    plt.axhline(
        y=baseline_worst,
        color='gray',
        linestyle='--',
        linewidth=2,
        label=f'Baseline Worst-Class Acc = {baseline_worst:.2f}'
    )
    # Highlight the maximum worst-class accuracy with horizontal and vertical lines
    plt.axhline(
        y=max_worst_acc,
        color='blue',
        linestyle=':',
        linewidth=2
    )
    plt.axvline(
        x=pcs_per_class_max_worst_acc,
        color='blue',
        linestyle=':',
        linewidth=2
    )
    # --- Add Vertical Line for Pareto Optimal Solution (if found) ---
    if pareto_pcs is not None:
        plt.axvline(
            x=pareto_pcs,
            color='green',
            linestyle='--',
            linewidth=2
        )
        pareto_line_legend = mlines.Line2D(
            [],
            [],
            color='green',
            linestyle='--',
            linewidth=2,
            label=f'Pareto Optimal: Worst Acc = {pareto_worst:.2f}, Total Acc = {pareto_total:.2f} at PCs={pareto_pcs}'
        )
    else:
        pareto_line_legend = None

    # Custom legend entry for max worst-class accuracy
    max_line_legend = mlines.Line2D(
        [],
        [],
        color='blue',
        linestyle=':',
        linewidth=2,
        label=f'Max Worst-Class Acc = {max_worst_acc:.2f} at PCs={pcs_per_class_max_worst_acc}\n(Worst total accuracy is {total_accuracy[np.argmax(worst_class_acc)]:.2f})'
    )

    # Get current legend handles and append our custom entries
    handles, labels = plt.gca().get_legend_handles_labels()
    handles.append(max_line_legend)
    labels.append(max_line_legend.get_label())
    if pareto_line_legend is not None:
        handles.append(pareto_line_legend)
        labels.append(pareto_line_legend.get_label())

    plt.xlabel('Number of PCs per Class')
    plt.ylabel('Worst Accuracy')
    plt.title('Worst-Class Accuracy vs. Number of PCs per Class')
    plt.grid(True)
    plt.legend(handles, labels)
    plt.tight_layout()
    plt.savefig(f"plt_2_worst_{model_name}.pdf", bbox_inches='tight', format='pdf')
    plt.show()

    print(f"Max worst accuracy of {max_worst_acc:.2f} found at {pcs_per_class_max_worst_acc} PCs/class.")
    if pareto_pcs is not None:
        print(f"Pareto optimal solution at {pareto_pcs} PCs/class with Worst Acc = {pareto_worst:.2f} and Total Acc = {pareto_total:.2f}.")

# -------------------------------
# PLOTTING SECTION for TOTAL ACCURACY
# -------------------------------
# Compute max total accuracy details
max_total_acc = max(total_accuracy)
pcs_for_max_total_acc = worst_class_nr_pcs[np.argmax(total_accuracy)]

plt.figure(figsize=(6, 4))
# Plot the total accuracy curve
plt.plot(
    worst_class_nr_pcs, total_accuracy,
    linestyle='-', color='orange',
    label='Total Accuracy'
)
# Plot baseline total accuracy
plt.axhline(
    y=baseline_acc,
    color='gray',
    linestyle='--',
    linewidth=2,
    label=f'Baseline Total Acc = {baseline_acc:.2f}'
)
# Highlight the maximum total accuracy with horizontal and vertical lines
plt.axhline(
    y=max_total_acc,
    color='orange',
    linestyle=':',
    linewidth=2
)
plt.axvline(
    x=pcs_for_max_total_acc,
    color='orange',
    linestyle=':',
    linewidth=2
)
# --- Add Vertical Line for Pareto Optimal Solution (if found) ---
if pareto_pcs is not None:
    plt.axvline(
        x=pareto_pcs,
        color='green',
        linestyle='--',
        linewidth=2
    )
    pareto_line_legend_2 = mlines.Line2D(
        [],
        [],
        color='green',
        linestyle='--',
        linewidth=2,
        label=f'Pareto Optimal: Worst Acc = {pareto_worst:.2f}, Total Acc = {pareto_total:.2f} at PCs={pareto_pcs}'
    )
else:
    pareto_line_legend_2 = None

# Custom legend entry for max total accuracy
max_line_legend_2 = mlines.Line2D(
    [],
    [],
    color='orange',
    linestyle=':',
    linewidth=2,
    label=f'Max Total Acc = {max_total_acc:.2f} at PCs={pcs_for_max_total_acc}\n(Worst class accuracy is {worst_class_acc[np.argmax(total_accuracy)]:.2f})'
)

handles2, labels2 = plt.gca().get_legend_handles_labels()
handles2.append(max_line_legend_2)
labels2.append(max_line_legend_2.get_label())
if pareto_line_legend_2 is not None:
    handles2.append(pareto_line_legend_2)
    labels2.append(pareto_line_legend_2.get_label())

plt.xlabel('Number of PCs per Class')
plt.ylabel('Total Accuracy (Avg)')
plt.title('Total Accuracy vs. Number of PCs per Class')
plt.grid(True)
plt.legend(handles2, labels2)
plt.tight_layout()
plt.savefig(f"plt_2_acc_{model_name}.pdf", bbox_inches='tight', format='pdf')
plt.show()

print(f"Max total accuracy of {max_total_acc:.2f} found at {pcs_for_max_total_acc}")

## Proof of concept 3: Use as main components of the model what we want

In [None]:
pcs_per_class_start = 1
pcs_per_class_end = int(nr_heads_*num_last_layers_*mean_rank_)
pcs_per_class_step = 10
max_pcs_per_head = -1
random = False
class_embeddings = classifier_.T  # M x D
concepts_to_add = ["feet shape", "beak shape"]

# Derive embedding:
for k, concept in enumerate(concepts_to_add):
    # Retrieve an embedding
    with torch.no_grad():
        # If querying by text, define a text prompt and encode it into an embedding
        # Tokenize the text query and move it to the device (GPU/CPU)
        text_query_token = tokenizer(concept).to(device)  
        # Encode the tokenized text into a normalized embedding
        topic_emb = model.encode_text(text_query_token, normalize=True)
        if k == 0:
            concepts_emb = torch.zeros(len(concepts_to_add), topic_emb.shape[-1], device=device)
        concepts_emb[k] = topic_emb

# Print baseline accuracy
# Baseline accuracy computation:
baseline = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1)
baseline_acc, idxs = test_accuracy(baseline @ classifier_, labels_, label="Baseline")
if classes_ == waterbird_classes:
    baseline_worst = test_waterbird_preds(idxs, labels_, background_groups_)
# Reconstruct embeddings for each class label

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

concepts_centered = concepts_emb - mean_final_texts.unsqueeze(0)

# Initialize a (num_images x 2) array to track:
#   [best_score_so_far, class_index_for_that_score]
all_preds = torch.zeros((final_embeddings_images.shape[0], 2), dtype=torch.double)

sorted_data = []
for text_idx in range(concepts_centered.shape[0]):
    # Perform query system on entry
    concept_i_centered = concepts_centered[text_idx, :].unsqueeze(0)

    data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)

    _, data_abs = reconstruct_embeddings(
        data, 
        [concept_i_centered], 
        ["text"], 
        return_princ_comp=True, 
        plot=False, 
        means=[mean_final_texts],
    )

    # Extract relevant details from the top k entries
    data_pcs = sort_data_by(data_abs, "correlation_princ_comp_abs", descending=True)
    # Derive nr_pcs_per_class
    sorted_data.append(data_pcs)

pcs_per_class_max_worst_acc = pcs_per_class_start
max_worst_acc = 0

worst_class_acc = []
worst_class_nr_pcs = []

total_accuracy = []  # Will store the average accuracy across all text_idx for each pcs_per_class

for pcs_per_class in range(pcs_per_class_start, pcs_per_class_end, pcs_per_class_step):
    entries = []
    temp_accs = []  # temporary list to store accuracy for each text_idx

    # Collect top_k_entries for each concept
    for text_idx in range(concepts_centered.shape[0]):
        # Retrieve data
        data_pcs = sorted_data[text_idx]
        top_k_entries = top_data(data_pcs, pcs_per_class)
        print(f"Currently processing label: {concepts_to_add[text_idx]} with nr_pcs_per_class: {pcs_per_class}")
        entries += top_k_entries

    # Remove duplicates
    entries_set = []
    entries_meta = []
    for entry in entries:
        layer = entry["layer"]
        head = entry["head"]
        princ_comp = entry["princ_comp"]
        if (layer, head, princ_comp) not in entries_meta:
            entries_meta.append((layer, head, princ_comp))
            entries_set.append(entry)

    print(f"Total number of unique entries: {len(entries_set)}")

    # If `random` is True, randomly pick PCs instead of the actual top_k
    if random:
        data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)
        entries_set = random_pcs(data, pcs_per_class * len(classes_))

    # Reconstruct final_embeddings_images
    reconstructed_images = reconstruct_all_embeddings_mean_ablation_pcs(
        entries_set,
        mlps_,
        attns_,
        attns_,
        nr_layers_,
        nr_heads_,
        num_last_layers_,
        ratio=-1,
        mean_ablate_all=True
    )

    reconstructed_images /= reconstructed_images.norm(dim=-1, keepdim=True)
    predictions = reconstructed_images @ classifier_

    # Evaluate across all text_idx
    # (If you have different labels per text_idx, adapt accordingly.)
    # For simplicity, assume `test_accuracy` returns (acc, idxs) for the entire set.
    acc, idxs = test_accuracy(predictions, labels_, label="All classes combined")
    total_accuracy.append(acc)

    # If you need to evaluate `acc` separately for each text_idx, do so in the loop above.
    # Then store the average or final result below.

    # You might also be printing correctness here:
    print_correct_elements(idxs, labels_, classes_)

    # Worst-class accuracy for Waterbirds, if applicable
    if classes_ == waterbird_classes:
        curr_worst_acc = test_waterbird_preds(idxs, labels_, background_groups_)
        if curr_worst_acc > max_worst_acc:
            max_worst_acc = curr_worst_acc
            pcs_per_class_max_worst_acc = pcs_per_class

        worst_class_acc.append(curr_worst_acc)
        worst_class_nr_pcs.append(pcs_per_class)

# ---------------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# ---------------------------------------
# -------------------------------
# COMPUTE PARETO OPTIMAL CANDIDATE
# -------------------------------
# A candidate is one where both total accuracy and worst-class accuracy exceed their baselines.
# Among these, we select the one with the highest worst-class accuracy.
# If there is a tie, we pick the candidate with the higher total accuracy.
total_accuracy_arr = np.array(total_accuracy)
worst_class_acc_arr = np.array(worst_class_acc)
worst_class_nr_pcs_arr = np.array(worst_class_nr_pcs)

candidates_mask = (total_accuracy_arr > baseline_acc) & (worst_class_acc_arr > baseline_worst)
if np.any(candidates_mask):
    candidate_indices = np.where(candidates_mask)[0]
    best_idx = candidate_indices[0]
    for idx in candidate_indices:
        if worst_class_acc_arr[idx] > worst_class_acc_arr[best_idx]:
            best_idx = idx
        elif worst_class_acc_arr[idx] == worst_class_acc_arr[best_idx] and total_accuracy_arr[idx] > total_accuracy_arr[best_idx]:
            best_idx = idx
    pareto_pcs = worst_class_nr_pcs_arr[best_idx]
    pareto_worst = worst_class_acc_arr[best_idx]
    pareto_total = total_accuracy_arr[best_idx]
else:
    pareto_pcs = None

# -------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# -------------------------------
# (Only applicable if classes_ == waterbird_classes)
if classes_ == waterbird_classes:
    # Recompute max values for worst-class accuracy for clarity
    max_worst_acc = max(worst_class_acc)
    pcs_per_class_max_worst_acc = worst_class_nr_pcs[np.argmax(worst_class_acc)]

    plt.figure(figsize=(6, 4))
    # Plot the worst-class accuracy curve
    plt.plot(
        worst_class_nr_pcs, worst_class_acc,
        color='blue',
        linestyle='-',
        label='Worst-Class Accuracy'
    )
    # Plot baseline worst-class accuracy
    plt.axhline(
        y=baseline_worst,
        color='gray',
        linestyle='--',
        linewidth=2,
        label=f'Baseline Worst-Class Acc = {baseline_worst:.2f}'
    )
    # Highlight the maximum worst-class accuracy with horizontal and vertical lines
    plt.axhline(
        y=max_worst_acc,
        color='blue',
        linestyle=':',
        linewidth=2
    )
    plt.axvline(
        x=pcs_per_class_max_worst_acc,
        color='blue',
        linestyle=':',
        linewidth=2
    )
    # --- Add Vertical Line for Pareto Optimal Solution (if found) ---
    if pareto_pcs is not None:
        plt.axvline(
            x=pareto_pcs,
            color='green',
            linestyle='--',
            linewidth=2
        )
        pareto_line_legend = mlines.Line2D(
            [],
            [],
            color='green',
            linestyle='--',
            linewidth=2,
            label=f'Pareto Optimal: Worst Acc = {pareto_worst:.2f}, Total Acc = {pareto_total:.2f} at PCs={pareto_pcs}'
        )
    else:
        pareto_line_legend = None

    # Custom legend entry for max worst-class accuracy
    max_line_legend = mlines.Line2D(
        [],
        [],
        color='blue',
        linestyle=':',
        linewidth=2,
        label=f'Max Worst-Class Acc = {max_worst_acc:.2f} at PCs={pcs_per_class_max_worst_acc}\n(Worst total accuracy is {total_accuracy[np.argmax(worst_class_acc)]:.2f})'
    )

    # Get current legend handles and append our custom entries
    handles, labels = plt.gca().get_legend_handles_labels()
    handles.append(max_line_legend)
    labels.append(max_line_legend.get_label())
    if pareto_line_legend is not None:
        handles.append(pareto_line_legend)
        labels.append(pareto_line_legend.get_label())

    plt.xlabel('Number of PCs per Class')
    plt.ylabel('Worst Accuracy')
    plt.title('Worst-Class Accuracy vs. Number of PCs per Class')
    plt.grid(True)
    plt.legend(handles, labels)
    plt.tight_layout()
    plt.savefig(f"plt_3_worst_{model_name}.pdf", bbox_inches='tight', format='pdf')
    plt.show()

    print(f"Max worst accuracy of {max_worst_acc:.2f} found at {pcs_per_class_max_worst_acc} PCs/class.")
    if pareto_pcs is not None:
        print(f"Pareto optimal solution at {pareto_pcs} PCs/class with Worst Acc = {pareto_worst:.2f} and Total Acc = {pareto_total:.2f}.")

# -------------------------------
# PLOTTING SECTION for TOTAL ACCURACY
# -------------------------------
# Compute max total accuracy details
max_total_acc = max(total_accuracy)
pcs_for_max_total_acc = worst_class_nr_pcs[np.argmax(total_accuracy)]

plt.figure(figsize=(6, 4))
# Plot the total accuracy curve
plt.plot(
    worst_class_nr_pcs, total_accuracy,
    linestyle='-', color='orange',
    label='Total Accuracy'
)
# Plot baseline total accuracy
plt.axhline(
    y=baseline_acc,
    color='gray',
    linestyle='--',
    linewidth=2,
    label=f'Baseline Total Acc = {baseline_acc:.2f}'
)
# Highlight the maximum total accuracy with horizontal and vertical lines
plt.axhline(
    y=max_total_acc,
    color='orange',
    linestyle=':',
    linewidth=2
)
plt.axvline(
    x=pcs_for_max_total_acc,
    color='orange',
    linestyle=':',
    linewidth=2
)
# --- Add Vertical Line for Pareto Optimal Solution (if found) ---
if pareto_pcs is not None:
    plt.axvline(
        x=pareto_pcs,
        color='green',
        linestyle='--',
        linewidth=2
    )
    pareto_line_legend_2 = mlines.Line2D(
        [],
        [],
        color='green',
        linestyle='--',
        linewidth=2,
        label=f'Pareto Optimal: Worst Acc = {pareto_worst:.2f}, Total Acc = {pareto_total:.2f} at PCs={pareto_pcs}'
    )
else:
    pareto_line_legend_2 = None

# Custom legend entry for max total accuracy
max_line_legend_2 = mlines.Line2D(
    [],
    [],
    color='orange',
    linestyle=':',
    linewidth=2,
    label=f'Max Total Acc = {max_total_acc:.2f} at PCs={pcs_for_max_total_acc}\n(Worst class accuracy is {worst_class_acc[np.argmax(total_accuracy)]:.2f})'
)

handles2, labels2 = plt.gca().get_legend_handles_labels()
handles2.append(max_line_legend_2)
labels2.append(max_line_legend_2.get_label())
if pareto_line_legend_2 is not None:
    handles2.append(pareto_line_legend_2)
    labels2.append(pareto_line_legend_2.get_label())

plt.xlabel('Number of PCs per Class')
plt.ylabel('Total Accuracy (Avg)')
plt.title('Total Accuracy vs. Number of PCs per Class')
plt.grid(True)
plt.legend(handles2, labels2)
plt.tight_layout()
plt.savefig(f"plt_3_acc_{model_name}.pdf", bbox_inches='tight', format='pdf')
plt.show()

print(f"Max total accuracy of {max_total_acc:.2f} found at {pcs_for_max_total_acc}")

## Proof of concept 4: Compare cosine of reconstruction using PCs of one class vs. all classes

In [None]:
pcs_per_class_start = 1
pcs_per_class_end = int(nr_heads_*num_last_layers_*mean_rank_)
pcs_per_class_step = 10
max_pcs_per_head = -1
random = False
class_embeddings = classifier_.T  # M x D

# Print baseline accuracy
# Baseline accuracy computation:
baseline = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1)
baseline_acc, idxs = test_accuracy(baseline @ classifier_, labels_, label="Baseline")
if classes_ == waterbird_classes:
    baseline_worst = test_waterbird_preds(idxs, labels_, background_groups_)
# Reconstruct embeddings for each class label

# Get mean of data and texts
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

classes_centered = class_embeddings - mean_final_texts.unsqueeze(0)

# Initialize a (num_images x 2) array to track:
#   [best_score_so_far, class_index_for_that_score]

sorted_data = []
for text_idx in range(classes_centered.shape[0]):
    # Perform query system on entry
    concept_i_centered = classes_centered[text_idx, :].unsqueeze(0)

    data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)

    _, data_abs = reconstruct_embeddings(
        data, 
        [concept_i_centered], 
        ["text"], 
        return_princ_comp=True, 
        plot=False, 
        means=[mean_final_texts],
    )

    # Extract relevant details from the top k entries
    data_pcs = sort_data_by(data_abs, "correlation_princ_comp_abs", descending=True)
    # Derive nr_pcs_per_class
    sorted_data.append(data_pcs)

pcs_per_class_max_worst_acc = pcs_per_class_start
max_worst_acc = 0

worst_class_acc = []
worst_class_nr_pcs = []

total_accuracy = []  # Will store the average accuracy across all text_idx for each pcs_per_class

for pcs_per_class in range(pcs_per_class_start, pcs_per_class_end, pcs_per_class_step):
    entries = []
    temp_accs = []  # temporary list to store accuracy for each text_idx
    all_preds = torch.zeros((final_embeddings_images.shape[0], 2), dtype=torch.double)

    for text_idx in range(classes_centered.shape[0]):
        # Retrieve data
        data_pcs = sorted_data[text_idx]
        top_k_entries = top_data(data_pcs, pcs_per_class)
        # If `random` is True, randomly pick PCs instead of the actual top_k
        if random:
            data = get_data(attention_dataset, max_pcs_per_head, skip_final=True)
            top_k_entries = random_pcs(data, pcs_per_class * len(classes_))
            
        # Reconstruct final_embeddings_images
        reconstructed_images = reconstruct_all_embeddings_mean_ablation_pcs(
            top_k_entries,
            mlps_,
            attns_, 
            attns_,
            nr_layers_,
            nr_heads_,
            num_last_layers_,
            ratio=-1,
            mean_ablate_all=False
        )

        reconstructed_images /= reconstructed_images.norm(dim=-1, keepdim=True)
        predictions = reconstructed_images @ class_embeddings[text_idx, :].T #class_embeddings[text_idx, :].T
        # Update "best so far" scores in all_preds
        best_vals_this_round = predictions
        improved_mask = best_vals_this_round > all_preds[:, 0]

        best_idxs_this_round = torch.full_like(all_preds[:, 1], fill_value=text_idx)
        all_preds[improved_mask, 0] = best_vals_this_round[improved_mask].double()
        all_preds[improved_mask, 1] = best_idxs_this_round[improved_mask].double()

        # Optionally, check accuracy for the current text_idx predictions
        acc, idxs = test_accuracy(predictions.unsqueeze(-1), labels_, label=f"{classes_[text_idx]}")
        print_correct_elements(idxs, labels_, classes_)

        # Build a fictitious one-hot matrix from all_preds
        num_images = final_embeddings_images.shape[0]
        num_classes = classifier_.shape[1]  # Typically M x D => M classes => classifier_.shape[1] is #classes

        # Convert the best class index to a LongTensor
        best_class_idxs = all_preds[:, 1].long()

        # Create zero matrix [num_images, num_classes]
        fictitious_preds = torch.zeros((num_images, num_classes), device=best_class_idxs.device)

        # Fill 1.0 in the best predicted class for each image
        fictitious_preds[torch.arange(num_images), best_class_idxs] = 1.0

        # Test accuracy on these "hard" predictions
        acc_best, idxs_best = test_accuracy(fictitious_preds, labels_, label="Best So Far (One-Hot)")
        if text_idx == len(classes_centered) - 1:
            total_accuracy.append(acc_best)

        sorted_output = print_correct_elements(idxs_best, labels_, classes_)
        if classes_ == waterbird_classes:
            curr_worst_acc = test_waterbird_preds(idxs_best, labels_, background_groups_)
            if text_idx == len(classes_centered) - 1:
                if curr_worst_acc > max_worst_acc:
                    max_worst_acc = curr_worst_acc
                    pcs_per_class_max_worst_acc = pcs_per_class

                worst_class_acc.append(curr_worst_acc)
                worst_class_nr_pcs.append(pcs_per_class)

        # Print overall accuracy so far
        tot_sum = 0
        for _, el_nr in sorted_output:
            tot_sum += el_nr
        if subset_dim == None:
            print(f"Tot accuracy so far is {tot_sum/len(labels_)}")

        else:
            print(f"Tot accuracy so far is {tot_sum/((text_idx + 1) * subset_dim)}")    
# ---------------------------------------
# PLOTTING SECTION for WORST-CLASS ACCURACY
# ---------------------------------------

# Suppose these are already known/computed:
max_worst_acc = max(worst_class_acc)
pcs_per_class_max_worst_acc = worst_class_nr_pcs[np.argmax(worst_class_acc)]

plt.figure(figsize=(6, 4))

# Plot the main data
plt.plot(
    worst_class_nr_pcs, worst_class_acc, color='blue', 
    linestyle='-', 
    label='Worst-Class Accuracy'
)

# Baseline line
plt.axhline(
    y=baseline_worst, 
    color='gray', 
    linestyle='--', 
    linewidth=2, 
    label=f'Baseline Worst-Class Acc = {baseline_worst:.2f}' 
)

# Horizontal line for max
plt.axhline(
    y=max_worst_acc, 
    color='blue', 
    linestyle=':', 
    linewidth=2
)

# Vertical line for max
plt.axvline(
    x=pcs_per_class_max_worst_acc, 
    color='blue', 
    linestyle=':', 
    linewidth=2
)

# Create a single custom legend entry for both max lines
max_line_legend = mlines.Line2D(
    [], [], 
    color='blue', 
    linestyle=':', 
    linewidth=2,
    label=f'Max Worst-Class Acc = {max_worst_acc:.2f} at PCs={pcs_per_class_max_worst_acc}\n(Worst total accuracy is {total_accuracy[np.argmax(worst_class_acc)]:.2f})'
)

# Collect existing handles and labels from the current axes
handles, labels = plt.gca().get_legend_handles_labels()

# Append the custom max-line legend entry
handles.append(max_line_legend)
labels.append(max_line_legend.get_label())

plt.xlabel('Number of PCs per Class')
plt.ylabel('Worst Accuracy')
plt.title('Worst-Class Accuracy vs. Number of PCs per Class')
plt.grid(True)
plt.legend(handles, labels)  # Use your updated handles and labels
plt.show()

print(f"Max worst accuracy of {max_worst_acc:.2f} found at {pcs_per_class_max_worst_acc} PCs/class.")

# ----------------------------------------------------------------------------------
# PLOTTING SECTION for TOTAL ACCURACY
# ----------------------------------------------------------------------------------

# Suppose these are already known/computed:
max_total_acc = max(total_accuracy)
pcs_for_max_total_acc = worst_class_nr_pcs[np.argmax(total_accuracy)]

plt.figure(figsize=(6, 4))

# Plot the main data
plt.plot(
    worst_class_nr_pcs, total_accuracy, 
    linestyle='-', color='orange', 
    label='Total Accuracy'
)

# Baseline line
plt.axhline(
    y=baseline_acc, 
    color='gray', 
    linestyle='--', 
    linewidth=2, 
    label=f'Baseline Total Acc = {baseline_acc:.2f}'
)

# Horizontal line for the max total accuracy
plt.axhline(
    y=max_total_acc, 
    color='orange', 
    linestyle=':', 
    linewidth=2
)

# Vertical line for the max total accuracy
plt.axvline(
    x=pcs_for_max_total_acc, 
    color='orange', 
    linestyle=':', 
    linewidth=2
)

# Create a single custom legend entry for both max lines
max_line_legend_2 = mlines.Line2D(
    [], [], 
    color='orange', 
    linestyle=':', 
    linewidth=2,
    label=f'Max Total Acc = {max_total_acc:.2f} at PCs={pcs_for_max_total_acc}\n(Worst class accuracy is {worst_class_acc[np.argmax(total_accuracy)]:.2f})'
)

# Collect existing handles/labels
handles2, labels2 = plt.gca().get_legend_handles_labels()

# Append the custom line
handles2.append(max_line_legend_2)
labels2.append(max_line_legend_2.get_label())

plt.xlabel('Number of PCs per Class')
plt.ylabel('Total Accuracy (Avg)')
plt.title('Total Accuracy vs. Number of PCs per Class')
plt.grid(True)
plt.legend(handles2, labels2)
plt.show()

print(f"Max total accuracy of {max_total_acc:.2f} found at {pcs_for_max_total_acc}")

## Comparison Textspan with waterbird

In [None]:
# Retrive heads textspan
if model_name == "ViT-H-14":
    to_mean_ablate_setting = [
        {"layer": 31, "head": 12},
        {"layer": 30, "head": 11},
        {"layer": 29, "head": 4},
    ]
    to_mean_ablate_geo = [
        {"layer": 31, "head": 8},
        {"layer": 30, "head": 15},
        {"layer": 30, "head": 12},
        {"layer": 30, "head": 6},
        {"layer": 29, "head": 14},
        {"layer": 29, "head": 8},
    ]

elif model_name == "ViT-L-14":
    to_mean_ablate_geo = [
        {"layer": 21, "head": 1},
        {"layer": 22, "head": 12},
        {"layer": 22, "head": 13},
        {"layer": 21, "head": 11},
        {"layer": 21, "head": 14},
        {"layer": 23, "head": 6},
    ]
    to_mean_ablate_setting = [
        {"layer": 21, "head": 3},
        {"layer": 21, "head": 6},
        {"layer": 21, "head": 8},
        {"layer": 21, "head": 13},
        {"layer": 22, "head": 2},
        {"layer": 22, "head": 12},
        {"layer": 22, "head": 15},
        {"layer": 23, "head": 1},
        {"layer": 23, "head": 3},
        {"layer": 23, "head": 5},
    ]

elif model_name == "ViT-B-16":
    to_mean_ablate_setting = [
        {"layer": 11, "head": 3},
        {"layer": 10, "head": 11},
        {"layer": 10, "head": 10},
        {"layer": 9, "head": 8},
        {"layer": 9, "head": 6},
    ]
    to_mean_ablate_geo = [
        {"layer": 11, "head": 6},
        {"layer": 11, "head": 0},
    ]

elif model_name == "ViT-B-32":
    to_mean_ablate_setting = [
        {"layer": 11, "head":5},
        {"layer": 10, "head": 5},
        {"layer": 10, "head": 3},
        {"layer": 9, "head": 1},
    ]
    to_mean_ablate_geo = [
        {"layer": 11, "head": 9},
        {"layer": 11, "head": 5},
    ]
to_mean_ablate_geo_heads = to_mean_ablate_setting + to_mean_ablate_geo
all_heads = [{"layer": l, "head": h} for l in range(nr_layers_ - num_last_layers_, nr_layers_) for h in range(nr_heads_)]
for h_1 in to_mean_ablate_geo_heads:
    for c, h_2 in enumerate(all_heads):
        if h_1 == h_2:
            all_heads = all_heads[:c] + all_heads[c+1:]
            break

reconstructed_images = reconstruct_all_embeddings_mean_ablation_heads(all_heads, mlps_, attns_, final_embeddings_images,nr_layers_, nr_heads_, num_last_layers_)
reconstructed_images /= reconstructed_images.norm(dim=-1, keepdim=True)
predictions = reconstructed_images @ classifier_ #class_embeddings[text_idx, :].T

# Optionally, check accuracy for the current text_idx predictions
acc, idxs = test_accuracy(predictions, labels_, label=f"Textspan")
print_correct_elements(idxs, labels_, classes_)    
if classes_ == waterbird_classes:
    test_waterbird_preds(idxs, labels_, background_groups_)



## Test different accuracies

In [None]:
# Print shapes of the tensors for debugging purposes:
# attns_: attention activations
# mlps_: MLP activations
# classifier_: classifier weights
# labels_: ground truth labels
print(attns_.shape, mlps_.shape, classifier_.shape, labels_.shape)


# Baseline accuracy computation:
baseline = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1)
test_accuracy(baseline @ classifier_, labels_, label="Baseline")
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

# Test accuracy of mean centered data with mean centered text
mean_centered_data = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1) - mean_final_images
mean_centered_data /= mean_centered_data.norm(dim=-1, keepdim=True)
mean_centered_classifier_ = classifier_ - mean_final_texts.unsqueeze(-1)
mean_centered_classifier_ /= mean_centered_classifier_.norm(dim=-1, keepdim=True)

test_accuracy(mean_centered_data @ mean_centered_classifier_, labels_, label="Mean centered data with mean centered text")

# Test accuracy of mean centered data with original text
mean_centered_data = attns_.sum(axis=(1, 2)) + mlps_.sum(axis=1) - mean_final_images
mean_centered_data /= mean_centered_data.norm(dim=-1, keepdim=True)
mean_centered_data += mean_final_images
mean_centered_data /= mean_centered_data.norm(dim=-1, keepdim=True)
test_accuracy(mean_centered_data @ classifier_, labels_, label="Mean centered data with original (not mean centered) text")

# We now attempt a "mean ablation" approach for attention
current_model = (current_mean_ablation_per_head_sum_
                 + no_heads_attentions_[:, last_ + 1:].sum(1)) + mlps_.sum(axis=1) 
_, indexes_mean_ablate = test_accuracy(current_model @ classifier_, labels_, label=f"Mean ablation from layer {last_} until layer {attns_.shape[1]}")

# We now attempt a "mean ablation" approach for attention
current_model = (current_mean_ablation_per_head_sum_
                 + no_heads_attentions_[:, last_ + 1:].sum(1)) + mlps_.sum(axis=1) 
current_model -= mean_final_images
current_model /= current_model.norm(dim=-1, keepdim=True)
_, indexes_mean_ablate = test_accuracy(current_model @ classifier_, labels_, label=f"Mean ablation from layer {last_} until layer {attns_.shape[1]} with mean centered images")

## Test different accuracies using reconstructions

In [None]:
final_embeddings_images_rec_embed = torch.zeros_like(final_embeddings_images)
final_embeddings_images_rec_attns = torch.zeros_like(final_embeddings_images)
final_embeddings_images_rec_attns_not_mean_centered = torch.zeros_like(final_embeddings_images)
image_emb_cent_embed = final_embeddings_images - mean_final_images

final_embeddings_texts_rec_embed = torch.zeros_like(classifier_.T)
texts_emb_cent_embed = (classifier_ - mean_final_texts.unsqueeze(-1)).T

# Open the attention dataset to retrieve projection matrices and mean values
with open(attention_dataset, "r") as json_file:
    for line in json_file:
        entry = json.loads(line)
        # If this entry is the final embedding entry (head == -1), skip it.
        if entry["head"] == -1:
            last_line = entry
            continue

        project_matrix = torch.tensor(entry["project_matrix"])
        vh = torch.tensor(entry["vh"])
        # Reconstruct the image embeddings using final embeddings:
        # Center them by subtracting mean attention values, project them through vh, 
        # apply project_matrix and vh again, then add mean values back.
        final_embeddings_images_rec_embed += (image_emb_cent_embed) @ vh.T @ project_matrix @ vh
        final_embeddings_texts_rec_embed += (texts_emb_cent_embed) @ vh.T @ project_matrix @ vh
        # Reconstruct the image embeddings using attention activations:
        # Similar process, but start from attns_ for the given layer/head.
        image_emb_cent_attns = attns_[:, entry["layer"], entry["head"], :] - torch.tensor(entry["mean_values_att"])
        final_embeddings_images_rec_attns += (image_emb_cent_attns) @ vh.T @ project_matrix @ vh + torch.tensor(entry["mean_values_att"])
        final_embeddings_images_rec_attns_not_mean_centered += (image_emb_cent_attns) @ vh.T @ project_matrix @ vh

final_embeddings_images_rec_embed_norm = final_embeddings_images_rec_embed/final_embeddings_images_rec_embed.norm(dim=-1, keepdim=True)

final_embeddings_texts_rec_embed_norm = final_embeddings_texts_rec_embed/final_embeddings_texts_rec_embed.norm(dim=-1, keepdim=True)

final_embeddings_images_rec_attns_not_mean_centered_norm = final_embeddings_images_rec_attns_not_mean_centered/final_embeddings_images_rec_attns_not_mean_centered.norm(dim=-1, keepdim=True)

texts_emb_cent_embed /= texts_emb_cent_embed.norm(dim=-1, keepdim=True)
# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = final_embeddings_images_rec_embed_norm + mean_final_images
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation with final embeddings on only the last layers")

current_model = mlps_.sum(axis=1) + current_mean_ablation_per_head_sum_ + final_embeddings_images_rec_attns_not_mean_centered
_, indexes_approx_activ_only = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation of images with direct contribution of activation space")

# Evaluate accuracy using the reconstructed embeddings from the attention activations approach
current_model = (mlps_.sum(axis=1) + current_mean_ablation_per_head_sum_ + final_embeddings_images_rec_attns)
_, indexes_approx_activ = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation with attention activations")


# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = final_embeddings_images_rec_embed+ mean_final_images
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation of images without mean-ablation")


# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = final_embeddings_images_rec_embed+ mean_final_images
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Approximation of images and texts without mean-ablation")

# Evaluate accuracy using the reconstructed embeddings from the final embedding approach
current_model = (mlps_.sum(axis=1) + current_mean_ablation_per_head_sum_ + final_embeddings_images_rec_attns)
_, indexes_approx_final = test_accuracy(current_model @ classifier_, labels_, label=f"Original images and approximation of texts without mean-ablation")


## Test Bias Correction

In [None]:
cache_dir = "../cache"
top_k = 30  # Maximum number of top entries to retrieve
approx = 1.1  # Target approximation threshold for the reconstruction quality
## Run the chosen algorithm on a dataset to derive text explanations 
command = f"python -m utils.scripts.bias_removal_test \
    --device {device} --model {model_name} --pretrained {pretrained} --seed {seed} \
    --subset_dim {subset_dim} --dataset_text {dataset_text_name} --dataset {datataset_image_name} \
    --device {device} --top_k {top_k} --max_approx {approx} --cache_dir {cache_dir}"
!{command}

## Test bias removal and subset model

In [None]:
# Next, we calculate scores for each principal component (PC) while ignoring query info.
# We do this by reconstructing embeddings from the principal components alone, both from 
# the final embeddings perspective and the attention activations perspective.

final_embeddings_images_rec_embed_topic = torch.zeros_like(final_embeddings_images)
final_embeddings_images_rec_attns_topic = torch.zeros_like(final_embeddings_images)

image_emb_cent_embed = final_embeddings_images - mean_final_images

top_k_other_details = get_remaining_pcs(data, top_k_details)

# Iterate through the top_k entries and reconstruct embeddings
for entry in top_k_details:
    # Reconstruct embeddings focusing on each principal component:
    # 1. Start from the final embeddings, center them, and extract the component of interest.
    vh = torch.tensor(entry["vh"])
    project_matrix = torch.tensor(entry["project_matrix"])
    princ_comp = torch.tensor(entry["princ_comp"])

    projection_image_embed = image_emb_cent_embed @ vh.T
    mask_images_embed = torch.zeros_like(projection_image_embed)
    mask_images_embed[:, princ_comp] = projection_image_embed[:, princ_comp]
    final_embeddings_images_rec_embed_topic += mask_images_embed @ project_matrix @ vh

    # Repeat for attention-based activations:
    mean_values_att = torch.tensor(entry["mean_values_att"])
    image_emb_cent_attns = attns_[:, entry["layer"], entry["head"], :] - mean_values_att
    projection_images_attns = image_emb_cent_attns @ vh.T
    mask_images_attns = torch.zeros_like(projection_images_attns)
    mask_images_attns[:, princ_comp] = projection_images_attns[:, princ_comp]
    final_embeddings_images_rec_attns_topic += mask_images_attns @ project_matrix @ vh


# Compute accuracy using the reconstruction from final embeddings, ignoring the query information.
current_model = final_embeddings_images_rec_embed - final_embeddings_images_rec_embed_topic
_, indexes_approx_final_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Bias removal)")
print_diff_elements(indexes_approx_final, indexes_approx_final_rem, subset_dim)

# Compute accuracy using the reconstruction from attention activations, also ignoring the query information.
current_model = final_embeddings_images_rec_attns_not_mean_centered - final_embeddings_images_rec_attns_topic
_, indexs_approx_activ_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Bias Removal)")
print_diff_elements(indexes_approx_activ_only, indexs_approx_activ_rem, subset_dim)

# Compute accuracy using the reconstruction from final embeddings, ignoring the query information.
current_model = final_embeddings_images_rec_embed_topic
_, indexes_approx_final_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Subset)")
print_correct_elements(indexes_approx_final_rem, labels_, classes_)

# Compute accuracy using the reconstruction from attention activations, also ignoring the query information.
current_model = final_embeddings_images_rec_attns_topic
_, indexs_approx_activ_rem = test_accuracy(current_model @ texts_emb_cent_embed.T, labels_, label=f"Approximation with current topic final embeddings (Subset)")
print_correct_elements(indexs_approx_activ_rem, labels_, classes_)
