In [1]:
from collections import Counter, defaultdict
from PIL import Image
import json, os, random, math, re, sys
from tqdm import tqdm
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["HF_DATASETS_CACHE"] = "/data/yingshac/hf_cache"
verbose = False
objects = ['mug', 'cap', 'sunglasses', 'remote', 'fork', 'plate', 'headphones', 'tape', 'candle', 'phone', 'spoon', 'book', 'knife', 'flower', 'bowl', 'cup', 'scissors', 'can']


In [3]:
device="cuda"
# use llava 1.5 7b model
model_path = 'llava-hf/llava-1.5-7b-hf'
model = LlavaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float32, ) 
processor = AutoProcessor.from_pretrained(model_path)
model.to(device)

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.00s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): Quick

In [4]:
def llava_vqa(image, prompt, device, verbose=False):
    if verbose: print(prompt)
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    inputs = {k:v.to(device) for k, v in inputs.items()}
    generate_ids = model.generate(**inputs, max_length=300)
    generated_text = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    answer = generated_text.split('\n')[-1]

    if verbose: print(answer)
    return int('yes' in answer.lower())
    

In [5]:
def evaluate(gth_caption, pilimage, device):
    # input pilimage: 64x64
    tmp = gth_caption.split()
    f1, f2 = tmp[0], tmp[-1]
    true_relation = " ".join(tmp[1:-1])
    if true_relation in ['right of', 'left of']:
        image1 = pilimage.crop((0, 0, 32, 64))
        image2 = pilimage.crop((32, 0, 64, 64))
    else:
        image1 = pilimage.crop((0, 0, 64, 32))
        image2 = pilimage.crop((0, 32, 64, 64))

    image1 = torch.from_numpy(np.array(image1)).to(torch.float32).to(device)
    image2 = torch.from_numpy(np.array(image2)).to(torch.float32).to(device)

    if true_relation in ['right of', 'in-front of']:
        f1, f2 = f2, f1
    
    prompt1 = f"USER: <image>\nIs a {f1} in the image?\nASSISTANT:"
    prompt2 = f"USER: <image>\nIs a {f2} in the image?\nASSISTANT:"

    correct1, correct2 = llava_vqa(image1, prompt1, device), llava_vqa(image2, prompt2, device)
    return correct1, correct2
    

In [6]:
handle = "0228_195834"
whichset="test"
sample_dir = f"../scripts/diffuser_real/output/{handle}/infr/{whichset}_sentences"
epc = sorted([f for f in os.listdir(sample_dir)], key=lambda x: int(x[5:].split("_")[0]))[-1]
imsize = 64

pilimages, gth_captions = [], []

for f in os.listdir(f"{sample_dir}/{epc}/samples"):
    if ".txt" in f: continue
    im = Image.open(f"{sample_dir}/{epc}/samples/{f}")
    W, H = im.size
    nrows, ncols = H//imsize, W//imsize
    captions_file = f.replace(".png", ".txt")
    with open(f"{sample_dir}/{epc}/samples/{captions_file}", "r") as captions:
        gth_captions.extend([x.strip() for x in captions.readlines()])
    for r in range(nrows):
        for c in range(ncols):
            left, top = c*imsize, r*imsize
            right, bottom = left+imsize, top+imsize
            pilimage = im.crop((left, top, right, bottom))
            if np.sum(pilimage) == 255*3*64*64: continue # skip placeholders which are purely white images
            pilimages.append(pilimage)
    print(len(pilimages), len(gth_captions))
    

256 256
376 376
632 632
888 888


In [7]:
correct = []
for pilimage, gth_caption in tqdm(zip(pilimages, gth_captions), total=len(gth_captions)):
    correct1, correct2 = evaluate(gth_caption, pilimage, device)
    correct.append(correct1 and correct2)
    #break

100%|██████████| 888/888 [1:01:20<00:00,  4.14s/it]


In [8]:
round(sum(correct) / len(correct), 4)

0.0653

## Draft

In [2]:
FN = {'can': 6, 'candle': 1, 'tape': 32, 'bowl': 5, 'fork': 11, 'flower': 27, 'sunglasses': 8, 'spoon': 5, 'remote': 3, 'plate': 3, 'scissors': 11, 'knife': 11, 'phone': 1}
FP = {('mug', 'cup'): 63, ('knife', 'scissors'): 5, ('bowl', 'plate'): 27, ('can', 'remote'): 6, 
      ('knife', 'fork'): 8, ('can', 'candle'): 11, ('cup', 'can'): 5, ('cap', 'book'): 6, 
      ('candle', 'cap'): 18, ('candle', 'bowl'): 2, ('tape', 'remote'): 4, ('tape', 'headphones'): 5, 
      ('spoon', 'cap'): 4, ('sunglasses', 'headphones'): 27, ('knife', 'remote'): 2, 
      ('knife', 'phone'): 4, ('knife', 'spoon'): 9, ('can', 'mug'): 4, ('can', 'headphones'): 3, 
      ('can', 'phone'): 1, ('can', 'book'): 1, ('cap', 'phone'): 2, ('scissors', 'cap'): 19, 
      ('scissors', 'remote'): 22, ('scissors', 'headphones'): 21, ('scissors', 'knife'): 17, 
      ('bowl', 'mug'): 4, ('headphones', 'mug'): 1, ('bowl', 'cup'): 10, ('headphones', 'cup'): 1, 
      ('can', 'cap'): 33, ('cup', 'plate'): 1, ('can', 'tape'): 1, ('book', 'mug'): 3, 
      ('book', 'cap'): 2, ('cap', 'tape'): 7, ('book', 'cup'): 4, ('book', 'can'): 1, 
      ('plate', 'spoon'): 8, ('flower', 'candle'): 21, ('sunglasses', 'candle'): 4, 
      ('sunglasses', 'phone'): 4, ('flower', 'cap'): 13, ('bowl', 'remote'): 4, 
      ('flower', 'scissors'): 2, ('spoon', 'knife'): 2, ('spoon', 'scissors'): 5, 
      ('headphones', 'remote'): 10, ('remote', 'tape'): 4, ('bowl', 'candle'): 3, 
      ('remote', 'phone'): 16, ('sunglasses', 'mug'): 3, ('sunglasses', 'cup'): 3, 
      ('bowl', 'phone'): 1, ('plate', 'mug'): 6, ('can', 'plate'): 1, ('plate', 'cup'): 7, 
      ('tape', 'mug'): 1, ('tape', 'spoon'): 4, ('tape', 'cup'): 3, ('scissors', 'phone'): 7, 
      ('scissors', 'spoon'): 4, ('mug', 'candle'): 1, ('remote', 'candle'): 3, ('scissors', 'tape'): 7, 
      ('remote', 'headphones'): 24, ('plate', 'bowl'): 13, ('sunglasses', 'remote'): 9, 
      ('sunglasses', 'can'): 1, ('sunglasses', 'cap'): 9, ('fork', 'remote'): 7, 
      ('fork', 'headphones'): 5, ('fork', 'phone'): 3, ('spoon', 'remote'): 5, ('candle', 'remote'): 3, 
      ('spoon', 'fork'): 2, ('spoon', 'candle'): 7, ('spoon', 'bowl'): 1, ('spoon', 'cup'): 1, 
      ('plate', 'cap'): 3, ('candle', 'cup'): 1, ('can', 'knife'): 2, ('cap', 'mug'): 3, 
      ('cap', 'cup'): 1, ('headphones', 'cap'): 9, ('headphones', 'phone'): 4, 
      ('sunglasses', 'scissors'): 1, ('fork', 'spoon'): 6, ('fork', 'knife'): 4, 
      ('fork', 'scissors'): 5, ('remote', 'can'): 1, ('cap', 'remote'): 1, ('can', 'cup'): 4, 
      ('cap', 'flower'): 7, ('flower', 'remote'): 1, ('flower', 'knife'): 2, ('bowl', 'spoon'): 3, 
      ('mug', 'phone'): 1, ('bowl', 'cap'): 1, ('scissors', 'fork'): 3, ('phone', 'remote'): 2, 
      ('phone', 'headphones'): 2, ('plate', 'can'): 1, ('tape', 'cap'): 2, ('spoon', 'plate'): 1, 
      ('candle', 'mug'): 2, ('candle', 'spoon'): 1, ('candle', 'can'): 1, ('sunglasses', 'book'): 2, 
      ('fork', 'cap'): 3, ('plate', 'remote'): 1, ('plate', 'phone'): 1, ('remote', 'cap'): 7, 
      ('can', 'scissors'): 1, ('remote', 'mug'): 2, ('scissors', 'candle'): 1, ('scissors', 'can'): 1, 
      ('candle', 'phone'): 1, ('book', 'remote'): 1, ('book', 'tape'): 1, ('cup', 'cap'): 1, 
      ('book', 'candle'): 1, ('remote', 'book'): 1, ('can', 'spoon'): 1, ('flower', 'plate'): 2, 
      ('flower', 'bowl'): 1, ('flower', 'cup'): 4, ('candle', 'plate'): 1, ('remote', 'cup'): 1, 
      ('tape', 'phone'): 1, ('knife', 'tape'): 1, ('bowl', 'flower'): 1, ('flower', 'mug'): 2, 
      ('headphones', 'tape'): 1, ('headphones', 'knife'): 1, ('headphones', 'scissors'): 1, 
      ('headphones', 'candle'): 3, ('remote', 'plate'): 1, ('remote', 'bowl'): 1, 
      ('headphones', 'can'): 1, ('spoon', 'flower'): 3, ('spoon', 'headphones'): 3, 
      ('spoon', 'tape'): 1, ('spoon', 'phone'): 1, ('tape', 'bowl'): 1, ('flower', 'can'): 1, 
      ('cap', 'candle'): 1, ('sunglasses', 'plate'): 1, ('sunglasses', 'bowl'): 1, 
      ('book', 'headphones'): 1, ('book', 'phone'): 1, ('headphones', 'sunglasses'): 1, 
      ('remote', 'flower'): 1, ('knife', 'plate'): 1, ('knife', 'bowl'): 1, ('remote', 'scissors'): 1, 
      ('fork', 'plate'): 1, ('fork', 'bowl'): 1, ('fork', 'cup'): 1}

In [5]:
sum(list(FP.values()))

748