In [None]:
import torch
import clip
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import torch.nn as nn
import nlp_utils
import os
from PIL import Image
import numpy as np
from quintuplets import *
from nlp_utils import add_text_to_image
import matplotlib.pyplot as plt

In [None]:
cache_dir = nlp_utils.get_cache_dir()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataset_path = "/home/dcor/roeyron/datasets/quintuplets_v1"
qp_ids = np.random.RandomState(1).permutation(sorted(os.listdir(dataset_path)))
qp_id = qp_ids[0]
qp = Quintuplet.load(dataset_path, qp_id)
print(qp)

# load Llava and clip

In [None]:
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"

processor = LlavaNextProcessor.from_pretrained(model_name, cache_dir=cache_dir)
model = LlavaNextForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir)
model.generation_config.pad_token_id = processor.tokenizer.pad_token_id

model = model.to(device)

clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

# get Llava's text description

In [None]:
def get_llava_desc(image, text):
    prompt = f"[INST] <image>\n{text} \n Limit your response to no more than 2 short sentences. [/INST]"
    #prompt = f"[INST] <image>\n{text} \n Answer shortly, with few words only. [/INST]"
    inputs = processor(prompt, image, return_tensors="pt").to(device)
    output = model.generate(**inputs, max_new_tokens=100, output_hidden_states=True, return_dict_in_generate=True)
    result_text = processor.decode(output['sequences'][0], skip_special_tokens=True)
    result_text = result_text.split(prompt.replace('<image>', ' '))[1].strip()
    return result_text

# get CLIP similarity of an image and a text

In [None]:
def get_clip_image_text_similarity(image, text):
    image_processed = clip_preprocess(image).unsqueeze(0).to(device)
    text_tokenized = clip.tokenize([text]).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image_processed)
        text_features = clip_model.encode_text(text_tokenized)
    cos = nn.CosineSimilarity(dim=0)
    sim = cos(image_features[0],text_features[0]).item()
    sim = (sim+1)/2
    return sim

# Evaluate Quadruplets Baseline

In [None]:
score = 0
qn_ids = get_splits_ids()['train']
qd_ids = [QuadrupletId(qn_id, which) for qn_id in qn_ids for which in ["gamma_positive", "delta_positive"]]
for qd_id in qd_ids:
    qd = load_quadruplet(QUINTUPLETS_DATASET_PATH, qd_id)
    desc = get_llava_desc(qd.query, qd.prompt)
    pos_score = get_clip_image_text_similarity(qd.positive, desc)
    neg_score = get_clip_image_text_similarity(qd.negative, desc)
    if pos_score > neg_score:
        score += 1

In [None]:
score / len(qd_ids)