## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

https://www.informatik.uni-leipzig.de/~saeedi/NCVoters_Readme.txt

```
5 party:
--------- 
5 sources. 
Each source 1000,000 entities.
There is one file per source, so totally 5 files
****************************************************
****************************************************
****************************************************
Fields:
---------- 
recId: entites with the same recId refer to the same entity.
givenname: 
surname: 
post code: 
suburb:
```

In [6]:
import glob
import csv
from tqdm.auto import tqdm

current_row_id = 0
row_dict = {}
rows_total = 5000000
cluster_attr = 'recid'

with tqdm(total=rows_total) as pbar:
    for filename in glob.glob(f'{home_dir}/Downloads/5Party-ocp20/*.csv'):
        with open(filename) as f:
            for row in csv.DictReader(f):
                row['id'] = current_row_id
                row[cluster_attr] = int(row[cluster_attr])  # convert cluster_attr to int
                row_dict[current_row_id] = row
                current_row_id += 1
                pbar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000000.0), HTML(value='')))




In [7]:
row_dict[0]

{'recid': 7852009,
 'givenname': 'kadelyn',
 'surname': 'gragnani',
 'suburb': 'waxhaw',
 'postcode': '28|73',
 'id': 0}

## Preprocess

In [8]:
attr_list = ['givenname', 'surname', 'suburb', 'postcode']

In [9]:
import unidecode

def clean_str(s):
    return unidecode.unidecode(s).lower().strip()[:30]

for row in tqdm(row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000000.0), HTML(value='')))




## Init Data Module

In [10]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [11]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [12]:
attr_info_dict = {
    'givenname': {
        'is_multitoken': False,
        'tokenizer': None,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'surname': {
        'is_multitoken': False,
        'tokenizer': None,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'suburb': {
        'is_multitoken': False,
        'tokenizer': None,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'postcode': {
        'is_multitoken': False,
        'tokenizer': None,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [13]:
from entity_embed import MultiSigDedupEmbed

train_cluster_len = 1000
valid_cluster_len = 1000
test_cluster_len = 200_000
ann_k = 10
use_mask = False
model = MultiSigDedupEmbed(
    # data kwargs
    row_dict=row_dict,
    attr_info_dict=attr_info_dict,
    cluster_attr=cluster_attr,
    pos_pair_batch_size=45,
    neg_pair_batch_size=1225,
    row_batch_size=16,
    train_cluster_len=train_cluster_len,
    valid_cluster_len=valid_cluster_len,
    test_cluster_len=test_cluster_len,
    only_plural_clusters=True,
    random_seed=random_seed,
    # model kwargs
    use_mask=use_mask,
    ann_k=ann_k,
)

11:54:22 INFO:For attr='givenname', computing actual alphabet and max_str_len
11:54:25 INFO:For attr='givenname', using actual_max_str_len=16
11:54:25 INFO:For attr='surname', computing actual alphabet and max_str_len
11:54:28 INFO:actual_max_str_len=21 must be pair to enable NN pooling. Updating to 22
11:54:28 INFO:For attr='surname', using actual_max_str_len=22
11:54:28 INFO:For attr='suburb', computing actual alphabet and max_str_len
11:54:31 INFO:actual_max_str_len=21 must be pair to enable NN pooling. Updating to 22
11:54:31 INFO:For attr='suburb', using actual_max_str_len=22
11:54:31 INFO:For attr='postcode', computing actual alphabet and max_str_len
11:54:34 INFO:actual_max_str_len=9 must be pair to enable NN pooling. Updating to 10
11:54:34 INFO:For attr='postcode', using actual_max_str_len=10


## Training

In [14]:
gpus = 1
max_epochs = 50
check_val_every_n_epoch = 1
early_stopping_monitor = 'valid_recall_at_0.5'
tb_log_dir = 'tb_logs'
tb_name = 'voters'

model.fit(
    gpus=gpus,
    max_epochs=max_epochs,
    check_val_every_n_epoch=check_val_every_n_epoch,
    early_stopping_monitor=early_stopping_monitor,
    tb_log_dir=tb_log_dir,
    tb_name=tb_name,
)

11:54:34 INFO:Fit model_sig_i=0, learning signature with unused_attr_list=['givenname', 'surname', 'suburb', 'postcode']
11:54:34 INFO:GPU available: True, used: True
11:54:34 INFO:TPU available: None, using: 0 TPU cores
11:54:34 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
11:54:48 INFO:Train pair count: 6679
11:54:48 INFO:Valid pair count: 6645
11:54:48 INFO:Test pair count: 1331364
11:54:51 INFO:
  | Name        | Type           | Params
-----------------------------------------------
0 | blocker_net | BlockerNet     | 2.5 M 
1 | losser      | NTXentLoss     | 0     
2 | miner       | BatchHardMiner | 0     
-----------------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [15]:
for lt_module in model.lt_module_list:
    display(lt_module.get_signature_weights())

{'givenname': 0.2618091106414795,
 'surname': 0.3332642614841461,
 'suburb': 0.20276573300361633,
 'postcode': 0.20216096937656403}

## Testing manually 

In [16]:
model.datamodule.setup(stage='test')

11:56:56 INFO:Train pair count: 6679
11:56:56 INFO:Valid pair count: 6645
11:56:56 INFO:Test pair count: 1331364


In [17]:
test_row_dict = model.datamodule.test_row_dict
test_multisig_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16
)
test_multisig_dict.keys()

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=49951.0), HTML(value='')))




dict_keys([('givenname', 'surname', 'suburb', 'postcode')])

In [18]:
test_true_pair_set = model.datamodule.test_true_pair_set
len(test_true_pair_set)

1331364

In [19]:
%%time

from entity_embed import MultiSigANNDedupIndex

ann_index = MultiSigANNDedupIndex(
    multisig_dict_keys=model.multisig_dict_keys,
    embedding_size=model.embedding_size,
)
ann_index.insert_multisig_dict(test_multisig_dict)
ann_index.build()

CPU times: user 15min 25s, sys: 2.59 s, total: 15min 27s
Wall time: 1min 30s


In [20]:
%%time

sim_threshold_dict = {
    ('givenname', 'surname', 'suburb', 'postcode'): 0.5,
}
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold_dict=sim_threshold_dict,
)

CPU times: user 21min 2s, sys: 2.04 s, total: 21min 4s
Wall time: 1min 59s


In [21]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

6.162908807903765

In [22]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.2701052871363072, 0.999275179440033)

In [23]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

3595084

In [24]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

965

In [25]:
cos_similarity = lambda a, b: np.dot(a, b)

In [26]:
test_multisig_dict.keys()

dict_keys([('givenname', 'surname', 'suburb', 'postcode')])

In [28]:
test_vector_dict = \
    test_multisig_dict[('givenname', 'surname', 'suburb', 'postcode')]

for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.59101814,
 {'recid': 3457209,
  'givenname': 'aggie',
  'surname': 'fillips',
  'suburb': 'wadesb0ro',
  'postcode': '28170',
  'id': 2045963},
 {'recid': 3457209,
  'givenname': 'aggie',
  'surname': 'phillips',
  'suburb': 'wadesboro',
  'postcode': '28170',
  'id': 4697033})

(0.9505236,
 {'recid': 7705558,
  'givenname': 'matthew',
  'surname': 'osbon',
  'suburb': 'moregead city',
  'postcode': '28557',
  'id': 1026297},
 {'recid': 7705558,
  'givenname': 'matthew',
  'surname': 'osborn',
  'suburb': 'morehead city',
  'postcode': '28557',
  'id': 3093755})

(0.62121016,
 {'recid': 5089659,
  'givenname': 'saron',
  'surname': 'johnston',
  'suburb': 'charlotte',
  'postcode': '28212',
  'id': 25894},
 {'recid': 5089659,
  'givenname': 'sharon',
  'surname': 'johnson',
  'suburb': 'charlotte',
  'postcode': '28212',
  'id': 3534923})

(0.94381684,
 {'recid': 6165506,
  'givenname': 'franes',
  'surname': 'harmon',
  'suburb': 'elizabeth city',
  'postcode': '27999',
  'id': 22283},
 {'recid': 6165506,
  'givenname': 'frances',
  'surname': 'harmon',
  'suburb': 'elizabeth city',
  'postcode': '27909',
  'id': 3107424})

(0.91287136,
 {'recid': 4591586,
  'givenname': 'zennifer',
  'surname': '5rebb',
  'suburb': 'southern pines',
  'postcode': '28387',
  'id': 79677},
 {'recid': 4591586,
  'givenname': 'jennifer',
  'surname': 'trebb',
  'suburb': 'southern pines',
  'postcode': '28387',
  'id': 2154122})

(0.5258381,
 {'recid': 932111,
  'givenname': 'robert',
  'surname': 'fillips',
  'suburb': 'duhn',
  'postcode': '28334',
  'id': 3063662},
 {'recid': 932111,
  'givenname': 'robert',
  'surname': 'phillips',
  'suburb': 'dunn',
  'postcode': '28334',
  'id': 4131364})

(0.92020535,
 {'recid': 3338272,
  'givenname': 'dhomas',
  'surname': 'becker',
  'suburb': 'hillsborough',
  'postcode': '27228',
  'id': 33974},
 {'recid': 3338272,
  'givenname': 'thomas',
  'surname': 'becker',
  'suburb': 'hillsborough',
  'postcode': '27278',
  'id': 3393875})

(0.47834933,
 {'recid': 4292018,
  'givenname': 'horscio',
  'surname': 'corbie5e',
  'suburb': 'winston salem',
  'postcode': '27105',
  'id': 13194},
 {'recid': 4292018,
  'givenname': 'oracio',
  'surname': 'corbiere',
  'suburb': 'winston salem',
  'postcode': '27805',
  'id': 1015508})

(0.6729281,
 {'recid': 5584048,
  'givenname': 'angel',
  'surname': 'jones',
  'suburb': 'roanocke rapids',
  'postcode': '2787o',
  'id': 1042238},
 {'recid': 5584048,
  'givenname': 'ankel',
  'surname': 'jories',
  'suburb': 'roanoke rapids',
  'postcode': '27870',
  'id': 4042346})

(0.93382704,
 {'recid': 5797641,
  'givenname': 'crystal',
  'surname': 'durham',
  'suburb': 'henderson',
  'postcode': '27536',
  'id': 2366027},
 {'recid': 5797641,
  'givenname': 'crystal',
  'surname': 'durhsm',
  'suburb': 'henderso',
  'postcode': '27536',
  'id': 4028591})