In [None]:
!nvidia-smi

In [None]:
import re
import torch
from PIL import Image
import matplotlib.pyplot as plt
import gc
import requests
import copy
from PIL import Image
from io import BytesIO
from torch.nn.functional import mse_loss
import numpy as np
import einops
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
from llava.conversation import conv_templates
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

from utils.scripts.utils_llava import *
from utils.misc.misc import accuracy, accuracy_correct
from utils.scripts.algorithms_text_explanations import *
from utils.models.factory import create_model_and_transforms, get_tokenizer
from utils.misc.visualization import visualization_preprocess
from utils.models.prs_hook import hook_prs_logger
from utils.datasets_constants.imagenet_classes import imagenet_classes
from utils.datasets_constants.cifar_10_classes import cifar_10_classes
from utils.datasets_constants.cub_classes import cub_classes, waterbird_classes
import os
from utils.datasets.dataset_helpers import dataset_to_dataloader
from utils.scripts.algorithms_text_explanations_funcs import *
import tqdm
from torchvision.transforms import ToPILImage
import copy

In [2]:
## Parameters
device = 'cuda'
seed = 0
num_last_layers_ = 4
subset_dim = 10
tot_samples_per_class = 50
dataset_text_name = "top_1500_nouns_5_sentences_imagenet_clean"
datataset_image_name = "imagenet"
cache_dir = "../cache"
path = './datasets/'

In [3]:
# Helper functions
def image_parser(image_file, sep=","):
    out = image_file.split(sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

device = "cuda"
model_name = "liuhaotian/llava-v1.5-7b"
model_path = "/cluster/work/vogtlab/Group/vstrozzi/cache/models--liuhaotian--llava-v1.5-7b/snapshots/4481d270cc22fd5c4d1bb5df129622006ccd9234/"

## Get LLava Model

In [None]:
### IMPORT: On Biomedcluster change .config under model_path to point towards correct vision_tower clip path
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_name),
)
model.to("cpu")

In [None]:
# CLIP model
model_CLIP_name = 'ViT-L-14-336' 
pretrained = "openai"
precision = "fp16"

torch.cuda.empty_cache()
model_CLIP, _, preprocess_clip = create_model_and_transforms(model_CLIP_name, pretrained=pretrained, precision=precision, cache_dir="../cache")

model_CLIP.eval()
context_length = model_CLIP.context_length
# Not needed anymore
vocab_size = model_CLIP.vocab_size
tokenizer_CLIP = get_tokenizer(model_CLIP_name)

print(model_CLIP)
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model_CLIP.visual.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Len of res:", len(model_CLIP.visual.transformer.resblocks))
# Hook necessary to have: no projection on shared space, no spatial tokens in output (i.e. contributuon of attention to tokens), and hidden outputs of all tokens
prs = hook_prs_logger(model_CLIP, device, spatial=False, vision_projection=False, full_output=True) # This attach hook to get the residual stream

final_embeddings_images = torch.tensor(np.load(f"output_dir/{datataset_image_name}_embeddings_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r")).to(device)
final_embeddings_texts = torch.tensor(np.load(f"output_dir/{dataset_text_name}_{model_CLIP_name}.npy", mmap_mode="r")).to(device, dtype=final_embeddings_images.dtype)

attns_hid_mean = torch.tensor(np.load(f"output_dir/{datataset_image_name}_attns_mean_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r")) # [l, n, h, d], attention values
mlps_hid_mean = torch.tensor(np.load(f"output_dir/{datataset_image_name}_mlps_mean_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r"))  # [l + 1, n, d], mlp values

if datataset_image_name == "imagenet":
    ds_ = ImageNet(root=path+"imagenet/", split="val", transform=visualization_preprocess)
elif datataset_image_name == "binary_waterbirds":
    ds_ = BinaryWaterbirds(root=path+"waterbird_complete95_forest2water2/", split="test", transform=visualization_preprocess)
elif datataset_image_name == "CIFAR100":
    ds_ = CIFAR100(
        root=path, download=True, train=False, transform=visualization_preprocess
    )
elif datataset_image_name == "CIFAR10":
    ds_ = CIFAR10(
        root=path, download=True, train=False, transform=visualization_preprocess
    )
else:
    ds_ = ImageFolder(root=path, transform=visualization_preprocess)

classes_ = {
        'imagenet': imagenet_classes, 
        'CIFAR10': cifar_10_classes,
        'waterbirds': cub_classes, 
        'binary_waterbirds': waterbird_classes, 
        'cub': cub_classes}[datataset_image_name]
        
# Depending
ds_vis_ = dataset_subset(
    ds_,
    samples_per_class=subset_dim,
    tot_samples_per_class=tot_samples_per_class,  # or whatever you prefer
    seed=seed,
)

with open( f"utils/text_descriptions/{dataset_text_name}.txt", "r") as f:
    texts_str = np.array([i.replace("\n", "") for i in f.readlines()])

## Here play around with LLava

In [20]:
def mean_ablate_head(attentions, mlps, select_layer, layers=[], heads=[], 
                     attentions_mean_abl=None, mlps_mean_abl=None, mean_ablate_mlps=False, mean_ablate_attns=True):

    # Clone the input tensors to prevent in-place modifications from affecting future calls
    attentions = attentions.clone()
    mlps = mlps.clone()

    # Compute the mean value over the selected layers, all heads, and tokens (excluding token index 0) 
    # or use predefined mean
    attentions_mean_abl = attentions_mean_abl.unsqueeze(0)
    mlps_mean_abl =  mlps_mean_abl.unsqueeze(0)

    # Replace the attention values for specified layers and heads with the computed mean ablation value
    if heads is not [] and layers is not []:
        for layer, head in zip(layers, heads):     
            if mean_ablate_attns:   
                attentions[:, layer, :, head, :] = attentions_mean_abl[:, layer, :, head, :]
            # If required, mean ablate mlps
            if mean_ablate_mlps:
                mlps[:, layer+1, :, :] = mlps_mean_abl[:, layer+1, :, :] 
    # Since MLPS has one more layer, mean ablate also nr. zero
    layers_set = set(layers)
    is_continuous_from_zero = layers_set == set(range(len(layers_set)))
    if mean_ablate_mlps == True and is_continuous_from_zero:
        mlps[:, 0, :, :] = mlps_mean_abl[:, 0, :, :] 
    # Aggregate the modified attention tensor by summing over layers and heads,
    # and add the corresponding summed MLP outputs 
    return attentions, mlps

def llava_pred(attentions, mlps, select_layer):

    return (attentions[:, :(select_layer + 1), :, :, :].sum(1).sum(2) +
            mlps[:, :(select_layer + 1), :, :].sum(1))

# Project 
def remove(a, b):
    a = a.squeeze()
    b = b.squeeze()
    return (a - (torch.dot(a, b) / torch.dot(b, b)) * b).unsqueeze(0)

def remove_patches(p, b):
    p = p.squeeze()
    for i in range(p.shape[0]):
        p[i:i+1, :] = remove(p[i:i+1, :], b)

    return p.unsqueeze(0)

In [None]:
def llava_infer(prompt, pil_image, images_embeds=False, mean_ablate=False, up_to_layer_ablate = 10, from_layer_ablate = False, mean_ablate_mlps=False, mean_ablate_attns=True, attentions_mean_abl=None, mlps_mean_abl=None): # If provided image embeds
    # Layer where to extract infos on patches
    select_layer = -2
    max_new_tokens = 512
    num_beams = 1 # numer of path of decision, less faster
    sep =  ","
    temperature = 0 # 0 lowest, det
    top_p = None

    ## Tokenization prompt
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    # Making prompt in correct format
    if IMAGE_PLACEHOLDER in prompt:
        if model.config.mm_use_im_start_end:
            prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
        else:
            prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
    else:
        if model.config.mm_use_im_start_end:
            prompt = image_token_se + "\n" + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    ## Convert model
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if conv_mode is not None and conv_mode != conv_mode:
        print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, conv_mode, conv_mode
            )
        )
    else:
        conv_mode = conv_mode

    ## Load conversation mode standard template 
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    # print(prompt)

    ## Load images from online or local
    """ image_files = image_parser(image_file, sep)
    images = load_images(image_files)
    image_sizes = [x.size for x in images] """
    
    image_sizes = [img.size]
    
    images_tensor = process_images(
        [pil_image],
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    ## Tokenize prompt
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(device)
    )

    attentions = None
    mlps = None
    ## Use CLIP Model
    if images_embeds:
        prs.reinit()
        model_CLIP.eval()
        with torch.no_grad():
            model.to("cpu")
            model_CLIP.to("cuda")
            # Get output as we want it
            spatial_features = model_CLIP.encode_image(
                    images_tensor.to(device), 
                    attn_method='head_no_spatial',
                    normalize=False
                )

            model_CLIP.to("cpu")

            # Move directions
            attentions, mlps = prs.finalize(spatial_features)  # attentions: [b, l, n, h, d], mlps: [b, l + 1, n, d]

            # Compute spatial features required by our layer 
            if mean_ablate:
                attentions, mlps = mean_ablate_head(attentions, mlps, select_layer,
                layers = [y for x in range(0, 16) for y in range(0, up_to_layer_ablate)] if not from_layer_ablate else
                         [y for x in range(0, 16) for y in range(up_to_layer_ablate, 24)],
                heads = [x for x in range(0, 16) for y in range(0, up_to_layer_ablate)] if not from_layer_ablate else
                        [x for x in range(0, 16) for y in range(up_to_layer_ablate, 24)],
                attentions_mean_abl = attentions_mean_abl,
                mlps_mean_abl = mlps_mean_abl,
                mean_ablate_mlps = mean_ablate_mlps,
                mean_ablate_attns = mean_ablate_attns)
            
    
            images_tensor = llava_pred(attentions, mlps, select_layer)
            images_tensor[:, 1:], # Skip CLS
        # images_tensor = remove_patches(images_tensor, invert_topic_emb)
        # print(images_tensor.shape)
        # Swap some features position 
        # images_tensor[:, :192, :] = images_tensor[:, 384:, :] 
        # images_tensor[:, 192:384, :] = images_tensor[:, 384:, :] 
        # images_tensor[:, 384:, :] = images_tensor[:, 384:, :] 


    ## Generate an answer by using full model LLava
    model.to("cuda")
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            #do_sample= True if temperature > 0 else False,
            #temperature=temperature,
            #top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            images_embeds = images_embeds # If want to give images embeds already precomputed TODO: Only support 1 image

        )

    ## Print the output
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

    return outputs, attentions, mlps


# Params
prompt = "You are a vision-language expert. Analyze the given image and classify it into one of the following categories:  \
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']. Answer with only the most appropriate category."
image_file = "images/catdog.png"

## Visualize image
img = Image.open(image_file)
plt.imshow(img)
plt.axis('off')  # Hide axis ticks and labels
plt.show()

llava_infer(prompt, img, images_embeds = True, mean_ablate = True, up_to_layer_ablate = 20, from_layer_ablate = False, attentions_mean_abl=attns_hid_mean, mlps_mean_abl=mlps_hid_mean,  mean_ablate_mlps=True, mean_ablate_attns=True)

In [26]:
# The image features a brown and white dog and a brown and black cat sitting together on a carpeted floor. They appear to be relaxed and comfortable in each other's company. The dog is positioned on the left side of the cat, with both animals facing the same direction.
# In the background, there is a bookshelf with several books on it, adding a cozy and lived-in atmosphere to the scene.

## Qualitative test of MLPS ablations of LLAVA on some images

In [27]:
def visualize_most_similar_texts_images_clip(clip_output, final_embeddings_images, final_embeddings_texts, ds_vis, classes, texts_str):
    # Visualize ds Initialize arrays to store the top score
    scores_array_images = np.empty(
        final_embeddings_images.shape[0], 
        dtype=[('score', 'f4'), ('score_vis', 'f4'), ('img_index', 'i4')]
    )

    scores_array_texts = np.empty(
        final_embeddings_texts.shape[0], 
        dtype=[('score', 'f4'), ('score_vis', 'f4'), ('txt_index', 'i4')]
    )

    # Compute mean embeddings for centering
    mean_final_images = torch.mean(final_embeddings_images,  axis=0)
    mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

    # Create arrays of indexes for referencing images and texts.
    indexes_images = np.arange(0, final_embeddings_images.shape[0], 1) 
    indexes_texts = np.arange(0, final_embeddings_texts.shape[0], 1)

    # Get mean of data and texts
    mean_final_images = torch.mean(final_embeddings_images, axis=0).to(device)

    # Compute scores for images

    scores_array_images["score_vis"] = (final_embeddings_images @ clip_output.T).squeeze().cpu().numpy()
    scores_array_texts["score_vis"] = (final_embeddings_texts @ clip_output.T).squeeze().cpu().numpy()

    clip_output /= clip_output.norm(dim=-1, keepdim=True)
    final_embeddings_images /= final_embeddings_images.norm(dim=-1, keepdim=True)
    final_embeddings_texts /= final_embeddings_texts.norm(dim=-1, keepdim=True)

    scores_array_images["score"] = (final_embeddings_images @ clip_output.T).squeeze().cpu().numpy()
    scores_array_texts["score"] = (final_embeddings_texts @ clip_output.T).squeeze().cpu().numpy()

    scores_array_images["img_index"] = indexes_images
    scores_array_texts["txt_index"] = indexes_texts


    # Define the number of top and worst images to look at for each princ_comp
    nr_top_imgs = 8  # Number of top elements
    nr_worst_imgs = 0  # Number of worst elements
    nr_cont_imgs = 0  # Length of continuous elements

    dbs = create_dbs(scores_array_images, scores_array_texts, nr_top_imgs, nr_worst_imgs, nr_cont_imgs)

    # Hardcoded visualizations
    nrs_dbs = [nr_top_imgs, nr_worst_imgs, nr_cont_imgs]
    dbs_new = []
    for i, db in enumerate(dbs):
        if nrs_dbs[i] == 0:
            continue
        dbs_new.append(db)
    # Visualize 
    visualize_dbs_no_data(dbs_new, ds_vis, texts_str, classes)

In [None]:
# Prompt and images
prompt = "Describe me in details the following image."
images_files = ["images/catdog.png", "images/four_people.png"]

# Until which which layer ablate and if want to reverse (not until but to layer)
up_to_layer_ablate = 0
from_layer_ablate = True
mean_ablate_mlps = True
mean_ablate_attns = True

# Extract necessary pprojection for clip
ln_post = copy.deepcopy(model_CLIP.visual.ln_post).to("cuda")
proj = copy.deepcopy(model_CLIP.visual.proj).to("cuda")
# Main analysis
for image_file in images_files:
    print(f"Displaying the image {image_file}")
    # Load as PIL
    img = Image.open(image_file)
    # Show it
    plt.imshow(img)
    plt.axis('off')  # Hide axis ticks and labels
    plt.show()

    ## Call LLAVA on normal image without modifying
    output, attns, mlps = llava_infer(prompt, img, images_embeds = True)
    print(attns.sum(3).sum(1).norm())
    print(mlps.sum(1).norm())
    print("Original LLAVA output")
    print(output)
    print()
    # Call CLIP on normal image 
    hidden_output_test = ln_post(attns[:, :, 0].sum(1).sum(1) + mlps[:, :, 0].sum(1)) # only CLS token
    test_clip_out = hidden_output_test @ proj
    print("Original CLIP  output text and images")
    visualize_most_similar_texts_images_clip(test_clip_out.detach(), final_embeddings_images, final_embeddings_texts, ds_vis_, classes_, texts_str)
    print()
    print()

    ## Call LLAVA on mlps mean ablate image
    output, attns, mlps = llava_infer(prompt, img, images_embeds = True, mean_ablate = True, up_to_layer_ablate = up_to_layer_ablate, from_layer_ablate= from_layer_ablate,attentions_mean_abl=attns_hid_mean, mlps_mean_abl=mlps_hid_mean, mean_ablate_mlps=mean_ablate_mlps, mean_ablate_attns=mean_ablate_attns)
    print(attns.sum(3).sum(1).norm())
    print(mlps.sum(1).norm())

    print("MLPS mean ablated LLAVA output")
    print(output)
    print()
    # Call CLIP on mlps mean ablated image 
    hidden_output_test = ln_post(attns[:, :, 0].sum(1).sum(1) + mlps[:, :, 0].sum(1)) # only CLS token
    test_clip_out = hidden_output_test @ proj
    print("MLPS mean ablated CLIP output text and images")
    visualize_most_similar_texts_images_clip(test_clip_out.detach(), final_embeddings_images, final_embeddings_texts, ds_vis_, classes_, texts_str)
    print()
    print()
    print()


## Test zero-shot accuracy VLM on datasets

In [9]:
def get_prompt_formatted(prompt, model, conv_templates):
    ## Tokenization prompt
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    # Making prompt in correct format
    if IMAGE_PLACEHOLDER in prompt:
        if model.config.mm_use_im_start_end:
            prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
        else:
            prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
    else:
        if model.config.mm_use_im_start_end:
            prompt = image_token_se + "\n" + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    ## Convert model
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if conv_mode is not None and conv_mode != conv_mode:
        print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, conv_mode, conv_mode
            )
        )
    else:
        conv_mode = conv_mode

    ## Load conversation mode standard template 
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    ## Tokenize prompt
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(device)
    )

    return input_ids

In [None]:
# Constants and fixed configuration
seed = 0
subset_dim = 10
tot_samples_per_class = 1000
path = './datasets/'
batch_size = 1  # ToDO: ONLY WORK WITH B DIM 1 NOW
dataset_name = "CIFAR10"
classes_ = {
        'imagenet': imagenet_classes, 
        'CIFAR10': cifar_10_classes,
        'waterbirds': cub_classes, 
        'binary_waterbirds': waterbird_classes, 
        'cub': cub_classes}[dataset_name]

prompt = f"You are a vision-language expert. Analyze the given image and classify it into one of the following categories:  \
{classes_}. Answer with only the most appropriate category in lower case."

# Load dataset
if dataset_name == "imagenet":
    ds_ = ImageNet(root=path+"imagenet/", split="val", transform=preprocess_clip)
elif dataset_name == "binary_waterbirds":
    ds_ = BinaryWaterbirds(root=path+"waterbird_complete95_forest2water2/", split="test", transform=preprocess_clip)
elif dataset_name == "CIFAR100":
    ds_ = CIFAR100(
        root=path, download=True, train=False, transform=preprocess_clip
    )
elif dataset_name == "CIFAR10":
    ds_ = CIFAR10(
        root=path, download=True, train=False, transform=preprocess_clip
    )
else:
    ds_ = ImageFolder(root=path, transform=preprocess_clip)

dataloader = dataset_to_dataloader(
    ds_,
    samples_per_class=subset_dim,
    tot_samples_per_class=tot_samples_per_class,  # or whatever you prefer
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    seed=seed,
)

# Load classifier
classifier_ = torch.tensor(np.load(f"output_dir/{dataset_name}_classifier_{model_CLIP_name}.npy", mmap_mode="r")).to(device, dtype=torch.float16) # embedding of the labels
attns_hid_mean = torch.tensor(np.load(f"output_dir/{datataset_image_name}_attns_mean_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r")) # [l, n, h, d], attention values
mlps_hid_mean = torch.tensor(np.load(f"output_dir/{datataset_image_name}_mlps_mean_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r"))  # [l + 1, n, d], mlp values

num_total_images = len(dataloader) * batch_size
print(f"We are using a dataset containing {num_total_images} images.")

# Metrics to measure 
tot_correct_llava = 0
tot_correct_clip = 0
log_it = 10
count = 0

# -------------------------------
# Initialize accumulators for our test loop:
# For each 'lay' test value, we store the number of correct predictions
test_lay_values = [0, 1, 5, 10, 15, 20, 21, 22, 23]
mean_ablate_attns = False
mean_ablate_mlps = True
test_results_llava = {lay: 0. for lay in test_lay_values}
test_results_clip = {lay: 0. for lay in test_lay_values}
# -------------------------------

# Layer where to extract infos on patches
select_layer = -2
nr_heads = 16
max_new_tokens = 512
num_beams = 1  # number of path of decision, less faster
sep =  ","
temperature = 0  # 0 lowest, det
top_p = None

# Extract necessary projection for clip
ln_post = copy.deepcopy(model_CLIP.visual.ln_post).to("cuda")
proj = copy.deepcopy(model_CLIP.visual.proj).to("cuda")
print(f"Running the test with ablation up to layer {test_lay_values} and mean ablation for attns {mean_ablate_attns} and mean ablation for mlps {mean_ablate_mlps}")
# Inference loop over images
model.to("cpu")
model.eval()
model_CLIP.eval()

# Get format of prompt
input_ids = get_prompt_formatted(prompt, model, conv_templates)
# Print the prompt
print(f"The prompt is \n\n {prompt}")

# Precompute indices
precomputed_indices = {}
for lay_val in test_lay_values:
    layers = [y for _ in range(nr_heads) for y in range(lay_val)]
    heads = [x for x in range(nr_heads) for _ in range(lay_val)]
    precomputed_indices[lay_val] = (layers, heads)

for i, (image, labels) in enumerate(tqdm.tqdm(dataloader)):
    batch_size_here = image.shape[0]
    count += batch_size_here
    with torch.no_grad():
        
        image_sizes = [image.size]
        ## Use CLIP Model
        prs.reinit()
        with torch.no_grad():
            ### THIS IS THE BOTTLENECK OF COMPUTATION (hook)
            model_CLIP.to("cuda")
            # Get output as we want it
            spatial_features = model_CLIP.encode_image(
                    image.to(device, dtype=torch.float16), 
                    attn_method='head_no_spatial',
                    normalize=False
                )
            model_CLIP.to("cpu")

            # Retrieve computations directions
            attentions, mlps = prs.finalize(spatial_features)  # attentions: [b, l, n, h, d], mlps: [b, l + 1, n, d]
        
        ## HERE CAN FINALLY WORK and perform our tests on this image
        # For each test value of 'lay', we recompute the ablated features and run both LLAVA and CLIP inference.
        # The results are accumulated for later overall accuracy computation.

        model.to("cuda")
        for lay_val in test_lay_values:
            layers, heads = precomputed_indices[lay_val]            # Compute spatial features required by our layer with the current 'lay' value.
            attentions_abl, mlps_abl =  mean_ablate_head(
                attentions, mlps, select_layer,
                layers = layers,
                heads =  heads,
                attentions_mean_abl = attns_hid_mean,
                mlps_mean_abl = mlps_hid_mean,
                mean_ablate_mlps = mean_ablate_mlps,
                mean_ablate_attns = mean_ablate_attns)

            ## LLAVA prediction using test_images_tensor
            with torch.inference_mode():
                test_output_ids = model.generate(
                    input_ids,
                    images= llava_pred(attentions_abl, mlps_abl, select_layer)[:, 1:], # all patches beside CLS token
                    image_sizes=image_sizes,
                    num_beams=num_beams,
                    max_new_tokens=max_new_tokens,
                    use_cache=True,
                    images_embeds=True  # If want to give images embeds already precomputed TODO: Only support 1 image
                )
            test_out = tokenizer.batch_decode(test_output_ids, skip_special_tokens=True)[0].strip()
            # Update LLAVA test counter if prediction is correct.
            if test_out.lower() in classes_[labels[0]].lower():
                test_results_llava[lay_val] += 1

            ## CLIP prediction using test_images_tensor b, l, n, h, d
            hidden_output_test = ln_post(attentions_abl[:, :, 0].sum(1).sum(1) + mlps_abl[:, :, 0].sum(1)) # only CLS token
            test_clip_out = hidden_output_test @ proj
            # Update CLIP test counter if prediction is correct.
            if torch.argmax((test_clip_out @ classifier_).squeeze()) == labels[0]:
                test_results_clip[lay_val] += 1

        model.to("cpu")

        # --- End of tests for different lay values for this image ---
        if (i + 1) % log_it == 0:
            for lay_val in test_lay_values:
                acc_llava = test_results_llava[lay_val] / (i + 1) * 100
                acc_clip = test_results_clip[lay_val] / (i + 1) * 100
                print(f"lay = {lay_val}: LLAVA accuracy: {acc_llava:.2f}%, CLIP accuracy: {acc_clip:.2f}%\n")

# After processing all images, compute the overall accuracy for each test 'lay' value.
result_str = "LLAVA and CLIP accuracies for different 'lay' values:\n"
for lay_val in test_lay_values:
    acc_llava = test_results_llava[lay_val] / num_total_images * 100
    acc_clip = test_results_clip[lay_val] / num_total_images * 100
    result_str += f"lay = {lay_val}: LLAVA accuracy: {acc_llava:.2f}%, CLIP accuracy: {acc_clip:.2f}%\n"


results_filename = f"test_results_{dataset_name}_abl_mlps_{mean_ablate_mlps}_abl_attns_{mean_ablate_attns}_only_last.txt"

# Save the final test results into a text file with the dynamic filename.
with open(results_filename, "w") as f:
    f.write(result_str)

print(f"Test results saved to {results_filename}")

## Project a CLIP text embedding into hidden space of ViT Encoder

In [None]:
P = retrieve_proj_matrix(model_CLIP).to(device)

ln_weight, ln_bias, ln_eps = retrieve_post_layer_norm_par(model_CLIP)
ln_weight, ln_bias, ln_eps = ln_weight.to(device), ln_bias.to(device), ln_eps

In [None]:
model_CLIP.to("cuda")

# Get an image and a query text
with torch.no_grad():
    prs.reinit()
    model_CLIP.eval()
    # If querying by text, define a text prompt and encode it into an embedding
    text_query = "cat."
    # Tokenize the text query and move it to the device (GPU/CPU)
    text_query_token = tokenizer_CLIP(text_query).to(device)  
    # Encode the tokenized text into a normalized embedding
    topic_emb = model_CLIP.encode_text(text_query_token, normalize=False)
    # If querying by image, load and preprocess the image from disk
    prs.reinit()  # Reinitialize any hooks if required
    text_query = "woman.png"
    image_pil = Image.open(f'images/{text_query}')
    image = preprocess_clip(image_pil)[np.newaxis, :, :, :]  # Add batch dimension
    if precision == "fp16":
        image = image.to(dtype=torch.float16)
        topic_emb = topic_emb.to(dtype=torch.float16)
        

    # Encode the image into a normalized embedding
    image_emb = model_CLIP.encode_image(
        image.to(device), 
        attn_method='head_no_spatial',
        normalize=False
    )
    print(image_emb.shape)
    print(topic_emb.shape)
# Center text embed on image embed
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

if precision == "fp16":
    mean_final_images = mean_final_images.to(dtype=torch.float16)
    mean_final_texts = mean_final_texts.to(dtype=torch.float16)

topic_emb = topic_emb - mean_final_texts + mean_final_images

print("Normal")
print(topic_emb.norm())
print(image_emb.norm())
print(topic_emb @ image_emb.T)

# Project 
def remove(a, b):
    a = a.squeeze()
    b = b.squeeze()
    return (a - (torch.dot(a, b) / torch.dot(b, b)) * b).unsqueeze(0)
# Fictious values
mean = torch.tensor(0.15)
std = torch.tensor(1)

invert_topic_emb = invert_proj_layer_norm(topic_emb, P, ln_weight, ln_bias, std, mean, ln_eps)
invert_image_emb = remove(invert_proj_layer_norm(image_emb, P, ln_weight, ln_bias, std, mean, ln_eps), invert_topic_emb)

print("After proj back")
print(invert_topic_emb.norm())
print(invert_image_emb.norm())
print(invert_topic_emb @ invert_image_emb.T)

# Go back and revaluate
topic_emb_p = model_CLIP.visual.ln_post(invert_topic_emb) @ P
image_emb_p = model_CLIP.visual.ln_post(invert_image_emb) @ P
print(invert_topic_emb.shape)
print("Normal proj back")
print(topic_emb_p.norm())
print(image_emb_p.norm())
print(torch.norm(topic_emb_p - image_emb_p))
print(topic_emb_p @ image_emb_p.T)