In [2]:
import sys
sys.executable

'/home/ksmehrab/miniconda/envs/llava/bin/python'

In [3]:
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import transformers
transformers.__version__

'4.31.0'

In [5]:
torch.__version__

'2.0.1+cu117'

## Load llava

In [32]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

model_path = "liuhaotian/llava-v1.5-7b"
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path),
    device=device
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.51s/it]


In [29]:
import llava

In [33]:
llava.__file__

'/home/ksmehrab/FishVLM/LLaVA/llava/__init__.py'

In [34]:
from llava.eval.run_llava import eval_model

model_path = "liuhaotian/llava-v1.5-7b"
prompt = "Describe the contents of the image."
image_file = "/home/ksmehrab/AttentionGrounding/sample_images/dog.jpg"

args = type('Args', (), {
    "model_path": model_path,
    "model_base": None,
    "model_name": get_model_name_from_path(model_path),
    "query": prompt,
    "conv_mode": None,
    "image_file": image_file,
    "sep": ",",
    "temperature": 0,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 512
})()

eval_model(args)

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.55s/it]


AttributeError: 'tuple' object has no attribute 'unsqueeze'

In [7]:
## Use the eval model code in the main code
import torch

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
    KeywordsStoppingCriteria,
)

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
import re


def image_parser(
    image_file,
    sep=","
):
    out = image_file.split(sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

def get_model(
    model_path,
    model_base=None
):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path, model_base, model_name
    )

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    return model_name, conv_mode, tokenizer, model, image_processor, context_len
    


def get_outputs(
    tokenizer,
    model,
    image_processor,
    context_len,
    image_file,
    query,
    conv_mode,
    sep=",",
    temperature=0.2,
    top_p=None,
    num_beams=1,
    max_new_tokens=512,
):
    """
    image_file: can be a string of paths, separated by sep (,)
    """

    qs = query
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in qs:
        if model.config.mm_use_im_start_end:
            qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
        else:
            qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
    else:
        if model.config.mm_use_im_start_end:
            qs = image_token_se + "\n" + qs
        else:
            """ 
            This is the default behavior. 
            This means the image token is added to the beginning of the prompt.
            """
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 

    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    image_files = image_parser(
        image_file=image_file,
        sep=sep
    )

    images = load_images(image_files)

    """ 
    If we pass in a list of images, we can potentially have a batch of images.
    """
    images_tensor = process_images(
        images,
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    """
    The following input_ids are likely the input_ids for the text prompt.
    We can decode these input_ids to see if/how they match the input prompt.

    Also note that we `unsqueeze(0)` to add a batch dimension. This means we can potentially have  a batch of textual prompts.
    This is likely possible if we simply pass in a list of prompts.
    """

    input_ids, last_chunk_tokens, last_chunk_offsets, last_chunk_after_image = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")

    input_ids = input_ids.unsqueeze(0).to(model.device)

    # print(input_ids)

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
            output_attentions=True,  # Enable attention scores
            return_dict_in_generate=True  # Return detailed outputs
        )

    return prompt, input_ids, last_chunk_tokens, last_chunk_offsets, last_chunk_after_image, images_tensor, output_ids

    """ 
    Notice how the output_ids are separated into two parts: the input_ids and the output_ids.
    output_ids[:, :input_token_len]
    """
    input_token_len = input_ids.shape[1]
    n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
    if n_diff_input_output > 0:
        print(
            f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
        )
    outputs = tokenizer.batch_decode(
        output_ids[:, input_token_len:], skip_special_tokens=True
    )[0]
    outputs = outputs.strip()
    if outputs.endswith(stop_str):
        outputs = outputs[: -len(stop_str)]
    outputs = outputs.strip()
    print(outputs)

    return output_ids, outputs


In [14]:
### Main code
### We need to get the model_name, conv_mode, tokenizer, model, image_processor, context_len
model_path = 'liuhaotian/llava-v1.5-7b'
model_name, conv_mode, tokenizer, model, image_processor, context_len = get_model(model_path)

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


In [8]:
### Set the image_file and textual query
query = "Detect the objects in the image"
query = query.strip()
image_file = '/data/OpenPSG/coco/val2017/000000300341.jpg'

In [61]:
### Set the image_file and textual query
query = "Bird eye"
query = query.strip()
image_file = "/home/ksmehrab/AttentionGrounding/sample_images/bird-sample.jpg"

In [15]:
### Call the get_outputs method to print the model output 
prompt, input_ids, last_chunk_tokens, last_chunk_offsets, last_chunk_after_image, images_tensor, output_ids = get_outputs(
    tokenizer=tokenizer,
    model=model,
    image_processor=image_processor,
    context_len=context_len,
    image_file=image_file,
    query=query,
    conv_mode=conv_mode,
    temperature=0
)

In [18]:
type(output_ids)

transformers.generation.utils.GreedySearchDecoderOnlyOutput

In [16]:
output_ids.keys()

odict_keys(['sequences', 'attentions'])

In [65]:
type(output_ids['sequences'])

torch.Tensor

In [66]:
output_ids['sequences'].shape

torch.Size([1, 112])

In [67]:
type(output_ids['attentions'])

tuple

In [68]:
len(output_ids['attentions'])

66

In [69]:
output_ids['attentions'][0]

(tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [8.9600e-01, 1.0382e-01, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [1.3147e-01, 2.4170e-02, 8.4424e-01,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           ...,
           [2.0351e-03, 1.0529e-03, 2.1267e-03,  ..., 1.4290e-02,
            0.0000e+00, 0.0000e+00],
           [2.0561e-03, 1.0805e-03, 2.0790e-03,  ..., 1.5320e-02,
            7.8726e-04, 0.0000e+00],
           [2.0719e-04, 8.7917e-05, 1.6308e-04,  ..., 1.6708e-02,
            1.4515e-03, 9.3933e-02]],
 
          [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [6.5918e-01, 3.4082e-01, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [4.3506e-01, 2.8369e-01, 2.8125e-01,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           ...,
           [2.9678e-03, 1.6556e-03, 1.

In [70]:
type(output_ids['attentions'][0])

tuple

In [71]:
len(output_ids['attentions'][0])

32

In [93]:
output_ids['attentions'][0][31].shape

torch.Size([1, 32, 621, 621])

In [73]:
# input_ids[input_ids == -200] = 0

In [74]:
input_ids

tensor([[    1,   319, 13563,  1546,   263, 12758,  5199,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  5199, 29915, 29879,  5155,
         29889,  3148,  1001, 29901, 29871,  -200, 29871,    13, 29933,  1823,
         10977,   319,  1799,  9047, 13566, 29901]], device='cuda:0')

In [75]:
tokenizer.convert_ids_to_tokens(29933)

'B'

In [76]:
len(input_ids[0])

46

In [77]:
last_chunk_tokens

[1, 29871, 13, 29933, 1823, 10977, 319, 1799, 9047, 13566, 29901]

In [78]:
last_chunk_offsets

[(0, 0),
 (0, 1),
 (0, 1),
 (1, 2),
 (2, 5),
 (5, 9),
 (9, 11),
 (11, 13),
 (13, 16),
 (16, 19),
 (19, 20)]

In [79]:
# Method that returns the map of words to token ids, and words to offset indices
# Split the sentence into words and punctuation
import re

import queue
def get_q(list):
    q = queue.Queue()
    for item in list:
        q.put(item)
    return q

def print_elements_in_queue(q):
    # Print all elements currently in the queue
    temp_list = []
    while not q.empty():
        item = q.get()
        temp_list.append(item)
        print(item)

    # Put the elements back into the queue
    for item in temp_list:
        q.put(item)

def get_word_to_token_map(sentence, token_ids, offsets):
    assert (token_ids[0] == 1) and (token_ids[1] == 29871) and (token_ids[2] == 13), "The first three tokens should be <sep>, :, and \n"

    assert (token_ids[-1] == 29901) and (token_ids[-2] == 13566) and (token_ids[-3] == 9047) and (token_ids[-4] == 1799) and (token_ids[-5] == 319), "The last 5 tokens shoud be 'ASSISTANT:'"

    words = re.findall(r'\w+|[^\w\s]', sentence, re.UNICODE)
    assert (words[-1] == ':') and (words[-2] == 'ASSISTANT'), "The last two words should be 'ASSISTANT' and ':'"
    
    req_words = words[:-2] 
    req_offsets = offsets[3:-5]
    req_token_ids = token_ids[3:-5]

    # print(req_token_ids)

    offsets_q = get_q(req_offsets)
    words_q = get_q(req_words)
    words_q_peek = get_q(req_words)
    token_ids_q = get_q(req_token_ids)

    current_word = None
    current_token_to_word = ''
    current_input_ids = []
    current_indices = []
    word_to_token_ids_map = {}
    word_to_token_index_map = {}
    counter = 0
    while not offsets_q.empty():
        if not current_word:
            current_word = words_q_peek.get()
        
        offset = offsets_q.get()
        token_id = token_ids_q.get()
        current_input_ids.append(token_id)
        current_indices.append(counter)
        sentence_part = sentence[offset[0]:offset[1]]
        current_token_to_word += sentence_part
       
        if current_token_to_word.strip() == current_word:
            # print('Match')
            
            word_to_token_ids_map[current_word] = current_input_ids

            if current_word not in word_to_token_index_map.keys():
                word_to_token_index_map[current_word] = current_indices
            else:
                word_to_token_index_map[current_word].extend(current_indices)

            current_word = None
            current_token_to_word = ''
            current_input_ids = []
            current_indices = []
            words_q.get()

        counter += 1

        if counter >= len(req_token_ids):
            break
    
    # print_elements_in_queue(words_q)

    assert offsets_q.empty(), "Offsets queue should be empty after while loop"
    assert words_q.empty(), "Words queue should be empty after while loop"

    return word_to_token_ids_map, word_to_token_index_map
    

In [80]:
word_to_token_id_map, word_to_relative_index_map = get_word_to_token_map(last_chunk_after_image, last_chunk_tokens, last_chunk_offsets)

In [81]:
word_to_relative_index_map

{'Bird': [0, 1], 'eye': [2]}

In [82]:
def get_number_of_tokens_before_image_and_till_separator(input_ids, separator_ids = [29871, 13], num_image_tokens=576):
    input_ids = input_ids.clone().detach().to('cpu')
    separator_ids = torch.tensor(separator_ids)
    num_tokens_before_separator = -1
    # Find the position of the separator sequence
    for i in range(len(input_ids[0]) - len(separator_ids) + 1):
        if torch.equal(input_ids[0][i:i+len(separator_ids)], separator_ids):
            num_tokens_before_separator = i
            break

    assert num_tokens_before_separator != -1, "The separator sequence was not found in the input_ids"

    # The num tokens before the separator include a token for image
    num_tokens_before_image = num_tokens_before_separator - 1

    # The total number of tokens with the separator sequence
    total_num_tokens_till_separator_inclusive = num_tokens_before_image + num_image_tokens + len(separator_ids)

    return num_tokens_before_image, total_num_tokens_till_separator_inclusive

In [83]:
num_tokens_before_image, total_num_tokens_till_separator_inclusive = get_number_of_tokens_before_image_and_till_separator(input_ids)

In [84]:
num_image_tokens= 576
separator_ids = [29871, 13]
assert (num_tokens_before_image + num_image_tokens + len(last_chunk_tokens) - 1) == output_ids['attentions'][0][0].shape[2], "The number of tokens before the image, the image tokens, and the tokens after the image should match the number of tokens in the attention matrix"

In [85]:
image_token_start_idx = num_tokens_before_image
image_token_end_idx = num_tokens_before_image + num_image_tokens

In [112]:
def extract_visual_attention_for_given_word(attention_matrix, num_layer, num_head, word, word_to_relative_index_map, num_tokens_before_relative_idx, image_token_start_idx, image_token_end_idx):
    sequence_attention = attention_matrix[num_layer][:, num_head, :, :].squeeze().detach().cpu()
    # print(sequence_attention.shape)
    assert word in word_to_relative_index_map.keys(), "The word should be in the word_to_relative_index_map"
    relative_indices = word_to_relative_index_map[word]
    token_indices = [idx + num_tokens_before_relative_idx for idx in relative_indices]
    # print(token_indices)
    req_attention_maps = []
    for idx in token_indices:
        req_attention_maps.append(sequence_attention[idx][image_token_start_idx:image_token_end_idx])
        # print(sequence_attention[idx][image_token_start_idx:image_token_end_idx].shape)
    return req_attention_maps

In [113]:
attention_maps = extract_visual_attention_for_given_word(
    attention_matrix=output_ids['attentions'][0], 
    num_layer=0, 
    num_head=0, 
    word='Bird', 
    word_to_relative_index_map=word_to_relative_index_map, 
    num_tokens_before_relative_idx=total_num_tokens_till_separator_inclusive, 
    image_token_start_idx=image_token_start_idx, 
    image_token_end_idx=image_token_end_idx
)

In [88]:
""" 
627 - 51 is how many image tokens get added 
576
Therefore, this version of llava uses patch sizes of 14, resulting in 24 x 24 = 576 image tokens 
"""

' \n627 - 51 is how many image tokens get added \n576\nTherefore, this version of llava uses patch sizes of 14, resulting in 24 x 24 = 576 image tokens \n'

In [19]:
tokenizer.batch_decode(
        output_ids, skip_special_tokens=False
    )

TypeError: argument 'ids': Can't extract `str` to `Vec`

In [102]:
""" 
So we kinda know the position of the image token. Its where the input id is -200
"""

' \nSo we kinda know the position of the image token. Its where the input id is -200\n'

In [93]:
images_tensor.shape

torch.Size([1, 3, 336, 336])

In [94]:
model.config.mm_use_im_start_end

False

In [20]:
prompt

"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nBird's eye ASSISTANT:"

In [108]:
DEFAULT_IM_START_TOKEN

'<im_start>'

In [109]:
DEFAULT_IM_END_TOKEN

'<im_end>'

In [114]:
tokenizer.convert_ids_to_tokens(4002)

'Des'

In [None]:
"""
We need to get the index of the token -200
We also need a way of mapping the tokenized ids to the actual tokens
For example: Given the word, 'blue,' we need to know what exactly the tokens are for 'blue'

Look at the code for tokenizer_image_token() inside mm_utils.py

Look at how they chunk based on the <image> token 

We basically want to pass through the tokenizer for the chunk that follows the image token, do a return_offsets_mapping=True for this chunk, map sentence offsets to the tokens
For subword tokenization, get the list of words in the sentence by doing a split. 
If a part of a sentence (corresponding to a sentence offset) is not present in the list of words, then we know that it is a subword token.
In this case, move to the next offset. Check if combining the previous offset with the current offset gives a word in the list of words. 

We can implement this as a queue. Pop an offset from the queue, get the sentence part, see if that word is present in the list of words. If not, pop the next offset and combine the two offsets. Repeat this process until we get a word in the list of words. Repeat this process for all the offsets (that is until the queue is empty). 
"""

In [None]:
# from transformers import AutoTokenizer

# # Load the tokenizer
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# # Input sentence
# prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nThis is justatest to makesurethatthe sub-word tokenization is working. ASSISTANT:"

# last_chunk_after_image = prompt.split('<image>')[-1] # Last chunk will have the user query 
# tokenized = tokenizer(last_chunk_after_image, return_offsets_mapping=True)

# # Extract tokens and offsets
# last_chunk_token_ids = tokenized["input_ids"]
# last_chunk_offsets = tokenized["offset_mapping"]

In [40]:
last_chunk_after_image

'\nThis is justatest to makesurethatthe sub-word tokenization is working. ASSISTANT:'

In [41]:
# words_in_sentence = last_chunk_after_image.split()

In [None]:
# Split the sentence into words and punctuation
import re
words_in_sentence = re.findall(r'\w+|[^\w\s]', last_chunk_after_image, re.UNICODE)

In [43]:
words_in_sentence

['This',
 'is',
 'justatest',
 'to',
 'makesurethatthe',
 'sub',
 '-',
 'word',
 'tokenization',
 'is',
 'working',
 '.',
 'ASSISTANT',
 ':']

In [44]:
tokens_in_sentence = [last_chunk_after_image[offsets[0]:offsets[1]] for offsets in last_chunk_offsets]

In [45]:
tokens_in_sentence

['',
 'This',
 'is',
 'just',
 'ates',
 't',
 'to',
 'makes',
 'ure',
 'tha',
 'tt',
 'he',
 'sub',
 '-',
 'word',
 'token',
 'ization',
 'is',
 'working',
 '.',
 'ASSISTANT',
 ':',
 '']

In [46]:
req_tokens_in_sentence, req_offsets, req_token_ids = [], [], []
for token_in_sentence, offset, token_id in zip(tokens_in_sentence, last_chunk_offsets, last_chunk_token_ids):
    token_in_sentence = token_in_sentence.strip()
    if token_in_sentence:
        req_tokens_in_sentence.append(token_in_sentence)
        req_offsets.append(offset) 
        req_token_ids.append(token_id)

In [47]:
req_token_ids

[2023,
 2003,
 2074,
 8520,
 2102,
 2000,
 3084,
 5397,
 8322,
 4779,
 5369,
 4942,
 1011,
 2773,
 19204,
 3989,
 2003,
 2551,
 1012,
 3353,
 1024]

In [48]:
req_tokens_in_sentence

['This',
 'is',
 'just',
 'ates',
 't',
 'to',
 'makes',
 'ure',
 'tha',
 'tt',
 'he',
 'sub',
 '-',
 'word',
 'token',
 'ization',
 'is',
 'working',
 '.',
 'ASSISTANT',
 ':']

In [49]:
import queue
def get_q(list):
    q = queue.Queue()
    for item in list:
        q.put(item)
    return q

offsets_q = get_q(req_offsets)
words_q = get_q(words_in_sentence)
words_q_peek = get_q(words_in_sentence)
token_ids_q = get_q(req_token_ids)

In [50]:
current_word = None
current_token_to_word = ''
current_input_ids = []
word_to_token_ids_map = {}
counter = 0
while not offsets_q.empty():
    if not current_word:
        current_word = words_q_peek.get()

    offset = offsets_q.get()
    token_id = token_ids_q.get()
    current_input_ids.append(token_id)
    sentence_part = last_chunk_after_image[offset[0]:offset[1]]
    current_token_to_word += sentence_part
    if current_token_to_word == current_word:
        # print('Match')
        # if current_word not in word_to_token_ids_map.keys():
        #     word_to_token_ids_map[current_word] = current_input_ids
        # else:
        #     word_to_token_ids_map[current_word].extend(current_input_ids)
        word_to_token_ids_map[current_word] = current_input_ids
        current_word = None
        current_token_to_word = ''
        current_input_ids = []
        words_q.get()

    counter += 1

    if counter >= len(req_token_ids):
        break
    

In [51]:
word_to_token_ids_map

{'This': [2023],
 'is': [2003],
 'justatest': [2074, 8520, 2102],
 'to': [2000],
 'makesurethatthe': [3084, 5397, 8322, 4779, 5369],
 'sub': [4942],
 '-': [1011],
 'word': [2773],
 'tokenization': [19204, 3989],
 'working': [2551],
 '.': [1012],
 'ASSISTANT': [3353],
 ':': [1024]}

In [None]:
""" 
We can actually modify this implementation to compute the token positions in a map as well, so we don't have to do it later.
We know the number of tokens before the -200 token. We know the number of tokens in the -200 token. We can also know the number of tokens that get added before the user part of the sentence begins. If we already know the relative position of each word from the beginning of the sentence, we can simply add the number of stuff that come before to get the absolute position of the token in our sequence
"""