## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

https://www.informatik.uni-leipzig.de/~saeedi/NCVoters_Readme.txt

```
5 party:
--------- 
5 sources. 
Each source 1000,000 entities.
There is one file per source, so totally 5 files
****************************************************
****************************************************
****************************************************
Fields:
---------- 
recId: entites with the same recId refer to the same entity.
givenname: 
surname: 
post code: 
suburb:
```

In [6]:
import glob
import csv
from tqdm.auto import tqdm

current_row_id = 0
row_dict = {}
rows_total = 5000000
cluster_attr = 'recid'

with tqdm(total=rows_total) as pbar:
    for filename in glob.glob(f'{home_dir}/Downloads/5Party-ocp20/*.csv'):
        with open(filename) as f:
            for row in csv.DictReader(f):
                row['id'] = current_row_id
                row[cluster_attr] = int(row[cluster_attr])  # convert cluster_attr to int
                row_dict[current_row_id] = row
                current_row_id += 1
                pbar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000000.0), HTML(value='')))




In [7]:
row_dict[0]

{'recid': 7852009,
 'givenname': 'kadelyn',
 'surname': 'gragnani',
 'suburb': 'waxhaw',
 'postcode': '28|73',
 'id': 0}

## Preprocess

In [8]:
attr_list = ['givenname', 'surname', 'suburb', 'postcode']

In [9]:
import unidecode

def clean_str(s):
    return unidecode.unidecode(s).lower().strip()[:30]

for row in tqdm(row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000000.0), HTML(value='')))




## Init Data Module

In [10]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [11]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [12]:
attr_info_dict = {
    'givenname': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'surname': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'suburb': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'postcode': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [13]:
from entity_embed import build_row_numericalizer

row_numericalizer = build_row_numericalizer(attr_info_dict, row_dict=row_dict)
row_numericalizer.attr_info_dict

13:25:19 INFO:For attr='givenname', computing actual alphabet and max_str_len
13:25:22 INFO:For attr='givenname', using actual_max_str_len=16
13:25:22 INFO:For attr='surname', computing actual alphabet and max_str_len
13:25:26 INFO:actual_max_str_len=21 must be pair to enable NN pooling. Updating to 22
13:25:26 INFO:For attr='surname', using actual_max_str_len=22
13:25:26 INFO:For attr='suburb', computing actual alphabet and max_str_len
13:25:30 INFO:actual_max_str_len=21 must be pair to enable NN pooling. Updating to 22
13:25:30 INFO:For attr='suburb', using actual_max_str_len=22
13:25:30 INFO:For attr='postcode', computing actual alphabet and max_str_len
13:25:33 INFO:actual_max_str_len=9 must be pair to enable NN pooling. Updating to 10
13:25:33 INFO:For attr='postcode', using actual_max_str_len=10


{'givenname': NumericalizeInfo(field_type=<FieldType.STRING: 'string'>, tokenizer=None, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=16, vocab=None),
 'surname': NumericalizeInfo(field_type=<FieldType.STRING: 'string'>, tokenizer=None, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=22, vocab=None),
 'suburb': NumericalizeInfo(field_type=<FieldType.STRI

In [14]:
from entity_embed import DeduplicationDataModule

train_cluster_len = 1000
valid_cluster_len = 1000
test_cluster_len = 200_000
datamodule = DeduplicationDataModule(
    row_dict=row_dict,
    cluster_attr=cluster_attr,
    row_numericalizer=row_numericalizer,
    pos_pair_batch_size=45,
    neg_pair_batch_size=1225,
    row_batch_size=16,
    train_cluster_len=train_cluster_len,
    valid_cluster_len=valid_cluster_len,
    test_cluster_len=test_cluster_len,
    only_plural_clusters=True,
    log_empty_vals=False,
    random_seed=random_seed
)

## Training

In [15]:
from entity_embed import EntityEmbed

ann_k = 10
model = EntityEmbed(
    datamodule,
    ann_k=ann_k,
)

In [16]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 50
early_stop_callback = EarlyStopping(
   monitor='valid_recall_at_0.5',
   min_delta=0.00,
   patience=10,
   verbose=True,
   mode='max'
)
tb_log_dir = 'tb_logs'
tb_name = 'voters'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name)
)

13:25:33 INFO:GPU available: True, used: True
13:25:33 INFO:TPU available: None, using: 0 TPU cores
13:25:33 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [17]:
trainer.fit(model, datamodule)

13:25:49 INFO:Train pair count: 6679
13:25:49 INFO:Valid pair count: 6645
13:25:49 INFO:Test pair count: 1331364
13:25:52 INFO:
  | Name        | Type           | Params
-----------------------------------------------
0 | blocker_net | BlockerNet     | 2.5 M 
1 | losser      | NTXentLoss     | 0     
2 | miner       | BatchHardMiner | 0     
-----------------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [18]:
model.blocker_net.get_signature_weights()

{'givenname': 0.23608173429965973,
 'surname': 0.31067323684692383,
 'suburb': 0.23410716652870178,
 'postcode': 0.21913790702819824}

In [19]:
import gc
del trainer
gc.collect()

327

## Testing manually 

In [20]:
datamodule.setup(stage='test')

13:27:43 INFO:Train pair count: 6679
13:27:43 INFO:Valid pair count: 6645
13:27:43 INFO:Test pair count: 1331364


In [21]:
test_row_dict = datamodule.test_row_dict
test_vector_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16
)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=49951.0), HTML(value='')))




In [22]:
embedding_size = model.blocker_net.embedding_size
test_true_pair_set = datamodule.test_true_pair_set

In [23]:
import gc
del model
del datamodule
gc.collect()

2618

In [24]:
assert len(test_vector_dict) == len(test_row_dict)

In [25]:
%%time

from entity_embed import ANNEntityIndex

ann_index = ANNEntityIndex(embedding_size=embedding_size)
ann_index.insert_vector_dict(test_vector_dict)
ann_index.build()

CPU times: user 13min 47s, sys: 2.04 s, total: 13min 49s
Wall time: 1min 39s


In [26]:
%%time

sim_threshold = 0.5
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold=sim_threshold
)

CPU times: user 21min 8s, sys: 647 ms, total: 21min 9s
Wall time: 2min 21s


In [27]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

6.157684925439244

In [28]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.27041753940212643, 0.9995823831799568)

In [29]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

3590500

In [30]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

556

In [31]:
cos_similarity = lambda a, b: np.dot(a, b)

In [32]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.44243747,
 {'recid': 3259167,
  'givenname': 'lindsay',
  'surname': 'xmeltzer',
  'suburb': 'salisbury',
  'postcode': '2814q',
  'id': 4155},
 {'recid': 3259167,
  'givenname': 'lindsay',
  'surname': 'schmeltzer',
  'suburb': 'salisbury',
  'postcode': '28144',
  'id': 1742878})

(0.9999999,
 {'recid': 3296693,
  'givenname': 'richard',
  'surname': 'rodriguez',
  'suburb': 'murphy',
  'postcode': '28906',
  'id': 884027},
 {'recid': 3296693,
  'givenname': 'richard',
  'surname': 'rodriguez',
  'suburb': 'murphy',
  'postcode': '28906',
  'id': 1365341})

(0.9999999,
 {'recid': 668953,
  'givenname': 'virginia',
  'surname': 'deal',
  'suburb': 'shelby',
  'postcode': '28152',
  'id': 789539},
 {'recid': 668953,
  'givenname': 'virginia',
  'surname': 'deal',
  'suburb': 'shelby',
  'postcode': '28152',
  'id': 4789986})

(0.62918615,
 {'recid': 8075424,
  'givenname': "cather'lne",
  'surname': 'matthews',
  'suburb': 'charlotte',
  'postcode': '28249',
  'id': 2036282},
 {'recid': 8075424,
  'givenname': 'caterine',
  'surname': 'matthews',
  'suburb': 'charlote',
  'postcode': '28209',
  'id': 4036524})

(0.9066909,
 {'recid': 903033,
  'givenname': 'ada',
  'surname': 'griffin',
  'suburb': 'north wilkesboro',
  'postcode': '28659',
  'id': 2636527},
 {'recid': 903033,
  'givenname': 'acla',
  'surname': 'griffin',
  'suburb': 'north wilkesboro',
  'postcode': '2865g',
  'id': 3045758})

(0.63899314,
 {'recid': 5584048,
  'givenname': 'angel',
  'surname': 'jones',
  'suburb': 'roanocke rapids',
  'postcode': '2787o',
  'id': 1042238},
 {'recid': 5584048,
  'givenname': 'ankel',
  'surname': 'jories',
  'suburb': 'roanoke rapids',
  'postcode': '27870',
  'id': 4042346})

(0.68683827,
 {'recid': 4541623,
  'givenname': 'cyntia',
  'surname': 'clar1<e',
  'suburb': 'charlotte',
  'postcode': '28210',
  'id': 25140},
 {'recid': 4541623,
  'givenname': 'cynthia',
  'surname': 'clare',
  'suburb': 'charlotte',
  'postcode': '2821o',
  'id': 2025266})

(0.5129759,
 {'recid': 584874,
  'givenname': 'james',
  'surname': 'mcrae',
  'suburb': 'laurinburg',
  'postcode': '28352',
  'id': 853639},
 {'recid': 584874,
  'givenname': 'jamws',
  'surname': 'iiicrae',
  'suburb': 'laurinburg',
  'postcode': '28352',
  'id': 1016987})

(0.48339996,
 {'recid': 5601884,
  'givenname': 'maria',
  'surname': 'chaiiiba',
  'suburb': 'charlotte',
  'postcode': '28527',
  'id': 1031015},
 {'recid': 5601884,
  'givenname': 'maria',
  'surname': 'xhamba',
  'suburb': 'charlotte',
  'postcode': '282z7',
  'id': 2031014})

(0.51123005,
 {'recid': 2436435,
  'givenname': 'stefah',
  'surname': 'wacller',
  'suburb': 'greensboro',
  'postcode': '27407',
  'id': 3026338},
 {'recid': 2436435,
  'givenname': 'stedan',
  'surname': 'wadler',
  'suburb': 'greensboro',
  'postcode': '274o7',
  'id': 4026360})