In [1]:
import json 
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

import faiss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/constrained_worlds.tar.gz
# ! tar -zxvf constrained_worlds.tar.gz && rm constrained_worlds.tar.gz 
# ! wget https://huggingface.co/datasets/martinjosifoski/SynthIE/resolve/main/id2name_mappings.tar.gz
# ! tar -zxvf id2name_mappings.tar.gz && rm id2name_mappings.tar.gz

In [3]:
entity_mapping = {}
with open("id2name_mappings/entity_mapping.jsonl", "r") as f:
    for line in f:
        mapping = eval(line)
        entity_mapping[mapping["id"]] = mapping["en_label"]
        
len(entity_mapping)

6427497

In [4]:
entities = []
with open("constrained_worlds/genie/entities.json", "r") as f:
    entities = json.load(f)
len(entities)

2724925

In [5]:
constrained_entity_mapping = {}

for ent in entities:
    constrained_entity_mapping[ent] = " ".join(entity_mapping[ent].split("_"))

In [10]:
len(constrained_entity_mapping)

2724925

In [11]:
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever').to('cuda:0')

In [12]:
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings

def embed_entity_batch(entity_list):
    inputs = tokenizer(entity_list, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**inputs.to('cuda:0'))
    embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
    return embeddings

In [13]:
entity_items = list(constrained_entity_mapping.items())
ids = [int(item[0][1:]) for item in entity_items]
names = [" ".join(item[1].split("_")) for item in entity_items]
del entity_items
del entity_mapping
len(set(ids)), len(set(names))

(2724925, 2724925)

In [14]:
embeddings = []
batch_size = 200

for i in tqdm(range(0, len(names), batch_size)):

    if i + batch_size > len(names):
        entity_list = names[i: len(names)]
    else:
        entity_list = names[i: i + batch_size]

    embeddings.append(embed_entity_batch(entity_list).detach().to('cpu'))

100%|██████████| 13625/13625 [22:11<00:00, 10.24it/s]


In [15]:
output = np.array(torch.concat(embeddings))
output.shape

(2724925, 768)

In [16]:
# ids = [int(id_[1:]) for id_ in ids]
assert all(isinstance(id_, int) for id_ in ids)
ids[:5]

[1607497, 22958416, 1848116, 1513683, 19880251]

In [17]:
dim = output.shape[1]
metric = faiss.METRIC_INNER_PRODUCT
index = faiss.index_factory(dim, "IDMap,Flat", metric)

index.add_with_ids(output, ids)

In [18]:
print(index.is_trained)

True


In [19]:
from time import time

before = time()
distances, indices = index.search(output[100:103, :], 3)
after = time()
after - before

1.1845037937164307

In [20]:
indices, distances

(array([[ 7859497,  7859498,    63570],
        [15238435,  1810614,   892064],
        [ 6354654,  6354658,  6354656]]),
 array([[2.3877087, 1.7078212, 1.5154692],
        [1.8492612, 1.4968052, 1.3931653],
        [2.3846188, 1.9597743, 1.8561847]], dtype=float32))

In [21]:
constrained_entity_mapping["Q"+str(7859497)], constrained_entity_mapping["Q"+str(7859498)], constrained_entity_mapping["Q"+str(63570)]

('Tworzanice', 'Tworzanki', 'Nowe Borza')

In [22]:
constrained_entity_mapping["Q"+str(15238435)], constrained_entity_mapping["Q"+str(1810614)], constrained_entity_mapping["Q"+str(892064)]

('Kikoboga Airstrip', 'Kiboga', 'Kiboga District')

In [23]:
constrained_entity_mapping["Q"+str(6354654)], constrained_entity_mapping["Q"+str(6354658)], constrained_entity_mapping["Q"+str(6354656)]

('Kalutara Physical Culture Club', 'Kalutara Town Club', 'Kalutara Stadium')

In [24]:
ids[100:103]

[7859497, 15238435, 6354654]

In [25]:
faiss.write_index(index, "entities.index")

In [26]:
index = faiss.read_index("entities.index")  # load the index

In [27]:
distances, indices = index.search(output[100:103, :], 3)

In [28]:
distances, indices

(array([[2.3877087, 1.7078212, 1.5154692],
        [1.8492612, 1.4968052, 1.3931653],
        [2.3846188, 1.9597743, 1.8561847]], dtype=float32),
 array([[ 7859497,  7859498,    63570],
        [15238435,  1810614,   892064],
        [ 6354654,  6354658,  6354656]]))

In [29]:
emb = embed_entity_batch(["Germany"])

In [30]:
index.search(np.array(emb.detach().cpu()), 3)

(array([[1.2840521, 1.1428443, 1.1164908]], dtype=float32),
 array([[    183, 5552084, 1649644]]))

In [31]:
constrained_entity_mapping["Q"+str(183)], constrained_entity_mapping["Q"+str(5552084)], constrained_entity_mapping["Q"+str(1649644)]

('Germany', 'Germany 1985', 'Gaël Germany')

In [32]:
with open("constrained_entity_mapping.json", 'w') as f:
    json.dump(constrained_entity_mapping, f)