In [1]:
import json
from tqdm import tqdm

In [11]:
# Whatsup dataset download link: https://drive.google.com/drive/u/0/folders/164q6X9hrvP-QYpi3ioSnfMuyHpG5oRkZ

version = "b"
annotations = []
if 'a' in version:
    annotations.extend(json.load(open("../data/whatsup_vlms/controlled_images_dataset.json", "r")))
if 'b' in version:
    annotations.extend(json.load(open("../data/whatsup_vlms/controlled_clevr_dataset.json", "r")))

print("num_examples = ", len(annotations))
print("E.g.")
annotations[0]

num_examples =  408
E.g.


{'image_path': 'data/controlled_clevr/mug_right_of_knife.jpeg',
 'caption_options': ['A mug to the right of a knife',
  'A mug in front of a knife',
  'A mug behind a knife',
  'A mug to the left of a knife']}

In [16]:
### write to json with triplet representations
J = []
SUBJ, OBJ = [], []
for a in tqdm(annotations):
    filename = a['image_path'][5:]
    tmp = a['image_path'].split("/")[-1][:-5].split("_")
    subj, obj = tmp[0], tmp[-1]
    SUBJ.append(subj)
    OBJ.append(obj)
    r = " ".join(tmp[1:-1])
    J.append([
        a['caption_options'][0],
        "whatsup_vlms/" + filename,
        (subj, obj, r)
    ])
print("\n=== Statistics of original WhatsUp dataset ===")
print("     #instances = ", len(J))
print("     #unique subj = ", len(set(SUBJ)))
print("     #unique obj = ", len(set(OBJ)))
print("     #unique concepts = ", len(set(SUBJ).union(set(OBJ))))
print("     #unique train_triplets = ", len(set([a[-1] for a in J])))
#json.dump(J, open(f"../data/aggregated/whatsup_vlm_{version}.json", "w"), indent=4)


100%|██████████| 408/408 [00:00<00:00, 370646.75it/s]


=== Statistics of original WhatsUp dataset ===
     #instances =  408
     #unique subj =  9
     #unique obj =  17
     #unique concepts =  18
     #unique train_triplets =  408





In [19]:
SYMMETRIC_REL = {
    "left of": "right of",
    "right of": "left of", 
    "in-front of": "behind", 
    "behind": "in-front of",
}

### write to json, filtering for relations (rel_version) and objects (skip_nouns)
### toggle autofill --- only works for version b
J = []
SUBJ, OBJ = [], []
assert version == "b"
autofill_symmetric_rel = True
skip_nouns = ["sunglasses", "remote", "phone"] # None #
rel_version = "lr"

suffix = "_autofill" if autofill_symmetric_rel else ""
if skip_nouns is not None: suffix += "_remove_" + "_".join([x[:3] for x in skip_nouns])
for a in tqdm(annotations):
    filename = a['image_path'][5:]
    tmp = a['image_path'].split("/")[-1][:-5].split("_")
    subj, obj = tmp[0], tmp[-1]
    if skip_nouns is not None and (subj in skip_nouns or obj in skip_nouns): continue
    SUBJ.append(subj)
    OBJ.append(obj)
    r = " ".join(tmp[1:-1])
    if rel_version == "lr" and r in ["in-front of", "behind"]: continue
    if rel_version == "fb" and r in ['left of', "right of"]: continue
    J.append([
        a['caption_options'][0],
        "whatsup_vlms/" + filename,
        (subj, obj, r)
    ])

if autofill_symmetric_rel:
    autofill = []
    for a in J:
        subj, obj, r = a[-1]
        change_r = a[0].replace(r, SYMMETRIC_REL[r])
        tmp = change_r.split()
        tmp[1] = obj
        tmp[-1] = subj
        autofill.append([
            " ".join(tmp),
            a[1],
            (obj, subj, SYMMETRIC_REL[r])
        ])
        SUBJ.append(obj)
        OBJ.append(subj)
        
    print(f"\nautofill {len(autofill)} tuples")
    J.extend(autofill)

print("\n=== Statistics of preprocessed WhatsUp dataset ===")
print("     #instances = ", len(J))
print("     #unique subj = ", len(set(SUBJ)))
print("     #unique obj = ", len(set(OBJ)))
print("     #unique concepts = ", len(set(SUBJ).union(set(OBJ))))
print("     #unique train_triplets = ", len(set([a[-1] for a in J])))
#json.dump(J, open(f"../data/aggregated/whatsup_vlm_{version}_{rel_version}{suffix}.json", "w"), indent=4)


100%|██████████| 408/408 [00:00<00:00, 713328.90it/s]


autofill 154 tuples

=== Statistics of preprocessed WhatsUp dataset ===
     #instances =  308
     #unique subj =  15
     #unique obj =  15
     #unique concepts =  15
     #unique train_triplets =  308



