In [3]:
import os, sys
sys.path.append("../")
import torch
from torch import nn
import re
from functools import partial
from tqdm import tqdm
from src.t5 import t5_encode_text, get_encoded_dim
from src.helper import *


### Process Chatgpt output

In [7]:
verbs = []
for file in ["at_the_bottom_of.txt", "on_top_of.txt", "entail_spatial_relation.txt"]:
    with open(f"../data/chatgpt_output/{file}", "r") as f:
        lines = f.readlines()
        for line in lines:
            match = re.match(r"\d+\. (?P<verb>.*)\b", line)
            if match is not None:
                v = match.group('verb')
                if len(v): verbs.append(v.lower())
print(len(verbs))
verbs = list(set(verbs))
verbs.sort()
print(len(verbs))
print(verbs)
#with open("../data/verbs.txt", "w") as f:
    #for v in verbs: f.write(f"{v}\n")

650
364
['abut', 'adhere', 'adjoin', 'adorn', 'affix', 'align', 'anchor', 'append', 'apply', 'approach', 'arrange', 'arrive', 'ascend', 'attach', 'avoid', 'bake', 'balance', 'bandage', 'base', 'bed', 'bend', 'bind', 'blanket', 'blend', 'block', 'boast', 'boil', 'bolt', 'bounce', 'brace', 'bracket', 'brag', 'break', 'bring', 'broil', 'brush', 'buckle', 'build', 'bundle', 'burn', 'burrow', 'bury', 'button', 'camouflage', 'capsize', 'carry', 'catch', 'cement', 'chase', 'chew', 'cinch', 'clamp', 'clasp', 'climb', 'cling', 'close', 'clothe', 'cluster', 'clutch', 'coat', 'collapse', 'combine', 'compress', 'conceal', 'connect', 'consume', 'cook', 'couple', 'cover', 'cram', 'crawl', 'creep', 'crisscross', 'cross', 'crumble', 'curl', 'curve', 'cut', 'dance', 'dangle', 'decorate', 'deepen', 'delve', 'depose', 'deposit', 'descend', 'devour', 'dig', 'disconnect', 'disguise', 'display', 'dive', 'dock', 'drag', 'drape', 'droop', 'drop', 'drown', 'embed', 'embrace', 'encase', 'encircle', 'enclose', '

In [9]:
verbs = [l.strip() for l in open("../data/verbs.txt", "r").readlines()]

general_transitive_verbs = []
with open(f"../data/chatgpt_output/transitive_verbs.txt", "r") as f:
    lines = f.readlines()
    for line in lines:
        match = re.match(r"\d+\. (?P<verb>.*)\b", line)
        if match is not None:
            v = match.group('verb').strip()
            if len(v): general_transitive_verbs.append(v.lower())
print(len(general_transitive_verbs))
general_transitive_verbs = list(set(general_transitive_verbs))
print(len(general_transitive_verbs))
complementary_transitive_verbs = list(set(general_transitive_verbs) - set(verbs))
complementary_transitive_verbs.sort()
print(len(complementary_transitive_verbs))
print(complementary_transitive_verbs)
#with open("../data/complementary_transitive_verbs.txt", "w") as f:
#    for v in complementary_transitive_verbs: f.write(f"{v}\n")

1254
858
725
['abandon', 'abduct', 'abhor', 'abide', 'absorb', 'abstain', 'accept', 'accompany', 'accuse', 'achieve', 'acquire', 'admire', 'adopt', 'adore', 'advance', 'advertise', 'advise', 'affect', 'agitate', 'aid', 'aim', 'alert', 'allocate', 'allow', 'alter', 'amaze', 'amuse', 'analyze', 'annoy', 'answer', 'anticipate', 'apologize', 'appoint', 'appreciate', 'apprehend', 'approve', 'argue', 'arrest', 'assemble', 'assess', 'assign', 'assist', 'associate', 'assume', 'attract', 'auction', 'authorize', 'award', 'bargain', 'barter', 'beat', 'beg', 'behave', 'belief', 'believe', 'benefit', 'betray', 'blame', 'bless', 'borrow', 'breathe', 'bribe', 'calculate', 'capture', 'care', 'cause', 'celebrate', 'challenge', 'change', 'charge', 'charm', 'cheat', 'check', 'choose', 'claim', 'clean', 'clear', 'collect', 'comfort', 'command', 'communicate', 'compare', 'compel', 'compete', 'complain', 'complete', 'complicate', 'compose', 'compromise', 'conclude', 'condemn', 'conduct', 'confess', 'confron

In [4]:
t5_name = 't5-small'
#t5_name = 'google/flan-t5-xxl'
t5 = partial(t5_encode_text, name = t5_name, dtype = torch.float32) 
t5("hello world").size()

torch.Size([1, 3, 512])

In [5]:
relations = ["is on top of", "is at the bottom of"] #, "has common boundary with"]
relation_embeddings = torch.stack([t5(f"A {r} B")[0, 0, :] for r in relations])
print(f"A {relations[0]} B")

A is on top of B


In [10]:
verb_embeddings = []
for v in tqdm(verbs+complementary_transitive_verbs):
    verb_embeddings.append(t5(f"B {convert_to_third_person_singular(v)} A")[0, 0, :])
verb_embeddings = torch.stack(verb_embeddings)
print(f"B {convert_to_third_person_singular(verbs[10])} A")

100%|██████████| 1089/1089 [00:04<00:00, 270.25it/s]

B arranges A





In [12]:
def cos_sim(a, b, eps=1e-8):
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

In [39]:
sim_mt = cos_sim(verb_embeddings, relation_embeddings)
print(sim_mt.size())
std, mean = torch.std_mean(sim_mt, dim=0)
print(std, mean)
sim_mt = (sim_mt - mean) / std
print(sim_mt)

torch.Size([1089, 3])
tensor([0.0321, 0.0314, 0.0313], device='cuda:0') tensor([0.6377, 0.6443, 0.6553], device='cuda:0')
tensor([[ 0.2911,  0.2994,  0.1782],
        [-0.5744, -0.5608, -0.5840],
        [ 0.6578,  0.7344,  0.5663],
        ...,
        [ 0.7240,  0.9933,  1.1877],
        [-2.1166, -2.2848, -2.3957],
        [-0.5334, -0.4472, -0.3971]], device='cuda:0')


In [14]:
print([f"A {r} B" for r in relations])
# encode relations with dummy subj and obj, i.e. A <rel> B
relations = ["is on top of", "is at the bottom of", "has common boundary with"]
relation_embeddings = torch.stack([t5(f"A {r} B")[0, 0, :] for r in relations])

print(f"A {convert_to_third_person_singular(verbs[100])} B")
# encode verbs with dummy subj and obj, i.e. A <verb> B
verb_embeddings = []
for v in tqdm(verbs+complementary_transitive_verbs):
    verb_embeddings.append(t5(f"A {convert_to_third_person_singular(v)} B")[0, 0, :])
verb_embeddings = torch.stack(verb_embeddings)

# compute embedding similarities
sim_mt = cos_sim(verb_embeddings, relation_embeddings)
std, mean = torch.std_mean(sim_mt, dim=0)
print(std, mean)
sim_mt = (sim_mt - mean) / std
print(sim_mt.size())

print([f"B {r} A" for r in relations])
print(f"B {convert_to_third_person_singular(verbs[100])} A")
# encode verbs with subj-obj interchanged, i.e. B <verb> A
verb_embeddings = []
for v in tqdm(verbs+complementary_transitive_verbs):
    verb_embeddings.append(t5(f"B {convert_to_third_person_singular(v)} A")[0, 0, :])
verb_embeddings = torch.stack(verb_embeddings)

# compute embedding similarities
sim_mt2 = cos_sim(verb_embeddings, relation_embeddings)
std, mean = torch.std_mean(sim_mt2, dim=0)
print(std, mean)
sim_mt2 = (sim_mt2 - mean) / std
print(sim_mt2.size())

d = {
    "Strongly prefer 'on top of': ": [],
    "Strongly prefer 'at the bottom of': ": [],
    "Always prefer 'on top of': ": [],
    "Always prefer 'at the bottom of': ": [],
}
for i, v in enumerate(verbs):
    if sim_mt[i][0] > sim_mt[i][1] and sim_mt2[i][0] < sim_mt[i][1]: 
        d["Strongly prefer 'on top of': "].append(v)
    elif sim_mt[i][0] < sim_mt[i][1] and sim_mt2[i][0] > sim_mt[i][1]:
        d["Strongly prefer 'at the bottom of': "].append(v)
    elif sim_mt[i][0] >= sim_mt[i][1] and sim_mt2[i][0] >= sim_mt[i][1]:
        d["Always prefer 'on top of': "].append(v)
    elif sim_mt[i][0] <= sim_mt[i][1] and sim_mt2[i][0] <= sim_mt[i][1]:
        d["Always prefer 'at the bottom of': "].append(v)

for k, v in d.items():
    print(k, len(v))

['A is on top of B', 'A is at the bottom of B', 'A has common boundary with B']
A encases B


  0%|          | 0/1089 [00:00<?, ?it/s]

100%|██████████| 1089/1089 [00:03<00:00, 278.89it/s]


tensor([0.0448, 0.0463, 0.0448], device='cuda:0') tensor([0.8718, 0.8715, 0.8748], device='cuda:0')
torch.Size([1089, 3])
['B is on top of A', 'B is at the bottom of A', 'B has common boundary with A']
B encases A


100%|██████████| 1089/1089 [00:03<00:00, 277.05it/s]


tensor([0.0321, 0.0314, 0.0313], device='cuda:0') tensor([0.6377, 0.6443, 0.6553], device='cuda:0')
torch.Size([1089, 3])
Strongly prefer 'on top of':  62
Strongly prefer 'at the bottom of':  96
Always prefer 'on top of':  89
Always prefer 'at the bottom of':  117


In [48]:
d["Strongly prefer 'on top of': "]

['abut',
 'append',
 'arrive',
 'bind',
 'bring',
 'carry',
 'catch',
 'cover',
 'cross',
 'dance',
 'drape',
 'embed',
 'encircle',
 'enclose',
 'hitch',
 'implant',
 'integrate',
 'intersect',
 'lash',
 'lock',
 'mix',
 'move',
 'observe',
 'overlap',
 'overlay',
 'overturn',
 'pass',
 'permeate',
 'pose',
 'present',
 'prop',
 'push',
 'ride',
 'roast',
 'scrub',
 'show',
 'staple',
 'step',
 'stock',
 'straddle',
 'superpose',
 'thrust',
 'topple',
 'upend',
 'walk',
 'zip-tie',
 'zipper']

In [50]:
import json
json.dump(d, open("../data/t5_xxl_top-bottom_preference.json", "w"), indent=4)

### Create prompts for LLaMA/Vicuna (since they are decoder-only)

In [56]:
template = "Question: Which one of the following spatial relationships between the subject and the object does the verb '{verb}' best entail? \n\nOptions: \nA. subject is on top of object \nB. subject is at the bottom of object \n\nAnswer: "

In [57]:
filler = {"verb": "catch"}
print(template.format(**filler))

Question: Which one of the following spatial relationships between the subject and the object does the verb 'catch' best entail? 

Options: 
A. subject is on top of object 
B. subject is at the bottom of object 

Answer: 


In [58]:
len(verbs+complementary_transitive_verbs)

1089

In [59]:
with open(f'../data/feed_decoder_LM/zero_shot.jsonl', 'w') as outfile:

    for i, v in enumerate(verbs+complementary_transitive_verbs):
        entry = {
            "question_id": i,
            "text": template.format(**{"verb": v})
        }
        json.dump(entry, outfile)
        outfile.write('\n')