In [8]:
import requests
import json
from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

In [17]:
# project directory
project_dir = "/root/vlm-compositionality"

# sugarcrepe
dataset_dir = project_dir+'/data/raw/sugarcrepe'

# coco
image_dir = project_dir+'/data/raw/coco/val2017'

In [9]:
# load sugarcrepe
def load_sugarcrepe(folder_path):
    """load sugarcrepe dataset from local"""

    if folder_path[-1] != '/':
        folder_path += '/'

    # add attribute
    with open (folder_path+'add_att.json', encoding='utf8') as f:
        add_attribute = json.load(f)

    # add object
    with open (folder_path+'add_obj.json', encoding='utf8') as f:
        add_object = json.load(f)

    # replace attribute
    with open (folder_path+'replace_att.json', encoding='utf8') as f:
        replace_attribute = json.load(f)

    # replace object
    with open (folder_path+'replace_obj.json', encoding='utf8') as f:
        replace_object = json.load(f)

    # replace relation
    with open (folder_path+'replace_rel.json', encoding='utf8') as f:
        replace_relation = json.load(f)

    # swap attribute
    with open (folder_path+'swap_att.json', encoding='utf8') as f:
        swap_attribute = json.load(f)

    # swap object
    with open (folder_path+'swap_obj.json', encoding='utf8') as f:
        swap_object = json.load(f)

    # collate together
    dataset = {
        'add_attribute': add_attribute, 'add_object': add_object, 
        'replace_attribute': replace_attribute, 'replace_object': replace_object, 
        'replace_relation': replace_relation, 'swap_attribute': swap_attribute, 
        'swap_object': swap_object,
    }

    return dataset
    
sugarcrepe = load_sugarcrepe(dataset_dir)

In [10]:
sugarcrepe['add_attribute']['5']

{'filename': '000000121506.jpg',
 'caption': 'A white umbrella that has been blown up the wrong way.',
 'negative_caption': 'A polka-dot white umbrella that has been blown up the wrong way.'}

In [11]:
# load f1 values
with open (project_dir+'/data/processed/image_to_f1_logit_50_mask_10000.json', encoding='utf8') as f:
    f1_map = json.load(f)

In [12]:
f1_map

{'000000403385.jpg': {'toilet': 0.16461897942262757,
  'sink': 0.543382897510397},
 '000000006818.jpg': {'toilet': 0.12236593518383956},
 '000000331352.jpg': {'toilet': 0.2224179099422533, 'sink': 0.38028530565844},
 '000000289393.jpg': {'bird': 0.33791358698348134,
  'giraffe': 0.3383099892257965,
  'cow': 0.388809374720458,
  'potted plant': 0.0},
 '000000143931.jpg': {'bus': 0.0640954530250361,
  'person': 0.03075365436942353},
 '000000308394.jpg': {'person': 0.0,
  'umbrella': 0.0,
  'bench': 0.05382999889466122,
  'handbag': 0.05686291663021608},
 '000000184321.jpg': {'train': 0.09546465325541707},
 '000000297343.jpg': {'stop sign': 0.358081545546445},
 '000000336587.jpg': {'stop sign': 0.7133558061141462, 'truck': 0.0},
 '000000122745.jpg': {'stop sign': 0.4145700043176198},
 '000000443303.jpg': {'cat': 0.32599737605558327,
  'suitcase': 0.1350540287283381,
  'book': 0.07588304884231874},
 '000000025560.jpg': {'tv': 0.0,
  'cat': 0.33793916738980484,
  'person': 0.0,
  'cup': 0.0

In [3]:
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto", 
)

processor = AutoProcessor.from_pretrained(model_id)

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:  21%|##        | 1.04G/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

In [19]:
# each value in "content" has to be a list of dicts with types ("text", "image") 
example = sugarcrepe['add_attribute']['5']
prompt = "Select the correct caption for the given image. Return either 'Caption 1' or 'Caption 2'. The captions are as follows :  Caption 1: {} Caption 2: {}"
prompt = prompt.format(example['negative_caption'], example['caption'])
conversation = [
    {
      "role": "user",
      "content": [
          {"type": "text", "text": prompt},
          {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

image_file = image_dir+'/'+example['filename']
raw_image = Image.open(image_file)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(model.device)

output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))

ER:  
Select the correct caption for the given image. Return either 'Caption 1' or 'Caption 2'. The captions are as follows :  Caption 1: A polka-dot white umbrella that has been blown up the wrong way. Caption 2: A white umbrella that has been blown up the wrong way. ASSISTANT: Caption 1
