In [131]:
import pandas as pd
import os
from PIL import Image, ImageDraw, ImageFont
import json

In [132]:
vg_path_1 = "/mmfs1/gscratch/krishna/clip_hn/datasets/vg/VG_100K"
vg_path_2 = "/mmfs1/gscratch/krishna/clip_hn/datasets/vg/VG_100K_2"
vg_path = "/gscratch/krishna/shared_data/datasets/vg" 

In [133]:
def tag(img, objs):
    W, H = img.size
    img1 = img.copy()
    draw = ImageDraw.Draw(img1)
    for i, obj in enumerate(objs):
        box = obj["bbox"]
        draw.rectangle(box, outline="red", width=6)
        x1, y1, x2, y2 = box
        label = obj["name"]

        font = ImageFont.load_default()
        font_box = font.getbbox(label)
        text_width = font_box[2] - font_box[0]
        text_height = font_box[3] - font_box[1]

        if x1 + text_width > W:
            x1 = x1 - text_width
        if y1 + text_height > H:
            y1 = y1 - text_height

        draw.rectangle((x1, y1 - text_height, x1 + text_width, y1), fill="red")
        draw.text((x1, y1 - text_height), label, fill="white", font=font)
    return img1

In [134]:
def get_vg_img_from_idx(idx):
    img_path = os.path.join(vg_path_1, f"{idx}.jpg")
    if not os.path.exists(img_path):
        img_path = os.path.join(vg_path_2, f"{idx}.jpg")
    img = Image.open(img_path)
    return img

In [135]:
def display_after_resize(img, resize=True):
    if resize:
        original_width, original_height = img.size
    
        new_height = 512  # or whatever new height you want
        aspect_ratio = original_height / original_width
        new_width = int(new_height * aspect_ratio)
        img = img.resize((new_height, new_width))
    
    display(img)

In [136]:
def get_label_attributes(objects):
    obj2attrs = {}
    for i, obj in enumerate(objects):
        if len(obj['names']) == 1:
            obj_name = obj['names'][0]
            if obj_name not in obj2attrs:
                obj2attrs[obj_name] = (obj['attributes'] if 'attributes' in obj else [])
            else:
                obj2attrs[obj_name] += (obj['attributes'] if 'attributes' in obj else [])
    return obj2attrs

def get_label_relationships(relationships):
    objs2rlts = {}
    for i, rlt in enumerate(relationships):
        sub = rlt['subject_id']
        obj = rlt['object_id']
        rlt = rlt['predicate'].lower().strip()
        pair = (sub, obj)
        objs2rlts[pair] = objs2rlts[pair] + [rlt] if pair in objs2rlts else [rlt]
    return objs2rlts

def get_pred_relationships(relationships):
    objs2rlts = {}
    for i, rlt in enumerate(relationships):
        sub = rlt['subject_id']
        obj = rlt['object_id']
        rlts = [r.lower().strip() for r in rlt['relationships']]
        pair = (sub, obj)
        objs2rlts[pair] = objs2rlts[pair] + rlts if pair in objs2rlts else rlts
    return objs2rlts
            
def get_pred_attributes(objects):
    obj2attrs = {}
    for i, obj in enumerate(objects):
        if obj['name'] in obj2attrs:
            obj2attrs[obj['name']] += obj2attrs[obj['name']]
        else:
            obj2attrs[obj['name']] = obj['attributes']
    return obj2attrs

In [27]:
# load vg scene graphs as a df 
sg = json.load(open(os.path.join(vg_path, "scene_graphs.json")))
sg_df = pd.DataFrame.from_records(sg)
sg_df.head()

Unnamed: 0,relationships,image_id,objects
0,"[{'synsets': ['along.r.01'], 'predicate': 'ON'...",1,"[{'synsets': ['clock.n.01'], 'h': 339, 'object..."
1,"[{'synsets': ['wear.v.01'], 'predicate': 'wear...",2,"[{'synsets': [], 'h': 103, 'object_id': 5069, ..."
2,"[{'synsets': ['in.r.01'], 'predicate': 'in fro...",3,"[{'synsets': [], 'h': 79, 'object_id': 5091, '..."
3,"[{'synsets': ['have.v.01'], 'predicate': 'has'...",4,"[{'synsets': ['curtain.n.01'], 'h': 300, 'obje..."
4,"[{'synsets': ['along.r.01'], 'predicate': 'ON'...",5,"[{'synsets': ['floor.n.01'], 'h': 108, 'object..."


In [72]:
preds = pd.read_csv('output_defined_cog_vlm.csv')
preds.head()

Unnamed: 0,image_id,objects,relationships
0,770,"[{'object_id': 1637131, 'name': 'crosswalk', '...",[]
1,861,"[{'object_id': 1030020, 'name': 'building', 'b...",[]
2,1017,"[{'object_id': 1541014, 'name': 'artwork', 'bb...","[{'subject_id': 3568085, 'object_id': 1541018,..."
3,2748,"[{'object_id': 1034635, 'name': 'trees', 'bbox...","[{'subject_id': 1034645, 'object_id': 1034635,..."
4,3891,"[{'object_id': 4355114, 'name': 'water', 'bbox...","[{'subject_id': 4355161, 'object_id': 4355122,..."


In [73]:
labels = sg_df[sg_df['image_id'].isin(preds['image_id'])]
print(len(labels))
labels.head()

100


Unnamed: 0,relationships,image_id,objects
769,"[{'synsets': ['along.r.01'], 'predicate': 'ON'...",770,"[{'synsets': ['crossing.n.05'], 'h': 148, 'obj..."
860,"[{'synsets': ['stand.v.01'], 'predicate': 'sta...",861,"[{'synsets': ['tile.n.01'], 'h': 284, 'object_..."
1016,"[{'synsets': ['look.v.02'], 'predicate': 'look...",1017,"[{'synsets': ['picture.n.01'], 'h': 386, 'obje..."
2747,"[{'synsets': ['next.r.01'], 'predicate': 'next...",2748,"[{'synsets': ['flower.n.01'], 'h': 81, 'object..."
3890,"[{'synsets': [], 'predicate': 'ON', 'relations...",3891,"[{'synsets': [], 'h': 90, 'object_id': 4355110..."


In [105]:
obj_id2info = {}

In [106]:
for _, row in labels.iterrows():
    objects = row['objects']
    rlts = row['relationships']
    img_id = row['image_id']
    for obj in objects:
        # print(obj)
        obj_id = obj['object_id']
        if obj_id not in obj_id2info:
            bbox = [obj['x'], obj['y'], obj['x'] + obj['w'], obj['y'] + obj['h']]
            obj_info = {'name': obj['names'][0], 'bbox': bbox}
            # print(obj_info)
            obj_id2info[obj_id] = obj_info

In [139]:
for _, row in preds.iterrows():
    pred_id = row['image_id'] 
    img = get_vg_img_from_idx(pred_id)
    
    label_row = labels[labels['image_id'] == pred_id]

    pred_rlts = eval(row['relationships'])
    pred_objs = eval(row['objects'])
    pred_attrs = get_pred_attributes(pred_objs)
    pred_rlts = get_pred_relationships(pred_rlts)
    # print(pred_attrs)
    # print(pred_rlts)
    
    # label_rlts = label_row['relationships'].values[0]
    # label_objs = label_row['objects'].values[0]
    # label_attrs = get_label_attributes(label_objs) 
    # label_rlts = get_label_relationships(label_rlts)
    # print(label_attrs)
    # print(label_rlts)

    for obj in eval(row['objects']):
        tagged_img = tag(img, [obj])
        display_after_resize(tagged_img)
        obj_name = obj['name']
        all_attrs = set(pred_attrs[obj_name]) # + label_attrs[obj_name]
        print(f"{obj_name}: {all_attrs}")
        
    for rlt in eval(row['relationships']):
        subj_id, obj_id = rlt['subject_id'], rlt['object_id']
        subj, obj = obj_id2info[subj_id], obj_id2info[obj_id]
        tagged_img = tag(img, [subj, obj])
        display_after_resize(tagged_img)
        all_rlts = set(pred_rlts[(subj_id, obj_id)]) # + label_rlts[(subj_id, obj_id)]
        print(f"{subj['name']} -> {obj['name']}: {all_rlts}")