In [None]:
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm

from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.utils import save_image

from pope_loader import POPEDataSet
from minigpt4.common.dist_utils import get_rank
from minigpt4.models import load_preprocess

from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry

# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *

from PIL import Image
from torchvision.utils import save_image

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn
import json


MODEL_EVAL_CONFIG_PATH = {
    "minigpt4": "eval_configs/minigpt4_eval.yaml",
    "instructblip": "eval_configs/instructblip_eval.yaml",
    "lrv_instruct": "eval_configs/lrv_instruct_eval.yaml",
    "shikra": "eval_configs/shikra_eval.yaml",
    "llava-1.5": "eval_configs/llava-1.5_eval.yaml",
}

POPE_PATH = {
    "random": "coco_pope/coco_pope_random.json",
    "popular": "coco_pope/coco_pope_popular.json",
    "adversarial": "coco_pope/coco_pope_adversarial.json",
}

INSTRUCTION_TEMPLATE = {
    "minigpt4": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
    "instructblip": "<ImageHere><question>",
    "lrv_instruct": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
    "shikra": "USER: <im_start><ImageHere><im_end> <question> ASSISTANT:",
    "llava-1.5": "USER: <ImageHere> <question> ASSISTANT:"
}


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True





parser = argparse.ArgumentParser(description="POPE-Adv evaluation on LVLMs.")
parser.add_argument("--model", type=str, help="model")
parser.add_argument("--gpu-id", type=int, help="specify the gpu to load the model.")
parser.add_argument(
    "--options",
    nargs="+",
    help="override some settings in the used config, the key-value pair "
    "in xxx=yyy format will be merged into config file (deprecate), "
    "change to --cfg-options instead.",
)
parser.add_argument("--data_path", type=str, default="/mnt/petrelfs/share_data/wangjiaqi/mllm-data-alg/COCO_2014/ori/val2014/val2014/", help="data path")
parser.add_argument("--batch_size", type=int, help="batch size")
parser.add_argument("--num_workers", type=int, default=2, help="num workers")
args = parser.parse_known_args()[0]


args.model = "llava-1.5"
# args.model = "instructblip"
# args.model = "minigpt4"
# args.model = "shikra"
args.gpu_id = "0"
args.batch_size = 1


os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
args.cfg_path = MODEL_EVAL_CONFIG_PATH[args.model]
cfg = Config(args)
setup_seeds(cfg)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# ========================================
#             Model Initialization
# ========================================
print('Initializing Model')

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device)
model.eval()
processor_cfg = cfg.get_config().preprocess
processor_cfg.vis_processor.eval.do_normalize = False
vis_processors, txt_processors = load_preprocess(processor_cfg)
print(vis_processors["eval"].transform)
print("Done!")

mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
norm = transforms.Normalize(mean, std)

In [None]:
image_path = "XXX.jpg"
raw_image = Image.open(image_path)
plt.imshow(raw_image)
plt.show()
raw_image = raw_image.convert("RGB")
image = vis_processors["eval"](raw_image).unsqueeze(0)
image = image.to(device)
print(image.shape)

qu = "Please describe this image in detail."
# qu = "What can you see in this image?"
# qu = "Introduce about this image."
# qu = "Is there a bottle in the image?"
template = INSTRUCTION_TEMPLATE[args.model]
qu = template.replace("<question>", qu)


with torch.inference_mode():
    with torch.no_grad():
        out = model.generate(
            {"image": norm(image), "prompt":qu}, 
            use_nucleus_sampling=False, 
            num_beams=1,
            max_new_tokens=512,
            output_attentions=True,
            # ours_decoding=True,
            # scale_factor=50,
            # threshold=15.0,
            # num_attn_candidates=5,
        )
print(out[0])
print("\n")


qu_append = out[0]
qu = qu + qu_append


with torch.inference_mode():
    with torch.no_grad():
        out = model.generate(
            {"image": norm(image), "prompt":qu}, 
            use_nucleus_sampling=False, 
            num_beams=1,
            max_new_tokens=512,
            output_attentions=True,
            # ours_decoding=True,
            # scale_factor=50,
            # threshold=15.0,
            # num_attn_candidates=5,
            output_attentions=True,
            return_dict_in_generate=True,
        )

# print(len(out[-1]))
# print(out.attentions[0].shape)  

# print(torch.topk(out.attentions[31].mean(1).squeeze().mean(0), k=10))

attns = [attn.clone() for attn in out.attentions]
# [layer_id, bs, head_id, qk, v]

# for i in range(len(attns)):
#     for j in range(len(attns[i])):
#         attns[i][j] = attns[i][j] * 5
#         attns[i][j] = attns[i][j].clamp(0, 1)


try:
    p_before, p_after = qu.split('<ImageHere>')
    p_before = model.system_message + p_before if args.model in ["shikra", "llava-1.5"] else p_before
    p_before_tokens = model.llama_tokenizer(
        p_before, return_tensors="pt", add_special_tokens=False).to("cuda").input_ids
    p_after_tokens = model.llama_tokenizer(
        p_after, return_tensors="pt", add_special_tokens=False).to("cuda").input_ids
    p_before_tokens = model.llama_tokenizer.convert_ids_to_tokens(p_before_tokens[0].tolist())
    p_after_tokens = model.llama_tokenizer.convert_ids_to_tokens(p_after_tokens[0].tolist())

    bos = torch.ones([1, 1], dtype=torch.int64, device="cuda") * model.llama_tokenizer.bos_token_id
    bos_tokens = model.llama_tokenizer.convert_ids_to_tokens(bos[0].tolist())
except:
    p_before, p_after = qu.split('<ImageHere>')
    p_before = model.system_message + p_before if args.model in ["shikra", "llava-1.5"] else p_before
    p_before_tokens = model.llm_tokenizer(
        p_before, return_tensors="pt", add_special_tokens=False).to("cuda").input_ids
    p_after_tokens = model.llm_tokenizer(
        p_after, return_tensors="pt", add_special_tokens=False).to("cuda").input_ids
    p_before_tokens = model.llm_tokenizer.convert_ids_to_tokens(p_before_tokens[0].tolist())
    p_after_tokens = model.llm_tokenizer.convert_ids_to_tokens(p_after_tokens[0].tolist())

    bos = torch.ones([1, 1], dtype=torch.int64, device="cuda") * model.llm_tokenizer.bos_token_id
    bos_tokens = model.llm_tokenizer.convert_ids_to_tokens(bos[0].tolist())



print(qu)
# print(p_before_tokens)
# print(p_after_tokens)
# print(bos_tokens)

NUM_IMAGE_TOKENS = 256 if args.model == "shikra" else 32
NUM_IMAGE_TOKENS = 576 if args.model == "llava-1.5" else NUM_IMAGE_TOKENS
tokens = bos_tokens + p_before_tokens + ['img_token'] * NUM_IMAGE_TOKENS + p_after_tokens
seq_len = len(tokens)
len1 = len(bos_tokens + p_before_tokens)
if args.model == "instructblip":
    len2 = len(bos_tokens + p_before_tokens + ['img_token'] * NUM_IMAGE_TOKENS + p_after_tokens[:p_after_tokens.index(".")+1])
else:
    len2 = len(bos_tokens + p_before_tokens + ['img_token'] * NUM_IMAGE_TOKENS + p_after_tokens[:p_after_tokens.index(":")+1])
# print(len(tokens))
print(len1, len2)

tokens = [str(idx) + "-" + token for idx, token in enumerate(tokens)]
print(tokens)

attn_last = attns[-1].max(1).values.data.squeeze()
attn_last = attn_last / attn_last.sum(-1, keepdim=True)
attn_max = torch.cat([attn.unsqueeze(0) for attn in attns], dim=0).max(2).values.data.max(0).values.data.squeeze()
attn_max = attn_max / attn_max.sum(-1, keepdim=True)





path = "text_feat/self_attention/{}/".format(args.model)
if not os.path.exists(path):
    os.mkdir(path)

plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号       
plt.rcParams['xtick.direction'] = 'in'
# plt.rcParams['ytick.direction'] = 'in'
def draw(data, x, y, ax):
    seaborn.heatmap(data, 
                    xticklabels=x, square=True, yticklabels=y, vmin=0, vmax=1.0, 
                    cbar=False, ax=ax)
fig, axs = plt.subplots(1, 1, figsize=(100, 100))#布置画板
draw(5*attn_last.cpu().numpy(), x=tokens, y=tokens, ax=axs)
plt.show()
# plt.savefig("xxx.pdf", dpi=1600)