In [None]:
import re
import torch
from PIL import Image
import gc
import requests
import copy
from PIL import Image
from io import BytesIO
from torch.nn.functional import mse_loss
import numpy as np
import einops
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
from llava.conversation import conv_templates
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

from utils.models.factory import create_model_and_transforms, get_tokenizer
from utils.models.prs_hook import hook_prs_logger


In [2]:
# Helper functions
def image_parser(image_file, sep=","):
    out = image_file.split(sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

device = "cuda"
model_name = "liuhaotian/llava-v1.5-7b"
model_path = "/cluster/work/vogtlab/Group/vstrozzi/cache/models--liuhaotian--llava-v1.5-7b/snapshots/4481d270cc22fd5c4d1bb5df129622006ccd9234/"

## Get LLava Model

In [None]:
### IMPORT: On Biomedcluster change .config under model_path to point towards correct vision_tower clip path
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_name)
)

In [None]:
# CLIP model
model_CLIP_name = 'ViT-L-14-336' 
pretrained = "hf"
precision = "fp16"

torch.cuda.empty_cache()
model_CLIP, _, preprocess_clip = create_model_and_transforms(model_CLIP_name, pretrained=pretrained, precision=precision, cache_dir="../cache")

model_CLIP.eval()
context_length = model_CLIP.context_length
# Not needed anymore
vocab_size = model_CLIP.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model_CLIP.visual.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Len of res:", len(model_CLIP.visual.transformer.resblocks))
# Hook necessary to have: no projection on shared space, no spatial tokens in output (i.e. contributuon of attention to tokens), and hidden outputs of all tokens
prs = hook_prs_logger(model_CLIP, device, spatial=False, vision_projection=False, full_output=True) # This attach hook to get the residual stream

## Here play around with LLava

In [None]:
# Layer where to extract infos on patches
select_layer = -2
# Params
prompt = "Describe the image focusing on main subjects. ignore background"
image_file = "images/catdog.png"
max_new_tokens = 512
num_beams = 1 # numer of path of decision, less faster
sep =  ","
temperature = 0 # 0 lowest, det
top_p = None
images_embeds = True # If provided image embeds

## Tokenization prompt
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
# Making prompt in correct format
if IMAGE_PLACEHOLDER in prompt:
    if model.config.mm_use_im_start_end:
        prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
    else:
        prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
else:
    if model.config.mm_use_im_start_end:
        prompt = image_token_se + "\n" + prompt
    else:
        prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

## Convert model
if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
    conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
    conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

if conv_mode is not None and conv_mode != conv_mode:
    print(
        "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
            conv_mode, conv_mode, conv_mode
        )
    )
else:
    conv_mode = conv_mode

## Load conversation mode standard template 
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)

## Load images from online or local
image_files = image_parser(image_file, sep)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
    images,
    image_processor,
    model.config
).to(model.device, dtype=torch.float16)


## Tokenize prompt
input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .to(device)
)


## Use CLIP Model
if images_embeds:
    prs.reinit()
    model_CLIP.eval()
    with torch.no_grad():
        model.to("cpu")
        model_CLIP.to("cuda")
        # Get output as we want it
        spatial_features = model_CLIP.encode_image(
                images_tensor.to(device), 
                attn_method='head_no_spatial',
                normalize=False
            )

        model_CLIP.to("cpu")

        # Move directions
        attentions, mlps = prs.finalize(spatial_features)  # attentions: [b, l, n, h, d], mlps: [b, l + 1, n, d]
        attentions = einops.rearrange(attentions, "b l n h d -> b l h n d")

        # Compute spatial features required by our layer
        spatial_features = (attentions[:, :(select_layer + 1), :, 1:, :].sum(1).sum(1) + mlps[:, :(select_layer + 1), 1:, :].sum(1))
    # Pass to LLaVa the features computed by us
    images_tensor = spatial_features


## Generate an answer by using full model LLava
model.to("cuda")
with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=images_tensor,
        image_sizes=image_sizes,
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        top_p=top_p,
        num_beams=num_beams,
        max_new_tokens=max_new_tokens,
        use_cache=True,
        images_embeds = images_embeds # If want to give images embeds already precomputed TODO: Only support 1 image

    )

## Print the output
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)