In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from utils import llava_load_model, llava_process_image, llava_generate
from lens_utils import llava_logit_lens

##### Paths

In [None]:
data_path = "/home/drdo/vlm-compositionality/data"
dataset_folder = data_path+"/raw/sugarcrepe",
image_folder = data_path+"/raw/coco_val_2017"
model_name = "llava-hf/llava-1.5-13b-hf"
image_file = image_folder+"/000000008690.jpg"

##### Load model, image

In [None]:
# load image
# TODO: random image
image = Image.open(image_file)

# load model, processor
model, processor = llava_load_model(model_name)

##### Process inputs

In [None]:
# process image and prompt(default)
inputs = llava_process_image(image, processor, device=model.device)

##### Get hidden states

In [None]:
# generate
outputs = llava_generate(inputs, model)

##### Get logit lens

In [None]:
# TODO: norm before unembedding
# vocab_dim, num_layers, num_tokens
softmax_probs = llava_logit_lens(inputs, model, outputs)

##### Object

In [None]:
class_ = 'hair'

#### Internal confidence heatmap

In [None]:
class_token_indices = processor.tokenizer.encode(class_)[1:]
heatmap_data = softmax_probs[class_token_indices].max(axis=0).T

# plot
num_image_embeddings = softmax_probs.shape[2]
im = plt.imshow(heatmap_data, aspect=30/num_image_embeddings, cmap='Blues', interpolation='nearest')
plt.title(f"'{class_}' probabilities")
plt.xlabel("LM Layer")
plt.ylabel("Image Embedding Index")
plt.tight_layout()
plt.clim(0, 1)
plt.colorbar()
plt.show()

##### Max localization

In [None]:
num_patches = 24
img_width, img_height = image.size

embedding_max = softmax_probs[class_token_indices].max(axis=0).max(axis=0)
segmentation = embedding_max.reshape(num_patches, num_patches).astype(float)

segmentation_resized = (np.array(Image.fromarray(segmentation).resize((img_width, img_height), Image.BILINEAR)))
plt.imshow(image)
plt.imshow(segmentation_resized, cmap='jet', interpolation='bilinear', alpha=.5)
plt.axis('off')
plt.title(f"'{class_}' localization")
plt.tight_layout()
plt.show()

##### Localization by layer