In [1]:
import json 
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

import faiss

In [2]:
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/constrained_worlds.tar.gz
# ! tar -zxvf constrained_worlds.tar.gz && rm constrained_worlds.tar.gz 
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/id2name_mappings.tar.gz
# ! tar -zxvf id2name_mappings.tar.gz && rm id2name_mappings.tar.gz

In [1]:
relation_mapping = {}
with open("id2name_mappings/relation_mapping.jsonl", "r") as f:
    for line in f:
        mapping = eval(line)
        relation_mapping[mapping["id"]] = mapping["en_label"]
        
len(relation_mapping)

FileNotFoundError: [Errno 2] No such file or directory: 'id2name_mappings/relation_mapping.jsonl'

In [None]:
relation2desc = {}
with open("data/relation2description.json", "r") as f:
   relation2desc = json.load(f)
        
len(relation2desc)

1187

In [None]:
rel2name_and_desc = {}

for rel_id in relation2desc.keys():
    rel2name_and_desc[rel_id] = relation_mapping[rel_id] + "; " + relation2desc[rel_id]

In [None]:
rel2name_and_desc

{'P2673': 'next crossing upstream; next crossing of this river, canal, etc. upstream of this subject',
 'P571': 'inception; time when an entity begins to exist; for date of official opening use P1619',
 'P186': 'material used; material the subject or the object is made of or derived from (do not confuse with P10672 which is used for processes)',
 'P1066': 'student of; person who has taught this person',
 'P453': 'character role; specific role played or filled by subject -- use only as qualifier of "cast member" (P161), "voice actor" (P725)',
 'P1418': 'orbits completed; number of orbits a spacecraft has done around a body',
 'P1108': 'electronegativity; tendency of an atom or functional group to attract electrons',
 'P2319': 'elector; people or other entities which are qualified to participate in the subject election',
 'P5998': 'distributary; stream that branches off and flows away from the river',
 'P1536': 'immediate cause of; immediate effect of this cause',
 'P6532': 'has phenotyp

In [None]:
relations = []
with open("constrained_worlds/genie_t5_tokenizeable/relations.json", "r") as f:
    relations = json.load(f)
len(relations)

1088

In [None]:
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever').to('cuda:1')

In [None]:
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings

def embed_entity_batch(entity_list):
    inputs = tokenizer(entity_list, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**inputs.to('cuda:1'))
    embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
    return embeddings

In [None]:
# relation_items = list(relation_mapping.items())
relation_items = list(rel2name_and_desc.items())
ids = [int(item[0][1:]) for item in relation_items]
names = [item[1] for item in relation_items]

len(set(ids)), len(set(names))

(1187, 1187)

In [14]:
ids[:5]

[2673, 571, 186, 1066, 453]

In [15]:
names[:5]

['next crossing upstream; next crossing of this river, canal, etc. upstream of this subject',
 'inception; time when an entity begins to exist; for date of official opening use P1619',
 'material used; material the subject or the object is made of or derived from (do not confuse with P10672 which is used for processes)',
 'student of; person who has taught this person',
 'character role; specific role played or filled by subject -- use only as qualifier of "cast member" (P161), "voice actor" (P725)']

In [16]:
embeddings = []
batch_size = 100

for i in tqdm(range(0, len(names), batch_size)):

    if i + batch_size > len(names):
        entity_list = names[i: len(names)]
    else:
        entity_list = names[i: i + batch_size]

    embeddings.append(embed_entity_batch(entity_list).detach().to('cpu'))

100%|██████████| 12/12 [00:02<00:00,  4.37it/s]


In [17]:
output = np.array(torch.concat(embeddings))
output.shape

(1187, 768)

In [18]:
dim = output.shape[1]
metric = faiss.METRIC_INNER_PRODUCT
index = faiss.index_factory(dim, "IDMap,Flat", metric)

index.add_with_ids(output, ids)

In [19]:
print(index.is_trained)

True


In [20]:
distances, indices = index.search(output[100:103, :], 3)

In [21]:
indices, distances

(array([[2989,  279, 2936],
        [2923, 7863, 2929],
        [2578,  101,  186]]),
 array([[1.9348701, 1.0913169, 1.0648031],
        [1.9616466, 1.3006667, 1.2933571],
        [2.0756724, 1.4308015, 1.3919606]], dtype=float32))

In [22]:
relation_mapping["P"+str(2989)], relation_mapping["P"+str(3161)], relation_mapping["P"+str(1660)]

('has grammatical case', 'has grammatical mood', 'has index case')

In [23]:
relation_mapping["P"+str(2989)], relation_mapping["P"+str(279)], relation_mapping["P"+str(2936)]

('has grammatical case', 'subclass of', 'language used')

In [16]:
relation_mapping["P"+str(2923)], relation_mapping["P"+str(2151)], relation_mapping["P"+str(7863)]

('focal height', 'focal length', 'aperture')

In [24]:
relation_mapping["P"+str(2923)], relation_mapping["P"+str(7863)], relation_mapping["P"+str(2929)]

('focal height', 'aperture', 'lighthouse range')

In [25]:
relation_mapping["P"+str(2578)], relation_mapping["P"+str(101)], relation_mapping["P"+str(186)]

('studies', 'field of work', 'material used')

In [26]:
ids[100:103]

[2989, 2923, 2578]

In [27]:
faiss.write_index(index, "relations_with_descriptions.index")

In [28]:
with open('relation_mapping_with_descriptions.json', 'w') as f:
    json.dump(rel2name_and_desc, f)