In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from utils import llava_load_model, llava_process_image, llava_generate
from lens_utils import llava_logit_lens

##### Constants

In [None]:
data_path = "/home/drdo/vlm-compositionality/data"
dataset_folder = data_path+"/raw/sugarcrepe",
image_folder = data_path+"/raw/coco_val_2017"
model_name = "llava-hf/llava-1.5-13b-hf"
image_file = image_folder+"/000000008690.jpg"

num_patches = 24

##### Load model, image

In [None]:
# load image
# TODO: RGB, COCO_val2014_000000562150.jpg
image = Image.open(image_file) #.convert("RGB")

# load model, processor
# TODO: float32
model, processor = llava_load_model(model_name)

##### Process inputs

In [None]:
# process image and prompt(default)
inputs = llava_process_image(image, processor, device=model.device)

##### Get hidden states

In [None]:
# generate
outputs = llava_generate(inputs, model)

##### Get logit lens

In [None]:
# TODO: norm before unembedding
# vocab_dim, num_layers, num_tokens
softmax_probs = llava_logit_lens(inputs, model, outputs)

##### Object

In [None]:
class_ = 'hair'
class_token_indices = processor.tokenizer.encode(class_)[1:]

#### Internal confidence heatmap

In [None]:
heatmap_data = softmax_probs[class_token_indices].max(axis=0).T
num_image_embeddings = softmax_probs.shape[2]
im = plt.imshow(heatmap_data, aspect=30/num_image_embeddings, cmap='Blues', interpolation='nearest')
plt.title(f"'{class_}' probabilities")
plt.xlabel("LM Layer")
plt.ylabel("Image Embedding Index")
plt.tight_layout()
plt.clim(0, 1)
plt.colorbar()
plt.show()

##### Localization

In [None]:
def viz_localization(softmax_probs, class_token_indices, image, layer=None):

    img_width, img_height = image.size

    if layer is None:
        softmax_probs = softmax_probs[class_token_indices].max(axis=0).max(axis=0)
    else:
        softmax_probs = softmax_probs[class_token_indices].max(axis=0)[layer]
    segmentation = softmax_probs.reshape(num_patches, num_patches).astype(float)

    segmentation_resized = (np.array(Image.fromarray(segmentation).resize((img_width, img_height), Image.BILINEAR)))
    plt.imshow(image)
    plt.imshow(segmentation_resized, cmap='jet', interpolation='bilinear', alpha=.5)
    plt.axis('off')
    # TODO: layer in title
    plt.title(f"'{class_},' localization")
    plt.tight_layout()
    plt.show()

##### Max localization

In [None]:
viz_localization(softmax_probs, class_token_indices, image)

##### Localization by layer

In [None]:
viz_localization(softmax_probs, class_token_indices, image, 33)