In [1]:
import json 
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

import faiss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/constrained_worlds.tar.gz
# ! tar -zxvf constrained_worlds.tar.gz && rm constrained_worlds.tar.gz 
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/id2name_mappings.tar.gz
# ! tar -zxvf id2name_mappings.tar.gz && rm id2name_mappings.tar.gz

In [3]:
relation_mapping = {}
with open("id2name_mappings/relation_mapping.jsonl", "r") as f:
    for line in f:
        mapping = eval(line)
        relation_mapping[mapping["id"]] = mapping["en_label"]
        
len(relation_mapping)

1187

In [4]:
relations = []
with open("constrained_worlds/genie_t5_tokenizeable/relations.json", "r") as f:
    relations = json.load(f)
len(relations)

1088

In [5]:
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever').to('cuda:1')

In [6]:
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings

def embed_entity_batch(entity_list):
    inputs = tokenizer(entity_list, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**inputs.to('cuda:1'))
    embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
    return embeddings

In [7]:
relation_items = list(relation_mapping.items())
ids = [int(item[0][1:]) for item in relation_items]
names = [item[1] for item in relation_items]

len(set(ids)), len(set(names))

(1187, 1187)

In [8]:
ids[:5]

[2673, 571, 186, 1066, 453]

In [9]:
embeddings = []
batch_size = 100

for i in tqdm(range(0, len(names), batch_size)):

    if i + batch_size > len(names):
        entity_list = names[i: len(names)]
    else:
        entity_list = names[i: i + batch_size]

    embeddings.append(embed_entity_batch(entity_list).detach().to('cpu'))

100%|██████████| 12/12 [00:01<00:00,  7.90it/s]


In [10]:
output = np.array(torch.concat(embeddings))
output.shape

(1187, 768)

In [11]:
dim = output.shape[1]
metric = faiss.METRIC_INNER_PRODUCT
index = faiss.index_factory(dim, "IDMap,Flat", metric)

index.add_with_ids(output, ids)

In [12]:
print(index.is_trained)

True


In [13]:
distances, indices = index.search(output[100:103, :], 3)

In [14]:
indices, distances

(array([[2989, 3161, 1660],
        [2923, 2151, 7863],
        [2578, 2579, 6153]]),
 array([[1.6228969 , 1.2542918 , 1.0368476 ],
        [1.7279    , 1.3264964 , 1.0802442 ],
        [1.4031705 , 0.8489038 , 0.83135617]], dtype=float32))

In [15]:
relation_mapping["P"+str(2989)], relation_mapping["P"+str(3161)], relation_mapping["P"+str(1660)]

('has grammatical case', 'has grammatical mood', 'has index case')

In [16]:
relation_mapping["P"+str(2923)], relation_mapping["P"+str(2151)], relation_mapping["P"+str(7863)]

('focal height', 'focal length', 'aperture')

In [17]:
relation_mapping["P"+str(2578)], relation_mapping["P"+str(2579)], relation_mapping["P"+str(6153)]

('studies', 'studied by', 'research site')

In [18]:
ids[100:103]

[2989, 2923, 2578]

In [19]:
faiss.write_index(index, "relations.index")

In [20]:
with open('relation_mapping.json', 'w') as f:
    json.dump(relation_mapping, f)