In [3]:
from PIL import Image
from sklearn.metrics import f1_score

from pycocotools.coco import COCO

from data_utils import(
    load_sugarcrepe,
    coco_cats,
    coco_object_mask,
    filter_sugarcrepe_distict_objects,
)

from llava_utils import(
    llava_load_model, 
    llava_process_image, 
    llava_generate,
)
from lens_utils import(
    llava_logit_lens,
    get_mask_from_lens,
)

##### Constants

In [4]:
# project directory
project_dir = "/root/vlm-compositionality"

# sugarcrepe
dataset_dir = project_dir+'/data/raw/sugarcrepe'

# coco images
image_dir = project_dir+'/data/raw/coco/val2017'
# coco annotations
ann_dir = project_dir+'/data/raw/coco/annotations'
ann_file = ann_dir+'/instances_val2017.json'

# model constants
model_name = "llava-hf/llava-1.5-7b-hf"
topk = 50
num_patches = 24

##### Load dataset and annotations

In [None]:
# load sugarcrepe
sugarcrepe = load_sugarcrepe(dataset_dir)

# load annotations
coco=COCO(ann_file)

# get coco image ids
image_ids = coco.getImgIds()

# filter images
image_ids = filter_sugarcrepe_distict_objects(coco, image_ids)

##### Load model

In [None]:
# load model, processor
model, processor = llava_load_model(model_name) #, flash_attention=False, torch_dtype=torch.float32)

##### Eval loop

In [None]:
# get image, get all object tokens for the image
# for each object token generate logit lens mask and compare with coco mask
for image_id in image_ids:

    # image info
    image_info = coco.loadImgs(image_id)[0]
    image_file = image_info['file_name']
    image_width = image_info['width']
    image_height = image_info['height']

    # object tokens
    tokens = coco_cats(coco, image_id)

    # get coco masks
    token_to_mask = coco_object_mask(coco, image_id)

    # load image
    image = Image.open(image_dir+'/'+image_file).convert("RGB")

    # process image and prompt(default)
    inputs = llava_process_image(image, processor, device=model.device)

    # generate
    outputs = llava_generate(inputs, model)

    # get logit lens
    # vocab_dim, num_layers, num_tokens
    # TODO: what if token not in topk?
    softmax_probs = llava_logit_lens(inputs, model, outputs, topk=topk)

    # compare for each token

    for token in tokens:
        # get non zero mask from lens
        ll_mask = get_mask_from_lens(
            softmax_probs,
            token,
            processor,
            num_patches,
            image_width, image_height
        )

        # compare coco mask and logit lens mask
        coco_mask = token_to_mask[token]

        # f1 score to estimate overlap
        f1 = f1_score(coco_mask.flatten(), ll_mask.flatten())
        print("f1 score for {} : {}".format(token, f1))
        quit()