In [None]:
!nvidia-smi

In [None]:
import re
import torch
from PIL import Image
import matplotlib.pyplot as plt
import gc
import requests
import copy
from PIL import Image
from io import BytesIO
from torch.nn.functional import mse_loss
import numpy as np
import einops
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
from llava.conversation import conv_templates
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

from utils.models.factory import create_model_and_transforms, get_tokenizer
from utils.models.prs_hook import hook_prs_logger
from utils.scripts.utils_llava import *

In [2]:
## Parameters
device = 'cuda'
seed = 0
num_last_layers_ = 4
subset_dim = 10
tot_samples_per_class = 50
dataset_text_name = "top_1500_nouns_5_sentences_imagenet_clean"
datataset_image_name = "imagenet"
cache_dir = "../cache"

In [3]:
# Helper functions
def image_parser(image_file, sep=","):
    out = image_file.split(sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

device = "cuda"
model_name = "liuhaotian/llava-v1.5-7b"
model_path = "/cluster/work/vogtlab/Group/vstrozzi/cache/models--liuhaotian--llava-v1.5-7b/snapshots/4481d270cc22fd5c4d1bb5df129622006ccd9234/"

## Get LLava Model

In [None]:
### IMPORT: On Biomedcluster change .config under model_path to point towards correct vision_tower clip path
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_name),
)
model.to("cpu")

In [None]:
# CLIP model
model_CLIP_name = 'ViT-L-14-336' 
pretrained = "openai"
precision = "fp16"

torch.cuda.empty_cache()
model_CLIP, _, preprocess_clip = create_model_and_transforms(model_CLIP_name, pretrained=pretrained, precision=precision, cache_dir="../cache")

model_CLIP.eval()
context_length = model_CLIP.context_length
# Not needed anymore
vocab_size = model_CLIP.vocab_size
tokenizer_CLIP = get_tokenizer(model_CLIP_name)

print(model_CLIP)
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model_CLIP.visual.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Len of res:", len(model_CLIP.visual.transformer.resblocks))
# Hook necessary to have: no projection on shared space, no spatial tokens in output (i.e. contributuon of attention to tokens), and hidden outputs of all tokens
prs = hook_prs_logger(model_CLIP, device, spatial=False, vision_projection=False, full_output=True) # This attach hook to get the residual stream

final_embeddings_images = torch.tensor(np.load(f"output_dir/{datataset_image_name}_embeddings_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r")).to(device)
final_embeddings_texts = torch.tensor(np.load(f"output_dir/{dataset_text_name}_{model_CLIP_name}.npy", mmap_mode="r")).to(device)

attns_hid_mean = torch.tensor(np.load(f"output_dir/{datataset_image_name}_attns_mean_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r")) # [l, n, h, d], attention values
mlps_hid_mean = torch.tensor(np.load(f"output_dir/{datataset_image_name}_mlps_mean_{model_CLIP_name}_seed_{seed}.npy", mmap_mode="r"))  # [l + 1, n, d], mlp values

## Here play around with LLava

In [6]:
def mean_ablate_head(attentions, mlps, select_layer, layers=None, heads=None, 
                     attentions_mean_abl=None, mlps_mean_abl=None, mean_ablate_mlps=False):

    # Clone the input tensors to prevent in-place modifications from affecting future calls
    attentions = attentions.clone()
    mlps = mlps.clone()

    # Compute the mean value over the selected layers, all heads, and tokens (excluding token index 0) 
    # or use predefined mean
    attentions_mean_abl = (attentions[:, :(select_layer + 1), :, :, :].mean(dim=(0, 1, 2, 3)) 
                           if attentions_mean_abl is None 
                           else attentions_mean_abl.unsqueeze(0))
    mlps_mean_abl = (mlps[:, :(select_layer + 1), :, :].mean(dim=(0, 1, 2)) 
                     if mlps_mean_abl is None 
                     else mlps_mean_abl.unsqueeze(0))

    # Replace the attention values for specified layers and heads with the computed mean ablation value
    if heads is not None and layers is not None:
        for layer, head in zip(layers, heads):        
            attentions[:, layer, :, head, :] = (attentions_mean_abl 
                if len(attentions_mean_abl.shape) < 2 
                else attentions_mean_abl[:, layer, :, head, :])
            # If required, mean ablate mlps
            if mean_ablate_mlps:
                mlps[:, layer, :, :] = (mlps_mean_abl[:, layer, :, :] 
                    if len(mlps_mean_abl.shape) >= 2 
                    else mlps_mean_abl)

    # Aggregate the modified attention tensor by summing over layers and heads,
    # and add the corresponding summed MLP outputs 
    return attentions, mlps

def llava_pred(attentions, mlps, select_layer):

    return (attentions[:, :(select_layer + 1), :, :, :].sum(1).sum(2) +
            mlps[:, :(select_layer + 1), :, :].sum(1))

# Project 
def remove(a, b):
    a = a.squeeze()
    b = b.squeeze()
    return (a - (torch.dot(a, b) / torch.dot(b, b)) * b).unsqueeze(0)

def remove_patches(p, b):
    p = p.squeeze()
    for i in range(p.shape[0]):
        p[i:i+1, :] = remove(p[i:i+1, :], b)

    return p.unsqueeze(0)

In [None]:
def llava_infer(prompt, pil_image, images_embeds=False, mean_ablate=False, up_to_layer_ablate = 10, mean_ablate_mlps=False, attentions_mean_abl=None, mlps_mean_abl=None,): # If provided image embeds
    # Layer where to extract infos on patches
    select_layer = -2
    max_new_tokens = 512
    num_beams = 1 # numer of path of decision, less faster
    sep =  ","
    temperature = 0 # 0 lowest, det
    top_p = None

    ## Tokenization prompt
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    # Making prompt in correct format
    if IMAGE_PLACEHOLDER in prompt:
        if model.config.mm_use_im_start_end:
            prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
        else:
            prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
    else:
        if model.config.mm_use_im_start_end:
            prompt = image_token_se + "\n" + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    ## Convert model
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if conv_mode is not None and conv_mode != conv_mode:
        print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, conv_mode, conv_mode
            )
        )
    else:
        conv_mode = conv_mode

    ## Load conversation mode standard template 
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    # print(prompt)

    ## Load images from online or local
    """ image_files = image_parser(image_file, sep)
    images = load_images(image_files)
    image_sizes = [x.size for x in images] """
    
    image_sizes = [img.size]
    
    images_tensor = process_images(
        [pil_image],
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)


    ## Tokenize prompt
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(device)
    )

    ## Use CLIP Model
    if images_embeds:
        prs.reinit()
        model_CLIP.eval()
        with torch.no_grad():
            model.to("cpu")
            model_CLIP.to("cuda")
            # Get output as we want it
            spatial_features = model_CLIP.encode_image(
                    images_tensor.to(device), 
                    attn_method='head_no_spatial',
                    normalize=False
                )

            model_CLIP.to("cpu")

            # Move directions
            attentions, mlps = prs.finalize(spatial_features)  # attentions: [b, l, n, h, d], mlps: [b, l + 1, n, d]

            # Compute spatial features required by our layer 
            if mean_ablate:
                attentions, mlps = mean_ablate_head(attentions, mlps, select_layer,
                layers = [y for x in range(0, 16) for y in range(0, up_to_layer_ablate)],
                heads = [x for x in range(0, 16) for y in range(0, up_to_layer_ablate)],
                attentions_mean_abl = attentions_mean_abl,
                mlps_mean_abl = mlps_mean_abl,
                mean_ablate_mlps = mean_ablate_mlps)
            
            images_tensor = llava_pred(attentions, mlps[:, :(select_layer + 1), :, :], select_layer)
            
        # images_tensor = remove_patches(images_tensor, invert_topic_emb)
        # print(images_tensor.shape)
        # Swap some features position 
        # images_tensor[:, :192, :] = images_tensor[:, 384:, :] 
        # images_tensor[:, 192:384, :] = images_tensor[:, 384:, :] 
        # images_tensor[:, 384:, :] = images_tensor[:, 384:, :] 


    ## Generate an answer by using full model LLava
    model.to("cuda")
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor[:, 1:], # Skip CLS
            image_sizes=image_sizes,
            #do_sample= True if temperature > 0 else False,
            #temperature=temperature,
            #top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            images_embeds = images_embeds # If want to give images embeds already precomputed TODO: Only support 1 image

        )

    ## Print the output
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

    return outputs, images_tensor


# Params
prompt = "You are a vision-language expert. Analyze the given image and classify it into one of the following categories:  \
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']. Answer with only the most appropriate category."
image_file = "images/catdog.png"

## Visualize image
img = Image.open(image_file)
plt.imshow(img)
plt.axis('off')  # Hide axis ticks and labels
plt.show()

llava_infer(prompt, img, images_embeds = True, mean_ablate = True, up_to_layer_ablate = 20, attentions_mean_abl=attns_hid_mean, mlps_mean_abl=mlps_hid_mean,  mean_ablate_mlps=True)

In [8]:
# The image features a brown and white dog and a brown and black cat sitting together on a carpeted floor. They appear to be relaxed and comfortable in each other's company. The dog is positioned on the left side of the cat, with both animals facing the same direction.
# In the background, there is a bookshelf with several books on it, adding a cozy and lived-in atmosphere to the scene.

## Test zero-shot accuracy VLM on Cifar

In [9]:
def get_prompt_formatted(prompt, model, conv_templates):
    ## Tokenization prompt
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    # Making prompt in correct format
    if IMAGE_PLACEHOLDER in prompt:
        if model.config.mm_use_im_start_end:
            prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
        else:
            prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
    else:
        if model.config.mm_use_im_start_end:
            prompt = image_token_se + "\n" + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    ## Convert model
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if conv_mode is not None and conv_mode != conv_mode:
        print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, conv_mode, conv_mode
            )
        )
    else:
        conv_mode = conv_mode

    ## Load conversation mode standard template 
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    ## Tokenize prompt
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(device)
    )

    return input_ids


In [None]:
# TODO: Add on mean ablation our loaded values of memory
from utils.misc.visualization import visualization_preprocess
from utils.datasets_constants.cifar_10_classes import cifar_10_classes
import numpy as np
import torch
from utils.misc.misc import accuracy, accuracy_correct
from utils.scripts.algorithms_text_explanations import *
from utils.models.factory import create_model_and_transforms, get_tokenizer
from utils.misc.visualization import visualization_preprocess
from utils.models.prs_hook import hook_prs_logger
from utils.datasets_constants.imagenet_classes import imagenet_classes
from utils.datasets_constants.cifar_10_classes import cifar_10_classes
from utils.datasets_constants.cub_classes import cub_classes, waterbird_classes
import os
from utils.datasets.dataset_helpers import dataset_to_dataloader
from utils.scripts.algorithms_text_explanations_funcs import *
import tqdm
from torchvision.transforms import ToPILImage
import copy

# Constants and fixed configuration
seed = 1
path = './datasets/'
batch_size = 1  # ToDO: ONLY WORK WITH B DIM 1 NOW
prompt = "You are a vision-language expert. Analyze the given image and classify it into one of the following categories:  \
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']. Answer with only the most appropriate category in lower case."

# Prepare dataset
ds_ = CIFAR10(root=path, download=True, train=False, transform=preprocess_clip)

dataloader = dataset_to_dataloader(
    ds_,
    samples_per_class=1,
    tot_samples_per_class=1000,  # or whatever you prefer
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    seed=seed,
)


classes_ = cifar_10_classes
classifier_ = torch.tensor(np.load(f"output_dir/CIFAR10_classifier_{model_CLIP_name}.npy", mmap_mode="r")).to(device, dtype=torch.float16) # embedding of the labels

print(classes_)
num_total_images = len(dataloader) * batch_size
print(f"We are using a dataset containing {num_total_images} images.")

# Metrics to measure 
tot_correct_llava = 0
tot_correct_clip = 0
log_it = 100
count = 0

# -------------------------------
# Initialize accumulators for our test loop:
# For each 'lay' test value, we store the number of correct predictions
test_lay_values = [1, 5, 10, 15, 20, 21, 22, 23]
test_results_llava = {lay: 0 for lay in test_lay_values}
test_results_clip = {lay: 0 for lay in test_lay_values}
# -------------------------------

# Layer where to extract infos on patches
select_layer = -2
max_new_tokens = 512
num_beams = 1  # number of path of decision, less faster
sep =  ","
temperature = 0  # 0 lowest, det
top_p = None

# Extract necessary pprojection for clip
ln_post = copy.deepcopy(model_CLIP.visual.ln_post).to("cuda")
proj = copy.deepcopy(model_CLIP.visual.proj).to("cuda")
print(f"Running the test with ablation up to layer {test_lay_values}")
# Inference loop over images
model.to("cpu")
model.eval()

# Get format of prompt
input_ids = get_prompt_formatted(prompt, model, conv_templates)

for i, (image, labels) in enumerate(tqdm.tqdm(dataloader)):
    batch_size_here = image.shape[0]
    count += batch_size_here
    with torch.no_grad():
        
        ## Use CLIP Model
        prs.reinit()
        model_CLIP.eval()
        with torch.no_grad():
            model_CLIP.to("cuda")
            # Get output as we want it
            spatial_features = model_CLIP.encode_image(
                    image.to(device, dtype=torch.float16), 
                    attn_method='head_no_spatial',
                    normalize=False
                )

            model_CLIP.to("cpu")

            # Move directions
            attentions, mlps = prs.finalize(spatial_features)  # attentions: [b, l, n, h, d], mlps: [b, l + 1, n, d]
        
        ## HERE CAN FINALLY WORK and perform our tests on this image
        # For each test value of 'lay', we recompute the ablated features and run both LLAVA and CLIP inference.
        # The results are accumulated for later overall accuracy computation.
        model.to("cuda")
        for lay_val in test_lay_values:
            # Compute spatial features required by our layer with the current 'lay' value.
            attentions_abl, mlps_abl =  mean_ablate_head(
                attentions, mlps, select_layer,
                layers = [y for x in range(0, 16) for y in range(0, lay_val)],
                heads = [x for x in range(0, 16) for y in range(0, lay_val)],
                attentions_mean_abl = attns_hid_mean,
                mlps_mean_abl = mlps_hid_mean,
                mean_ablate_mlps = True)

            ## LLAVA prediction using test_images_tensor
            with torch.inference_mode():
                test_output_ids = model.generate(
                    input_ids,
                    images= llava_pred(attentions_abl, mlps_abl, select_layer)[:, 1:],
                    image_sizes=image_sizes,
                    num_beams=num_beams,
                    max_new_tokens=max_new_tokens,
                    use_cache=True,
                    images_embeds=True  # If want to give images embeds already precomputed TODO: Only support 1 image
                )
            test_out = tokenizer.batch_decode(test_output_ids, skip_special_tokens=True)[0].strip()
            # Update LLAVA test counter if prediction is correct.
            print("Correct sol is ", classes_[labels[0]].lower())
            print(test_out.lower())
            if test_out.lower() == classes_[labels[0]].lower():
                test_results_llava[lay_val] += 1

            ## CLIP prediction using test_images_tensor b, l, n, h, d
            hidden_output_test = ln_post(attentions_abl[:, :, 0].sum(1).sum(1) + mlps_abl[:, :, 0].sum(1)) #test_images_tensor[:, 0, :].squeeze(0)
            test_clip_out = hidden_output_test @ proj
            # Update CLIP test counter if prediction is correct.
            print((test_clip_out @ classifier_).squeeze())
            print(classes_[torch.argmax((test_clip_out @ classifier_).squeeze())])
            print(classes_[torch.argmax((spatial_features @ classifier_).squeeze())])

            if torch.argmax(test_clip_out @ classifier_) == labels[0]:
                test_results_clip[lay_val] += 1
        model.to("cpu")

        # --- End of tests for different lay values for this image ---

        if (i + 1) % log_it == 0:
            print(f"Tot accuracy LLAVA so far is {test_results_llava/count*100}")
            print(f"Tot accuracy ClIP so far is {test_results_clip/count*100}")

print(f"Final accuracy LLAVA is {test_results_llava/count*100}")
print(f"Final accuracy CLIP is {tot_correct_clip/count*100}")

# After processing all images, compute the overall accuracy for each test 'lay' value.
result_str = "LLAVA and CLIP accuracies for different 'lay' values:\n"
for lay_val in test_lay_values:
    acc_llava = test_results_llava[lay_val] / num_total_images * 100
    acc_clip = test_results_clip[lay_val] / num_total_images * 100
    result_str += f"lay = {lay_val}: LLAVA accuracy: {acc_llava:.2f}%, CLIP accuracy: {acc_clip:.2f}%\n"

# Save the final test results into a text file.
with open("test_results.txt", "w") as f:
    f.write(result_str)

print("Test results saved to test_results.txt")


## Project a CLIP text embedding into hidden space of ViT Encoder

In [None]:
P = retrieve_proj_matrix(model_CLIP).to(device)

ln_weight, ln_bias, ln_eps = retrieve_post_layer_norm_par(model_CLIP)
ln_weight, ln_bias, ln_eps = ln_weight.to(device), ln_bias.to(device), ln_eps

In [None]:
model_CLIP.to("cuda")

# Get an image and a query text
with torch.no_grad():
    prs.reinit()
    model_CLIP.eval()
    # If querying by text, define a text prompt and encode it into an embedding
    text_query = "cat."
    # Tokenize the text query and move it to the device (GPU/CPU)
    text_query_token = tokenizer_CLIP(text_query).to(device)  
    # Encode the tokenized text into a normalized embedding
    topic_emb = model_CLIP.encode_text(text_query_token, normalize=False)
    # If querying by image, load and preprocess the image from disk
    prs.reinit()  # Reinitialize any hooks if required
    text_query = "woman.png"
    image_pil = Image.open(f'images/{text_query}')
    image = preprocess_clip(image_pil)[np.newaxis, :, :, :]  # Add batch dimension
    if precision == "fp16":
        image = image.to(dtype=torch.float16)
        topic_emb = topic_emb.to(dtype=torch.float16)
        

    # Encode the image into a normalized embedding
    image_emb = model_CLIP.encode_image(
        image.to(device), 
        attn_method='head_no_spatial',
        normalize=False
    )
    print(image_emb.shape)
    print(topic_emb.shape)
# Center text embed on image embed
mean_final_images = torch.mean(final_embeddings_images, axis=0)
mean_final_texts = torch.mean(final_embeddings_texts, axis=0)

if precision == "fp16":
    mean_final_images = mean_final_images.to(dtype=torch.float16)
    mean_final_texts = mean_final_texts.to(dtype=torch.float16)

topic_emb = topic_emb - mean_final_texts + mean_final_images

print("Normal")
print(topic_emb.norm())
print(image_emb.norm())
print(topic_emb @ image_emb.T)

# Project 
def remove(a, b):
    a = a.squeeze()
    b = b.squeeze()
    return (a - (torch.dot(a, b) / torch.dot(b, b)) * b).unsqueeze(0)
# Fictious values
mean = torch.tensor(0.15)
std = torch.tensor(1)

invert_topic_emb = invert_proj_layer_norm(topic_emb, P, ln_weight, ln_bias, std, mean, ln_eps)
invert_image_emb = remove(invert_proj_layer_norm(image_emb, P, ln_weight, ln_bias, std, mean, ln_eps), invert_topic_emb)

print("After proj back")
print(invert_topic_emb.norm())
print(invert_image_emb.norm())
print(invert_topic_emb @ invert_image_emb.T)

# Go back and revaluate
topic_emb_p = model_CLIP.visual.ln_post(invert_topic_emb) @ P
image_emb_p = model_CLIP.visual.ln_post(invert_image_emb) @ P
print(invert_topic_emb.shape)
print("Normal proj back")
print(topic_emb_p.norm())
print(image_emb_p.norm())
print(torch.norm(topic_emb_p - image_emb_p))
print(topic_emb_p @ image_emb_p.T)