In [1]:
import json 
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

import faiss

In [3]:
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/constrained_worlds.tar.gz
# ! tar -zxvf constrained_worlds.tar.gz && rm constrained_worlds.tar.gz 
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/id2name_mappings.tar.gz
# ! tar -zxvf id2name_mappings.tar.gz && rm id2name_mappings.tar.gz

In [4]:
entity_mapping = {}
with open("../id2name_mappings/entity_mapping.jsonl", "r") as f:
    for line in f:
        mapping = eval(line)
        entity_mapping[mapping["id"]] = mapping["en_label"]
        
len(entity_mapping)

6427497

In [7]:
len(set(entity_mapping.values()))

6427497

In [5]:
entities = []
with open("constrained_worlds/genie/entities.json", "r") as f:
    entities = json.load(f)
len(entities)

2724925

In [6]:
constrained_entity_mapping = {}

for ent in entities:
    constrained_entity_mapping[ent] = " ".join(entity_mapping[ent].split("_"))

In [7]:
len(constrained_entity_mapping)

2724925

In [8]:
ent2desc = {}
with open("id2label_and_desc.json", "r") as f:
    ent2desc = json.load(f)
len(ent2desc)

107674454

In [17]:
id2desc_intersected = {}
not_found = []
for ent in tqdm(entities):
    if ent in ent2desc:
        id2desc_intersected[ent] = ent2desc[ent]['description']
    else:
        not_found.append(ent)

100%|██████████| 2724925/2724925 [00:04<00:00, 670616.09it/s]


In [10]:
len(not_found), len(id2desc_intersected)

(5406, 2719519)

In [26]:
len(not_found_desc)

5406

In [15]:
with open('not_found_desc.json', 'r') as f:
    not_found_desc = json.load(f)

In [20]:
id2desc_intersected = {**id2desc_intersected, **not_found_desc}
len(id2desc_intersected)

2724925

In [22]:
empty = 0
for val in id2desc_intersected.values():
    if val == '':
        empty += 1
empty

122664

In [32]:
id2label_desc = {}

for ent_id in constrained_entity_mapping.keys():
    id2label_desc[ent_id] = constrained_entity_mapping[ent_id] + "; " + id2desc_intersected[ent_id]

In [45]:
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever').to('cuda:4')

In [46]:
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings

def embed_entity_batch(entity_list):
    inputs = tokenizer(entity_list, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**inputs.to('cuda:4'))
    embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
    return embeddings

In [42]:
# entity_items = list(constrained_entity_mapping.items())

entity_items = list(id2label_desc.items())
ids = [int(item[0][1:]) for item in entity_items]
names = [" ".join(item[1].split("_")) for item in entity_items]
del entity_items
del entity_mapping
len(set(ids)), len(set(names))

NameError: name 'entity_mapping' is not defined

In [47]:
embeddings = []
batch_size = 100

for i in tqdm(range(0, len(names), batch_size)):

    if i + batch_size > len(names):
        entity_list = names[i: len(names)]
    else:
        entity_list = names[i: i + batch_size]

    embeddings.append(embed_entity_batch(entity_list).detach().to('cpu'))

100%|██████████| 27250/27250 [45:10<00:00, 10.06it/s]  


In [48]:
output = np.array(torch.concat(embeddings))
output.shape

(2724925, 768)

In [49]:
# ids = [int(id_[1:]) for id_ in ids]
assert all(isinstance(id_, int) for id_ in ids)
ids[:5]

[1607497, 22958416, 1848116, 1513683, 19880251]

In [50]:
dim = output.shape[1]
metric = faiss.METRIC_INNER_PRODUCT
index = faiss.index_factory(dim, "IDMap,Flat", metric)

index.add_with_ids(output, ids)

In [51]:
print(index.is_trained)

True


In [52]:
from time import time

before = time()
distances, indices = index.search(output[100:103, :], 3)
after = time()
after - before

4.170503377914429

In [53]:
indices, distances

(array([[ 7859497,  7859498,  7989648],
        [15238435,  1124728,  1810614],
        [ 6354654,  6354656,  1031041]]),
 array([[2.1675177, 1.9223285, 1.7584894],
        [1.9325339, 1.475885 , 1.4460012],
        [2.5693328, 1.7931323, 1.7416024]], dtype=float32))

In [54]:
id2label_desc["Q"+str(7859497)], id2label_desc["Q"+str(7859498)], id2label_desc["Q"+str(7989648)]

('Tworzanice; village in Greater Poland, Poland',
 'Tworzanki; village in Greater Poland, Poland',
 'Westrza; village in Greater Poland, Poland')

In [58]:
constrained_entity_mapping["Q"+str(7859497)], constrained_entity_mapping["Q"+str(7859498)], constrained_entity_mapping["Q"+str(63570)]

('Tworzanice', 'Tworzanki', 'Nowe Borza')

In [56]:
constrained_entity_mapping["Q"+str(15238435)], constrained_entity_mapping["Q"+str(1124728)], constrained_entity_mapping["Q"+str(1810614)]

('Kikoboga Airstrip', 'Mikumi National Park', 'Kiboga')

In [59]:
id2label_desc["Q"+str(15238435)], id2label_desc["Q"+str(1810614)], id2label_desc["Q"+str(892064)]

('Kikoboga Airstrip; airport in Mikumi National Park',
 'Kiboga; human settlement',
 'Kiboga District; districts of Uganda')

In [60]:
id2label_desc["Q"+str(6354654)], id2label_desc["Q"+str(6354656)], id2label_desc["Q"+str(1031041)]

('Kalutara Physical Culture Club; ',
 'Kalutara Stadium; multi-use stadium in Kalutara town',
 'LILALU; sports club')

In [23]:
constrained_entity_mapping["Q"+str(6354654)], constrained_entity_mapping["Q"+str(6354658)], constrained_entity_mapping["Q"+str(6354656)]

('Kalutara Physical Culture Club', 'Kalutara Town Club', 'Kalutara Stadium')

In [61]:
ids[100:103]

[7859497, 15238435, 6354654]

In [62]:
faiss.write_index(index, "entities_with_descriptions.index")

In [63]:
index = faiss.read_index("entities_with_descriptions.index")  # load the index

In [64]:
distances, indices = index.search(output[100:103, :], 3)

In [None]:
distances, indices

(array([[2.3877087, 1.7078212, 1.5154692],
        [1.8492612, 1.4968052, 1.3931653],
        [2.3846188, 1.9597743, 1.8561847]], dtype=float32),
 array([[ 7859497,  7859498,    63570],
        [15238435,  1810614,   892064],
        [ 6354654,  6354658,  6354656]]))

In [13]:
emb = embed_entity_batch(["Germany"])

In [14]:
index.search(np.array(emb.detach().cpu()), 3)

(array([[1.2840521, 1.1428443, 1.1164908]], dtype=float32),
 array([[    183, 5552084, 1649644]]))

In [15]:
constrained_entity_mapping["Q"+str(183)], constrained_entity_mapping["Q"+str(5552084)], constrained_entity_mapping["Q"+str(1649644)]

('Germany', 'Germany 1985', 'Gaël Germany')

In [65]:
with open("id2label_desc.json", 'w') as f:
    json.dump(id2label_desc, f)

In [17]:
emb = embed_entity_batch(["steve Jobs"])

In [18]:
index.search(np.array(emb.detach().cpu()), 3)

(array([[1.3705723, 1.278435 , 1.1589789]], dtype=float32),
 array([[   19837, 19801748, 18754959]]))

In [20]:
constrained_entity_mapping["Q"+str(19837)], constrained_entity_mapping["Q"+str(19801748)], constrained_entity_mapping["Q"+str(18754959)]

('Steve Jobs', 'Becoming Steve Jobs', 'Steve Jobs (film)')