# Entity recognition and embedding base building

In [1]:
%cd ..

/home/codeholder/code/python-playground/app_noisemon


My baseline will be comparing embeddings of ORG chunks and company names.

The large spacy model already provides embedding functionality, but those are static.
I'm gonna need a transformer.
For the embeddings matching I will use faiss

#### Pipeline
1. Load labelstudio'ed texts
2. For each labeled document
    - transform document to embedding matrix
    - get entity vectors
    - store it into index

## 1. Import data

In [10]:
import json
import os
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from functools import partial as p
# from lenses import lens



from wasabi import Printer
msg = Printer()

In [6]:
data_path = Path("./data/05-labeled/project-10-at-2021-10-02-22-43-62e3404c.json")

In [7]:
data = json.loads(data_path.read_text())

In [8]:
len(data)

79

## 2. Prepare embedding functions

In [18]:
import spacy_alignments as tokenizations

In [11]:
from typing import List

In [17]:
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer_ = AutoTokenizer.from_pretrained("cointegrated/LaBSE-en-ru")
model = AutoModel.from_pretrained("cointegrated/LaBSE-en-ru")

sentences = ["Hello World", "Привет Мир"]
tokenizer = p(tokenizer_, padding=True, truncation=False, return_tensors="pt")

def embed_documents(encoded_input):
    
    with torch.no_grad():
        model_output = model(**encoded_input)
    embeddings = model_output.last_hidden_state
    embeddings = torch.nn.functional.normalize(embeddings)
    return embeddings

Some weights of the model checkpoint at cointegrated/LaBSE-en-ru were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
model_output.last_hidden_state.shape

torch.Size([2, 4, 768])

In [57]:
# {'value': {'start': 0,
#   'end': 17,
#   'text': 'ООО «Венера-плюс»',
#   'labels': ['Organization']},
#  'id': 'BDbfvhbqc9',
#  'from_name': 'label',
#  'to_name': 'text',
#  'type': 'labels'},
# {'value': {'start': 350,
#   'end': 360,
#   'text': '"Заказчик"',
#   'labels': ['Role']},
#  'id': 'ytWWFWYrvX',
#  'from_name': 'label',
#  'to_name': 'text',
#  'type': 'labels'},
# {'from_id': 'BDbfvhbqc9',
#  'to_id': '8WJgOdX7zR',
#  'type': 'relation',
#  'direction': 'right',
#  'labels': []},
qid_to_vector_list = defaultdict(list)
qid_to_alias = defaultdict(set)
def ann_to_ent(labelstudio):
    id_to_qid_name_pair = defaultdict(dict)
    text = labelstudio["data"]["text"]
    encoded_input = tokenizer([text])
    wordpieces = tokenizer_.batch_decode(encoded_input.input_ids[0])
    embedding_alignment, _ = tokenizations.get_alignments(text, wordpieces)
    emb = embed_documents(encoded_input)
    emb = emb / np.linalg.norm(emb)
    emb = emb.squeeze()


    # 1. Matching named entities with QIDs
    for chunk in labelstudio["annotations"][0]["result"]:
        if chunk["from_name"] == "ner":
            id_to_qid_name_pair[chunk["id"]]["text"] = chunk["value"]["text"]
        if chunk["from_name"] == "entity":
            id_to_qid_name_pair[chunk["id"]]["qid"] = chunk["value"]["text"][0]

    # 2. Create ents and assign kb_id to ents

    for chunk in labelstudio["annotations"][0]["result"]:
        if chunk["from_name"] == "ner":
            # 1. Check if entity is labeled
            QID = id_to_qid_name_pair[chunk["id"]].get("qid", None)
            if not QID:
                print(f"{id_to_qid_name_pair[chunk['id']]['text']} has no matching QID")
                continue
            try:
                # 2. Get entity vector
                entity_start, entity_end = chunk["value"]["start"], chunk["value"]["end"]
                entity = text[entity_start:entity_end]
                ent_idxs = [idx for list_of_indices in embedding_alignment[entity_start: entity_end] for idx in list_of_indices]
                ent_idxs = sorted(set(ent_idxs))
#                 print(ent_idxs, embedding_alignment)
                entity_embeddings = emb[ent_idxs]
                entity_vector = torch.mean(entity_embeddings, dim=0)
                qid_to_vector_list[QID].append(entity_vector)
                qid_to_alias[QID].add(entity)
            except:
                msg.fail("Result:", chunk)
                msg.fail("Doc:", doc)
                msg.fail("----------")
                continue


In [58]:
for labelstudio in data:
    ann_to_ent(labelstudio)

In [60]:
sum([len(v) for v in qid_to_vector_list.values()])

222

In [65]:
for qid, aliases in qid_to_alias.items():
    print(f"{qid:<15}", aliases)

Q1642605        {'Русал', 'РусАл', 'РУСАЛа'}
Q952937         {'Лондонской бирже металлов (LME)'}
Q108398998      {'Открытие Капитал'}
Q1141123        {'Роснефти', 'Роснефтью', 'Роснефть', 'РОСНЕФТЬ'}
Q940518         {'Магнита', 'Магнит', 'МАГНИТ', 'МАГНИТА', 'ПАО «Магнит»'}
Q379271         {'Интерфаксом', 'ИНТЕРФАКС', 'Интерфаксу'}
Q7907607        {'ВТБ-Капитал', 'ВТБ КАПИТАЛ', 'ВТБ Капитала', 'ВТБ капитала'}
Q1355823        {'FTSE'}
Q294508         {'Алросы', 'АЛРОСА', 'Алроса'}
Q4400200        {'Русагро'}
Q2369311        {'РуссНефти'}
Q1616858        {'Татнефть', 'Татнефти'}
Q3063197        {'ФСК ЕЭС'}
Q2116312        {'Транснефти', 'Транснефть'}
Q1884500        {'ММК', 'ММК\n'}
Q4304175        {'МКБ\n'}
Q4327204        {'ОТКРЫТИЕМ'}
Q1549389        {'ВТБ\n', 'ВТБ'}
Q102673         {'Газпромом', 'Газпром', 'Газпрома'}
Q182477         {'NVidia'}
Q173395         {'Cisco Systems'}
Q483551         {'Wal-Mart'}
Q3656098        {'Атон'}
Q1809133        {'ТМК', 'ПАО "Трубная металлургическа

## FAISS

In [70]:
idx, (index_pair, a) = next(enumerate(qid_to_vector_list.items()))

In [71]:
index_pair

'Q1642605'

In [101]:
vector_index_to_qid = {}
vectors_tensor = []

index = 0
for qid, vectors in qid_to_vector_list.items():
    for vector in vectors:
        vector_index_to_qid[index] = qid
        index += 1
        vectors_tensor.append(vector)
        
vectors_tensor = torch.vstack(vectors_tensor)
vectors_tensor.shape

torch.Size([222, 768])

In [105]:
# making a query vector
encoded_input = tokenizer(["Cisco Systems"])
emb = embed_documents(encoded_input)
emb = emb.squeeze()
print(emb.shape)
emb = emb[[1,2]]
emb = torch.mean(emb, dim=0)
emb = emb.view((1, -1))
emb.shape

torch.Size([4, 768])


torch.Size([1, 768])

In [114]:
dir(faiss)

['AdditiveQuantizer',
 'AlignedTableFloat32',
 'AlignedTableFloat32_round_capacity',
 'AlignedTableUint16',
 'AlignedTableUint16_round_capacity',
 'AlignedTableUint8',
 'AlignedTableUint8_round_capacity',
 'AlignedTable_to_array',
 'ArrayInvertedLists',
 'AutoTuneCriterion',
 'BitstringReader',
 'BitstringWriter',
 'BlockInvertedLists',
 'BufferList',
 'BufferedIOReader',
 'BufferedIOWriter',
 'ByteVector',
 'ByteVectorVector',
 'CMax_float_partition_fuzzy',
 'CMax_uint16_partition_fuzzy',
 'CMin_float_partition_fuzzy',
 'CMin_uint16_partition_fuzzy',
 'CenteringTransform',
 'CharVector',
 'Cloner',
 'Clustering',
 'ClusteringIterationStats',
 'ClusteringIterationStatsVector',
 'ClusteringParameters',
 'ConcatenatedInvertedLists',
 'DirectMap',
 'DirectMapAdd',
 'DistanceComputer',
 'DoubleVector',
 'EnumeratedVectors',
 'FAISS_VERSION_MAJOR',
 'FAISS_VERSION_MINOR',
 'FAISS_VERSION_PATCH',
 'FastScanStats',
 'FileIOReader',
 'FileIOWriter',
 'Float32Vector',
 'Float32VectorVector',
 '

In [115]:
import numpy as np
import faiss 

d = 768                           # dimension
               
faiss_index = faiss.IndexFlatIP(d)   # build the index
print(faiss_index.is_trained)
faiss_index.add(vectors_tensor.numpy())                  # add vectors to the index
print(faiss_index.ntotal)

True
222


In [116]:
k = 4                        # we want to see 4 nearest neighbors

D, I = faiss_index.search(emb.numpy(), k)     # actual search

In [117]:
print(I)                   # neighbors 

[[ 98 216 221  99]]


In [118]:
from collections import Counter
def get_majority(indices):
    qids = [vector_index_to_qid[index] for index in indices]
    counts = Counter(qids)
    return counts
    

In [119]:
get_majority(I[0])

Counter({'Q173395': 1, 'Q487907': 1, 'Q131723': 1, 'Q483551': 1})

In [120]:
D

array([[1.4481782 , 0.9453915 , 0.92165357, 0.7970917 ]], dtype=float32)

In [121]:
for qid in get_majority(I[0]).keys():
    print(qid_to_alias[qid])

{'Cisco Systems'}
{'Bank of America'}
{'биткоин'}
{'Wal-Mart'}


In [126]:
faiss.write_index(faiss_index, "faiss_index.binary")

In [127]:
faiss_index2 = faiss.read_index("faiss_index.binary")

In [128]:
faiss_index2

<faiss.swigfaiss.IndexFlat; proxy of <Swig Object of type 'faiss::IndexFlat *' at 0x7fe0e7330ed0> >

In [125]:
faiss_index

<faiss.swigfaiss.IndexFlatIP; proxy of <Swig Object of type 'faiss::IndexFlatIP *' at 0x7fe0e6add720> >

In [113]:
dir(faiss_index)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__swig_destroy__',
 '__weakref__',
 'add',
 'add_c',
 'add_with_ids',
 'add_with_ids_c',
 'assign',
 'assign_c',
 'compute_distance_subset',
 'compute_residual',
 'compute_residual_n',
 'd',
 'get_distance_computer',
 'is_trained',
 'metric_arg',
 'metric_type',
 'ntotal',
 'range_search',
 'range_search_c',
 'reconstruct',
 'reconstruct_c',
 'reconstruct_n',
 'reconstruct_n_c',
 'remove_ids',
 'remove_ids_c',
 'reset',
 'sa_code_size',
 'sa_decode',
 'sa_decode_c',
 'sa_encode',
 'sa_encode_c',
 'search',
 'search_and_reconstruct',
 'search_and_reconstruct_c',
 'search_c',
 'this',
 'thisown',
 'train'