In [1]:
from IPython.core.display import display, HTML

%load_ext autoreload

In [41]:
base_url = "./data"
wiki_version = "wiki_2019"

## Helper functions

In [None]:
!pip install pympler

In [36]:
from pympler import asizeof

## Load data

In [37]:
import cProfile
pr = cProfile.Profile()
pr.enable()

from flair.data import Sentence
from flair.models import SequenceTagger
import time, json, gzip
import syntok.segmenter as segmenter

tagger = SequenceTagger.load("ner-fast")
in_file = 'msmarco_doc_00.gz'
data_path = f'/Users/vanhulsm/Desktop/msmarco_v2_doc_sample/msmarcov2/{in_file}'

sample_size = 50 # Change to bigger number for better estimate
batch_size = 10 # Change to value that batch fits ~60% avg? GPU mem

batch = []
lines = []
with gzip.open(data_path, 'rt') as f:
    for i, line in enumerate(f):
        if (i+2) % sample_size == 0:
            break
        document = json.loads(line)
        url, title, headings, body, docid = (document['url'], document['title'],
            document['headings'], document['body'], document['docid'])
        for paragraph in segmenter.process(body):
            for sentence in paragraph:
                s = ' '.join([sen.value for sen in sentence])
                lines.append(s)

docs = {f'doc_{i}': [' '.join(lines[i*5:((i+1)*5)]), []] for i in range(100, 250)}# = " ".join(lines[:100])

2021-12-19 21:05:20,077 --------------------------------------------------------------------------------
2021-12-19 21:05:20,079 The model key 'ner-fast' now maps to 'https://huggingface.co/flair/ner-english-fast' on the HuggingFace ModelHub
2021-12-19 21:05:20,081  - The most current version of the model is automatically downloaded from there.
2021-12-19 21:05:20,082  - (you can alternatively manually download the original model at https://nlp.informatik.hu-berlin.de/resources/models/ner-fast/en-ner-fast-conll03-v0.4.pt)
2021-12-19 21:05:20,084 --------------------------------------------------------------------------------
2021-12-19 21:05:20,790 loading file /Users/vanhulsm/.flair/models/ner-english-fast/4c58e7191ff952c030b82db25b3694b58800b0e722ff15427f527e1631ed6142.e13c7c4664ffe2bbfa8f1f5375bd0dced866b8c1dd7ff89a6d705518abf0a611


KeyboardInterrupt: 

In [42]:
%autoreload 2

from REL.training_datasets import TrainingEvaluationDatasets

datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"]

server = True
docs = {}
for i, doc in enumerate(datasets):
    sentences = []
    for x in datasets[doc]:
        if x["sentence"] not in sentences:
            sentences.append(x["sentence"])
    text = ". ".join([x for x in sentences])
    docs[i] = [text, []]
    if len(docs) == 150:
        print("length docs is 50.")
        print("====================")
        break

Loading aida_train
Loading aida_testA
Loading aida_testB
Loading wned-ace2004
Loading wned-aquaint
Loading wned-clueweb
Loading wned-msnbc
Loading wned-wikipedia
length docs is 50.


# Mention detection
Bottlenecks:
1. SQLite3 is too slow, try array-based approach.
2. Flair is too slow:
    - Try different model (e.g. replace LSTM with CNN+LSTM), can also stop using Flair.
    - Convert to ONNX
    - ... 
3. Wikipedia2Vec is too time-consuming to train, should be quite easy to replace with PyTorch and sparse embs.
    - Data format is quite tedious to use though.
    - HashEmbeddings? Result in ~110.000.900 parameters, which in turn would result in ~419MB memory usage when using float32.
    - Alternatively, we can tokenize entities and share parameters across different entities.

## Process intermediate results Flair

In [43]:
from REL.ner import load_flair_ner

tagger_ner = load_flair_ner("ner-fast")

2021-12-19 21:06:21,414 --------------------------------------------------------------------------------
2021-12-19 21:06:21,417 The model key 'ner-fast' now maps to 'https://huggingface.co/flair/ner-english-fast' on the HuggingFace ModelHub
2021-12-19 21:06:21,417  - The most current version of the model is automatically downloaded from there.
2021-12-19 21:06:21,418  - (you can alternatively manually download the original model at https://nlp.informatik.hu-berlin.de/resources/models/ner-fast/en-ner-fast-conll03-v0.4.pt)
2021-12-19 21:06:21,419 --------------------------------------------------------------------------------
2021-12-19 21:06:21,956 loading file /Users/vanhulsm/.flair/models/ner-english-fast/4c58e7191ff952c030b82db25b3694b58800b0e722ff15427f527e1631ed6142.e13c7c4664ffe2bbfa8f1f5375bd0dced866b8c1dd7ff89a6d705518abf0a611


In [44]:
%autoreload 2

import REL
from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection
from REL.utils import process_results

mention_detection = MentionDetection(base_url, wiki_version)

In [45]:
(dataset, 
 dataset_sentences_raw, 
 processed_sentences, 
 splits) = mention_detection.process_and_predict_flair(docs, tagger_ner)

In [46]:
import pickle
# Store intermediate results
with open('./data/intermediate_flair_res.pkl', 'wb') as f:
    pickle.dump([dataset, dataset_sentences_raw, processed_sentences, splits], f)

## Process predictions

In [47]:
import numpy as np 

In [48]:
import pickle
# Store intermediate results
with open('./data/intermediate_flair_res.pkl', 'rb') as f:
    dataset, dataset_sentences_raw, processed_sentences, splits = pickle.load(f)

In [49]:
%autoreload 2

import REL
from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection
from REL.utils import process_results

mention_detection = MentionDetection(base_url, wiki_version)
results, total_ment = mention_detection.process_predictions_flair(dataset, dataset_sentences_raw, processed_sentences, splits)

100%|██████████| 150/150 [00:12<00:00, 12.46it/s]


In [50]:
display(HTML("<style>.container { width:90% !important; }</style>"))
print(mention_detection.profiler.summary())

Profiler Report

Action            	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
-----------------------------------------------------------------------------------------------------------------
Total             	|  -              	|_              	|  12.117         	|  100 %          	|
-----------------------------------------------------------------------------------------------------------------
get_candidates    	|  0.0022919      	|3198           	|  7.3296         	|  60.5 %         	|
preprocess_mention	|  0.00081884     	|3198           	|  2.6187         	|  21.6 %         	|
get_ctxt          	|  0.00059791     	|2861           	|  1.7106         	|  14.1 %         	|



In [51]:
# Extract all mentions
mentions = []
for k, v in results.items():
    for ment in v:
        mentions.append(ment['mention'])

In [52]:
from tqdm import tqdm

# Benchmark preprocess_mention (note that these mentions are already processed)
mention_candidates = {}

import time
start = time.monotonic()
for m in tqdm(mentions):
    cands = mention_detection.get_candidates(m)
    mention_candidates[m] = cands
time_taken = time.monotonic()-start

print(f'Total time taken {time_taken:.3f}s, average time taken {time_taken/len(mentions):.10f}')

100%|██████████| 2861/2861 [00:06<00:00, 418.39it/s]

Total time taken 6.840s, average time taken 0.0023907366





In [11]:
asizeof.asizeof(np.array(['aa', 'aa', 'aa', 'aa']))

152

In [53]:
# Get potential increase when using dict
start = time.monotonic()

for m in tqdm(mentions):
    cands = mention_candidates[m]
time_taken = time.monotonic()-start

print(f'Total time taken {time_taken:.3f}s, average time taken {time_taken/len(mentions):.10f}')

100%|██████████| 2861/2861 [00:00<00:00, 953963.25it/s]

Total time taken 0.005s, average time taken 0.0000018032





In [14]:
print(f'Approx {asizeof.asizeof(mention_candidates) // 1024 }KBs memory usage for {len(mention_candidates)} mentions and their candidates')

Approx 8151KBs memory usage for 1309 mentions and their candidates


In [15]:
print(f'Approx {asizeof.asizeof({np.int16(i) for i, (k, v) in enumerate(mention_candidates.items())}) // 1024 }KBs memory usage for {len(mention_candidates)} mentions and their candidates')

Approx 169KBs memory usage for 1309 mentions and their candidates


# Efficiently represent mentions and their entities

## Preliminary index building using Python native data structures

In [28]:
import collections
from tqdm import tqdm

mention_entity_references = {}
all_entities = {}
ent_probability = {}

max_length_ment = -1

idx_ment = 0
idx_ent = 0

def process_entities(entities,
                     all_entities, 
                     ent_probability):
    processed_entities = []
    for entity, prob in entities:
        # If value not exists, replace with index
        ent_index = all_entities.get(entity, -1)
        prob_index = ent_probability.get(prob, -1)
        
        if prob_index == -1:
            prob_index = len(ent_probability)
            ent_probability[prob] = prob_index
        
        if ent_index == -1:
            ent_index = len(all_entities)
            all_entities[entity] = ent_index

        processed_entities.append([ent_index, 
                                   prob_index])
    return (processed_entities, 
            all_entities, 
            ent_probability)

for ment, entities in tqdm(mention_candidates.items()):
    
    # Mentions of length refers to specific mention that will be ordered    
    processed_entities, all_entities, ent_probability = process_entities(entities, 
                                                                        all_entities,
                                                                        ent_probability 
                                                                        )
    
    n_entities = len(processed_entities)
    mention_entity_references[ment] = np.array(processed_entities, dtype=np.int32)

100%|██████████| 1309/1309 [00:00<00:00, 22223.70it/s]


330264

## Convert to Numpy array objects

In [31]:
tot_size = 0

# Convert mentions and ment/entity references.
tot_size += asizeof.asizeof(mention_entity_references)

# Convert entity for index lookup.
arr_all_entities = np.array(list(all_entities.keys()), dtype=object)

tot_size += asizeof.asizeof(list(all_entities.keys()))

# Convert array probabilities
arr_probs = np.array(list(ent_probability.keys()), dtype=np.float16)

tot_size += asizeof.asizeof(arr_probs)

# Number of KBs
tot_size // 1024

3841

## Look-up speed test

In [None]:
# sqlite: mean = 0.0016143512, total time = 4.598s



In [19]:
# Get potential increase when using dict
'''
TODO:
1. Create .py file and class, convert to pickled files.
2. Measure RAM usage as .py file when loading all into memory.
3. Report to RU with next steps <-- probably makes sense to keep mention_candidates_lookup in memory and the remainder with their indexes
in a Redis table or something.

The indexes refer to both entity_embeddings, entity_names and their probabilities, so can be done in one bulk lookup.

But as `mention_candidates_lookup` remains in-memory, preprocess_mention remains untouched.

'''

mention_candidates_lookup = collections.defaultdict(dict)
times = []

for m in tqdm(mentions):
    start = time.monotonic()
    
    # Lookup entities
    entity_identifiers = mention_entity_references[m]
    
    entity_strings = arr_all_entities[entity_identifiers[:,0]]
    entity_probs = arr_probs[entity_identifiers[:,1]]
    
    result = np.stack((entity_strings, entity_probs)).T
    
    times.append(time.monotonic()-start)
        
    # This is solely done for validity checks, so excluded from time test
    for ent_str, ent_prob in result:
        mention_candidates_lookup[m][ent_str] = ent_prob

print(
    f"""Time taken {sum(times):.10f}s, 
    mean per ment {np.mean(times):.10f}, 
    std per ment {np.std(times):.10f}, 
    max per ment {np.max(times):.10f}, 
    min per ment {np.min(times):.10f}"""
)

# Compare between array-based lookups and original
for mention, entities in tqdm(mention_candidates.items()):
    lookup = mention_candidates_lookup[mention]
    
    entities_original = []
    assert len(lookup) == len(entities), 'Unequal length of # of entities'
    
    for e, p in entities:
        assert lookup[e] == np.float16(p), 'Probability for entity not equal'
        entities_original.append(e)
    
    lookup_entities = set(list(lookup.keys())).symmetric_difference(entities_original)
    assert len(lookup_entities) == 0, 'Entities found that are not present in one anothers sets'
print('All tests passed, array is equal to dictionary.')

100%|██████████| 2848/2848 [00:00<00:00, 16921.74it/s]
100%|██████████| 1309/1309 [00:00<00:00, 14195.11it/s]

Time taken 0.0560099000s, 
    mean per ment 0.0000196664, 
    std per ment 0.0000256856, 
    max per ment 0.0009202860, 
    min per ment 0.0000104540
All tests passed, array is equal to dictionary.





# Clean class

In [5]:
import sqlite3

In [6]:
def process_entities(entities,
                     all_entities, 
                     ent_probability):
    processed_entities = []
    for entity, prob in entities:
        # If value not exists, replace with index
        ent_index = all_entities.get(entity, -1)
        prob_index = ent_probability.get(prob, -1)
        
        if prob_index == -1:
            prob_index = len(ent_probability)
            ent_probability[prob] = prob_index
        
        if ent_index == -1:
            ent_index = len(all_entities)
            all_entities[entity] = ent_index

        processed_entities.append([ent_index, 
                                   prob_index])
    return (processed_entities, 
            all_entities, 
            ent_probability)

In [8]:
%autoreload 2

import pickle

import numpy as np

import REL
from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection
from REL.utils import process_results

base_url = "./data"
wiki_version = "wiki_2019"

import collections
from tqdm import tqdm

mention_entity_references = {}
all_entities = {}
ent_probability = {}

max_length_ment = -1

idx_ment = 0
idx_ent = 0

mention_detection = MentionDetection(base_url, wiki_version)

c = mention_detection.wiki_db.db.cursor()
n_rows= c.execute("select COUNT(word) from wiki").fetchone()[0]

c = mention_detection.wiki_db.db.cursor()
c.execute("select word, p_e_m from wiki")
cnt = 0
batch_size = 1_000
for i in tqdm(range((n_rows//batch_size)+1)):
    batch = c.fetchmany(batch_size)
    cnt += 1
    
    for ment, ent_binary in batch:
        entities = mention_detection.wiki_db.binary_to_dict(ent_binary)
        
        # Mentions of length refers to specific mention that will be ordered    
        processed_entities, all_entities, ent_probability = process_entities(entities, 
                                                                            all_entities,
                                                                            ent_probability 
                                                                            )

        n_entities = len(processed_entities)
        mention_entity_references[ment] = np.array(processed_entities, dtype=np.int32)
    
    if not batch:
        break
        
    if i % 250 == 0:
        # Convert and store intermediately so that we can restart somewhere if needed.
        ## Convert entity for index lookup.
        arr_all_entities = np.array(list(all_entities.keys()), dtype=object)

        ## Convert array probabilities
        arr_probs = np.array(list(ent_probability.keys()), dtype=np.float16)

        with open(f'./data/ment_cands_lookup_{i}.pkl', 'wb') as f:
            pickle.dump(mention_entity_references, f)

        np.save('./data/arr_entities', arr_all_entities)
        np.save('./data/arr_mention_entity_probs', arr_probs)
        
        mention_entity_references = {}

100%|██████████| 23203/23203 [13:47<00:00, 28.04it/s] 


In [9]:
# Convert entity for index lookup.
arr_all_entities = np.array(list(all_entities.keys()), dtype=object)

# Convert array probabilities
arr_probs = np.array(list(ent_probability.keys()), dtype=np.float16)

with open(f'./data/ment_cands_lookup_{i}.pkl', 'wb') as f:
    pickle.dump(mention_entity_references, f)

np.save('./data/arr_entities', arr_all_entities)
np.save('./data/arr_mention_entity_probs', arr_probs)

In [10]:
tot_size = 0

# Convert mentions and ment/entity references.
tot_size += asizeof.asizeof(mention_entity_references)
tot_size += asizeof.asizeof(arr_probs)
tot_size += asizeof.asizeof(list(all_entities.keys()))

# Number of KBs
tot_size // 1024

417168

In [14]:
417168 // 1024

407

100%|██████████| 102/102 [00:57<00:00,  1.78it/s]


In [24]:
total_bytes = 0
for v in tqdm(ment_cands_lookup.values()):
    total_bytes += v.nbytes
    
total_bytes // 1024

100%|██████████| 23202365/23202365 [00:13<00:00, 1775728.24it/s]


239242

In [27]:
len(mention_entity_references) / len(arr_all_entities)

0.045652653959389736

# Profile 
- When loading `arr_all_entities` and `arr_probs`, memory usage is ~464MB. 
- When also loading `mention_entity_references`, it goes up to ~6800MB, meaning ~6.3GB is used for 
`mention_entity_references`.

In [None]:
# Open Python terminal and test memory usage of:
# 1. All three in memory
# 2. Only `mention_entity_references` in memory


# Test with: 
# top -pid + ps



import os
from tqdm import tqdm
import numpy as np
from natsort import natsorted
files = natsorted(os.listdir('./data/'))

ment_cands_lookup = {}
for dirr in tqdm(files):
    if 'ment_cands_lookup' in dirr:
        with open(f'./data/{dirr}', 'rb') as f:
            tmp = pickle.load(f)
        for k, v in tmp.items():
            assert k not in ment_cands_lookup, 'Not unique.'
            ment_cands_lookup[k] = v

arr_all_entities = np.load('./data/arr_entities')
arr_probs = np.load('./data/arr_mention_entity_probs')

In [30]:
from tqdm import tqdm
import time

In [59]:
len(mentions)

2861

In [54]:

# Benchmark preprocess_mention (note that these mentions are already processed)
mention_candidates = {}

times = []

for m in tqdm(mentions):
    start = time.monotonic()
    cands = mention_detection.get_candidates(m)
    assert len(cands) != 0, 'wut'
    times.append(time.monotonic()-start)

print(
    f"""Time taken {sum(times):.10f}s, 
    mean per ment {np.mean(times):.10f}, 
    std per ment {np.std(times):.10f}, 
    max per ment {np.max(times):.10f}, 
    min per ment {np.min(times):.10f}"""
)

100%|██████████| 2861/2861 [00:06<00:00, 431.80it/s]

Time taken 6.5767760050s, 
    mean per ment 0.0022987683, 
    std per ment 0.0029809774, 
    max per ment 0.0401902410, 
    min per ment 0.0000630020





In [69]:
0.0022987683 / 0.0000292318

78.63930035098763

In [66]:

# Benchmark preprocess_mention (note that these mentions are already processed)
mention_candidates = {}

times = []

for m in tqdm(mentions):
    start = time.monotonic()
    
    # Lookup entities
    entity_identifiers = ment_cands_lookup[m]
    entity_strings = arr_all_entities[entity_identifiers[:,0]]
    entity_probs = arr_probs[entity_identifiers[:,1]]
    
    times.append(time.monotonic()-start)
    
    
    #result = np.stack((entity_strings, entity_probs)).T
    
    
print(
    f"""Time taken {sum(times):.10f}s, 
    mean per ment {np.mean(times):.10f}, 
    std per ment {np.std(times):.10f}, 
    max per ment {np.max(times):.10f}, 
    min per ment {np.min(times):.10f}"""
)

100%|██████████| 2861/2861 [00:00<00:00, 31799.11it/s]

Time taken 0.0836321040s, 
    mean per ment 0.0000292318, 
    std per ment 0.0000871654, 
    max per ment 0.0042386290, 
    min per ment 0.0000046650



