# Building faiss index for entity linking purposes

In [3]:
%cd ..

/home/codeholder/code/python-playground/app_noisemon


Given labelstudio NER+NEL results, I will build faiss index to match entity context embedding with QID

#### Pipeline
1. Load labelstudio'ed texts
2. For each labeled document
    - get "QID to alias" matching
    - get "QID to vector" matching
3. Build index

## 1. Import data

In [4]:
import json
import os
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
import numpy as np
import torch

In [3]:
data_path = Path("./data/05-labeled/project-10-at-2021-10-02-22-43-62e3404c.json")

In [4]:
data = json.loads(data_path.read_text())

In [5]:
len(data)

79

## 2. Prepare labelstudio data reformatters

In [50]:
from typing import List, Dict, Set, Tuple

In [51]:
from scripts.char_span_to_vector import ContextualEmbedding
embedder = ContextualEmbedding()

Some weights of the model checkpoint at cointegrated/LaBSE-en-ru were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [52]:
def ann_to_ent(labelstudio, embedder=embedder) -> Tuple[Dict[str, Set[str]], Dict[str, List[torch.Tensor]]]:
    """
    Given labelstudio data returns two dicts: Dict[QID, {Alias}] and Dict[QID, [Vector]]
    """
    # Temporary structures
    id_to_qid_name_pair = defaultdict(dict)
    qid_and_span_pairs = [] # [(qid, (span_s,span_e))]
    # Result buffers
    qid_to_vector_list = defaultdict(list)
    qid_to_alias = defaultdict(set)
    
    text = labelstudio["data"]["text"]
    embedder.embed_text(text)
    
    # 1. Matching labeling result chunks by their labelstudio internal IDs
    for chunk in labelstudio["annotations"][0]["result"]:
        if chunk["from_name"] == "ner":
            id_to_qid_name_pair[chunk["id"]]["text"] = chunk["value"]["text"]
        if chunk["from_name"] == "entity":
            id_to_qid_name_pair[chunk["id"]]["qid"] = chunk["value"]["text"][0]

    # 2. Match QIDs with respecting spans and text chunks (aliases)
    for chunk in labelstudio["annotations"][0]["result"]:
        if chunk["from_name"] == "ner":
            QID = id_to_qid_name_pair[chunk["id"]].get("qid", None)
            if not QID:
                print(f"{id_to_qid_name_pair[chunk['id']]['text']} has no matching QID")
                continue
            # aliases
            entity_start, entity_end = chunk["value"]["start"], chunk["value"]["end"]
            entity = text[entity_start: entity_end]
            qid_to_alias[QID].add(entity)
            # vecs will be calculated later
            qid_and_span_pairs.append((QID, (entity_start, entity_end)))
            
    # Given spans, get context vectors
    qids, spans = zip(*qid_and_span_pairs) # [(qid, span), (qid, span)] -> [qid, qid], [span, span]
    entity_vectors = embedder.get_char_span_vectors(spans)
    for QID, entity_vector in zip(qids, entity_vectors):
        qid_to_vector_list[QID].append(entity_vector)
    
    return qid_to_alias, qid_to_vector_list

In [54]:
qid_to_vector_list = defaultdict(list)
qid_to_alias = defaultdict(set)

for labelstudio in tqdm(data):
    # Process data
    qid_aliass, qid_vecs = ann_to_ent(labelstudio)
    # Merge results
    for QID, vecs in qid_vecs.items():
        qid_to_vector_list[QID].extend(vecs)
    for QID, aliases in qid_aliass.items():
        qid_to_alias[QID].update(aliases)

In [55]:
# Total num of vectors
sum([len(v) for v in qid_to_vector_list.values()])

222

In [63]:
# Show aliases to check sanity
for qid, aliases in qid_to_alias.items():
    print(f"{qid:<15}", aliases)

Q1642605        {'РУСАЛа', 'Русал', 'РусАл'}
Q952937         {'Лондонской бирже металлов (LME)'}
Q108398998      {'Открытие Капитал'}
Q1141123        {'Роснефти', 'Роснефтью', 'РОСНЕФТЬ', 'Роснефть'}
Q940518         {'МАГНИТ', 'ПАО «Магнит»', 'МАГНИТА', 'Магнита', 'Магнит'}
Q379271         {'Интерфаксу', 'ИНТЕРФАКС', 'Интерфаксом'}
Q7907607        {'ВТБ капитала', 'ВТБ Капитала', 'ВТБ-Капитал', 'ВТБ КАПИТАЛ'}
Q1355823        {'FTSE'}
Q294508         {'АЛРОСА', 'Алросы', 'Алроса'}
Q4400200        {'Русагро'}
Q2369311        {'РуссНефти'}
Q1616858        {'Татнефть', 'Татнефти'}
Q3063197        {'ФСК ЕЭС'}
Q2116312        {'Транснефть', 'Транснефти'}
Q1884500        {'ММК', 'ММК\n'}
Q4304175        {'МКБ\n'}
Q4327204        {'ОТКРЫТИЕМ'}
Q1549389        {'ВТБ', 'ВТБ\n'}
Q102673         {'Газпром', 'Газпромом', 'Газпрома'}
Q182477         {'NVidia'}
Q173395         {'Cisco Systems'}
Q483551         {'Wal-Mart'}
Q3656098        {'Атон'}
Q1809133        {'ТМК', 'ПАО "Трубная металлургическа

## 3. FAISS

In [1]:
import numpy as np
import faiss
import torch

### 3.1 Merge vector list into tensor

In [22]:
# We will need to preserve this mapping, because FAISS supports only integer indexing
vector_index_to_qid = {}
vectors_tensor = []

index = 0
for qid, vectors in qid_to_vector_list.items():
    for vector in vectors:
        vector_index_to_qid[index] = qid
        index += 1
        vectors_tensor.append(vector)
        
vectors_tensor = torch.vstack(vectors_tensor)
vectors_tensor.shape

torch.Size([222, 768])

### 3.2 Build an actual FAISS index

I use inner product index, cuz I did normalize my vectors

In [31]:
d = 768  # dimension 
        
faiss_index = faiss.IndexFlatIP(d)   # build the index
print(faiss_index.is_trained)
faiss_index.add(vectors_tensor.numpy()) # add vectors to the index
print(faiss_index.ntotal)

True
222


### 3.3 Make a test request

In [59]:
# making a query vector
company_name = "Газпром"
embedder.embed_text(company_name)
emb = embedder.get_char_span_vectors([(0, len(company_name))])[0]
emb = emb.view((1, -1))
emb.shape

torch.Size([1, 768])

In [64]:
emb = torch.vstack([emb, emb]).numpy()
emb.shape

(2, 768)

In [60]:
# Helper function evaluate index search output
from collections import Counter
def get_majority(indices):
    qids = [vector_index_to_qid[index] for index in indices]
    counts = Counter(qids)
    return counts
    

In [66]:
k = 4 # we want to see 4 nearest neighbors
# I - indices of neighbours
# D - distances to them
D, I = faiss_index.search(emb, k)     # actual search

In [67]:
print(I) # neighbors 
print(D)

[[86 83 87 84]
 [86 83 87 84]]
[[52.420643 46.50787  36.403435 30.952461]
 [52.420643 46.50787  36.403435 30.952461]]


In [69]:
list(I)

[array([86, 83, 87, 84]), array([86, 83, 87, 84])]

In [68]:
type(I)

numpy.ndarray

In [37]:
get_majority(I[0])

Counter({'Q102673': 4})

In [38]:
for qid in get_majority(I[0]).keys():
    print(qid_to_alias[qid])

{'Газпром', 'Газпромом', 'Газпрома'}


### 3.4 Save index to disk

In [None]:
faiss.write_index(faiss_index, "faiss_index.binary")

In [5]:
faiss_index2 = faiss.read_index("faiss_index.binary")

In [None]:
faiss_index2

In [None]:
faiss_index

In [None]:
dir(faiss_index)

In [71]:
Counter(["a", "a", "b"]).most_common(1)

[('a', 2)]

## 4. Adding new vectors

In [6]:
faiss_index2

<faiss.swigfaiss.IndexFlat; proxy of <Swig Object of type 'faiss::IndexFlat *' at 0x7f12c65b60c0> >

In [22]:
faiss_index2.add(np.random.random((1, 768)).astype('float32'))

In [19]:
faiss_index2.ntotal

229

In [20]:
faiss.write_index(faiss_index2, "faiss2_index.binary")
faiss_index2 = faiss.read_index("faiss2_index.binary")

In [21]:
faiss_index2.ntotal

229