In [None]:
import argparse
import torch
import time
from llava.constants import (
    IMAGE_TOKEN_INDEX, # IMAGE_TOKEN_INDEX = -200
    DEFAULT_IMAGE_TOKEN, # DEFAULT_IMAGE_TOKEN = "<image>"
    DEFAULT_IM_START_TOKEN, # DEFAULT_IM_START_TOKEN = "<im_start>"
    DEFAULT_IM_END_TOKEN, # DEFAULT_IM_END_TOKEN = "<im_end>"
    IMAGE_PLACEHOLDER, # IMAGE_PLACEHOLDER = "<image-placeholder>"
)

from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init

from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
    KeywordsStoppingCriteria,
)

from PIL import Image
import base64
import requests
from PIL import Image
from io import BytesIO
import re


def image_parser(args):
    out = args.image_file.split(args.sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

model_name = get_model_name_from_path("llava-v1.5-7b")
tokenizer, model, image_processor, context_len = load_pretrained_model(
        "llava-v1.5-7b", None, model_name, device="cuda:0"
    )

In [None]:
query = "Describe the image in detail."
# query = "Generate a caption for this image."
# query = "Please describe this image in detail."
image_file = "case3.jpg"
probe_prompt = "The image showcases a street corner in a city, featuring street signs mounted on a street light post. The signs provide directions to various locations, including the local library and a nearby market. The sky above the street corner is blue, indicating a clear day.\n\nIn the background, there is a"
args_dict = {
                "model_path": "llava-v1.5-7b",
                "model_base": None,
                "model_name": get_model_name_from_path("llava-v1.5-7b"),
                "query": query,
                "conv_mode": None,
                "image_file": image_file,
                "sep": ",",
                "temperature": -1,
                "top_p": None, 
                "num_beams": 1,
                "max_new_tokens": 2048
            }
args = type('Args', (), args_dict)()

# Model
# disable_torch_init()
qs = args.query
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
    if model.config.mm_use_im_start_end:
        qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
    else:
        qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
    if model.config.mm_use_im_start_end:
        qs = image_token_se + "\n" + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

if args.conv_mode is not None and conv_mode != args.conv_mode:
    print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, args.conv_mode, args.conv_mode
            )
        )
else:
    args.conv_mode = conv_mode

conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
    
image_files = image_parser(args)
images = load_images(image_files)
images_tensor = process_images(
        images,
        image_processor, # clip
        model.config
    ).to(model.device, dtype=torch.float16)

input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda(0)
    )

input = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\n{} ASSISTANT: ".format(query) + probe_prompt

input_ids= (
        tokenizer_image_token(input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda(0)
    )

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

with torch.inference_mode():
    with torch.no_grad():
        output_dict = model.generate(
            input_ids,
            images=images_tensor,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True,
            output_attentions=True,
            return_dict_in_generate=True,
            output_hidden_states=True,
            stopping_criteria=[stopping_criteria],
            dola_decoding=True,
            alpha = 0.6,
            threshold_top_p = 0.9,
            threshold_top_k = 10,
            min_candidate_tokens = None,
            early_exit_layers=[i for i in range(15, 25, 1)],
        )

A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>
Describe the image in detail. ASSISTANT:
torch.Size([1, 3, 336, 336])
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>
Describe the image in detail. ASSISTANT: The image showcases a street corner in a city, featuring street signs mounted on a street light post. The signs provide directions to various locations, including the local library and a nearby market. The sky above the street corner is blue, indicating a clear day.

In the background, there is a car
torch.float16
using dola greedy search


In [None]:
import torch.nn.functional as F
with torch.no_grad():
    mapping_layer = model.get_output_embeddings()
    out = mapping_layer(output_dict.hidden_states[0][-1][:,-1,:])

next_probs = F.softmax(out, dim=-1)
top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=20)

top_k_probs, top_k_ids

(tensor([[5.2832e-01, 4.4482e-01, 2.2141e-02, 3.1910e-03, 4.3178e-04, 3.2592e-04,
          1.6654e-04, 1.1623e-04, 9.1910e-05, 6.7234e-05, 5.7518e-05, 4.3452e-05,
          2.8491e-05, 2.0862e-05, 1.5974e-05, 8.7023e-06, 8.5235e-06, 7.5698e-06,
          7.3314e-06, 6.9737e-06]], device='cuda:0', dtype=torch.float16),
 tensor([[ 1939,  3869,  1670,   450,  5806,   512,  8512, 16564,   739,  1094,
            306,   319, 14832,   910,  4001,  3118,  2398,  2216, 11511,  8669]],
        device='cuda:0'))

In [None]:
tokenizer.convert_ids_to_tokens(torch.squeeze(top_k_ids).tolist())

['▁No',
 '▁Yes',
 '▁There',
 '▁The',
 '▁While',
 '▁In',
 '▁Although',
 '▁Based',
 '▁It',
 '▁As',
 '▁I',
 '▁A',
 '▁Though',
 '▁This',
 '▁Since',
 '▁One',
 '▁However',
 '▁Not',
 '▁Unfortunately',
 '▁Instead']

In [None]:
import torch.nn.functional as F
first_token_id, second_token_id, third_token_id, fourth_token_id =  3869,1939, 29900,  1670
first_token_list = []
second_token_list = []
third_token_list = []
fourth_token_list = []

with torch.no_grad():
    mapping_layer = model.get_output_embeddings()
    for i in range(1, 33):
        if i == 32:
            out = mapping_layer(output_dict.hidden_states[0][i][:,-1,:])
        else:
            out = mapping_layer(model.model.norm(output_dict.hidden_states[0][i][:,-1,:]))
        next_probs = F.softmax(out, dim=-1)
        first_token_list.append(next_probs[:,first_token_id].cpu().numpy())
        second_token_list.append(next_probs[:,second_token_id].cpu().numpy())        
        third_token_list.append(next_probs[:,third_token_id].cpu().numpy())
        fourth_token_list.append(next_probs[:,fourth_token_id].cpu().numpy())

In [None]:
import matplotlib.pyplot as plt
import numpy as np
#折线图
x = np.arange(1, 33, 1)#点的横坐标
k1 = first_token_list #线1的纵坐标
k2 = second_token_list#线2的纵坐标
k3 = third_token_list
k4 = fourth_token_list
plt.plot(x,k1,'s-',color = 'r',label="yes")#s-:方形
plt.plot(x,k2,'o-',color = 'g',label="no")#o-:圆形 
plt.plot(x,k3,'b',color = 'b',label="cup")
plt.plot(x,k4,'s-',color = 'b',label="chair")
plt.xlabel("layer")#横坐标名字
plt.ylabel("prob")#纵坐标名字
plt.legend(loc = "best")#图例
plt.ylim(0,1)
plt.show()