## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

In [6]:
import glob
import csv
from tqdm.auto import tqdm

row_dict = {}
rows_total = 2260

with tqdm(total=rows_total) as pbar:
    with open(f'{home_dir}/Downloads/affiliationstrings/affiliationstrings_ids.csv') as f:
        for row in csv.DictReader(f):
            row['id'] = int(row.pop('id1'))
            row_dict[row['id']] = row
            pbar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2260.0), HTML(value='')))




In [7]:
true_pair_set = set()

with open(f'{home_dir}/Downloads/affiliationstrings/affiliationstrings_mapping.csv') as f:
    for row in csv.DictReader(f, fieldnames=['id1', 'id2']):
        true_pair_set.add(tuple(sorted([int(row['id1']), int(row['id2'])])))

len(true_pair_set)

16408

In [8]:
from entity_embed.data_utils.utils import id_pairs_to_cluster_mapping_and_dict

cluster_mapping, cluster_dict = id_pairs_to_cluster_mapping_and_dict(true_pair_set)
len(cluster_mapping)

2260

In [9]:
len(cluster_dict)

330

In [10]:
from entity_embed.data_utils.utils import cluster_dict_to_id_pairs

assert len(true_pair_set - cluster_dict_to_id_pairs(cluster_dict)) == 0

In [11]:
cluster_attr = 'cluster_id'

for row_id, row in tqdm(row_dict.items()):
    row[cluster_attr] = cluster_mapping[row_id]

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2260.0), HTML(value='')))




In [12]:
row_dict[2727]

{'affil1': 'IBM Yamato Software Laboratory', 'id': 2727, 'cluster_id': 2727}

In [13]:
[row_dict[row_id] for row_id in cluster_dict[row_dict[2727]['cluster_id']]]

[{'affil1': 'IBM Yamato Software Laboratory', 'id': 2727, 'cluster_id': 2727},
 {'affil1': 'IBM Tokyo Research Lab, Tokyo, Japan',
  'id': 7609,
  'cluster_id': 2727},
 {'affil1': 'IBM Tokyo Research Laboratory', 'id': 2725, 'cluster_id': 2727}]

## Preprocess

In [14]:
attr_list = ['affil1']

In [15]:
import unidecode
from entity_embed.data_utils.one_hot_encoders import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    return ' '.join(s_part for s_part in default_tokenizer(s))

for row in tqdm(row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2260.0), HTML(value='')))




## Init Data Module

In [16]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [17]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [18]:
attr_info_dict = {
    'affil1': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [19]:
from entity_embed import build_row_encoder

row_encoder = build_row_encoder(attr_info_dict, row_dict=row_dict)
row_encoder.attr_info_dict

22:36:31 INFO:For attr='affil1', computing actual alphabet and max_str_len
22:36:31 INFO:actual_max_str_len=21 must be pair to enable NN pooling. Updating to 22
22:36:31 INFO:For attr='affil1', using actual_max_str_len=22


{'affil1': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7f97f86df940>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=22)}

In [20]:
from entity_embed import DeduplicationDataModule

train_cluster_len = 100
valid_cluster_len = 100
test_cluster_len = len(cluster_dict) - valid_cluster_len - train_cluster_len
datamodule = DeduplicationDataModule(
    row_dict=row_dict,
    cluster_attr=cluster_attr,
    row_encoder=row_encoder,
    pos_pair_batch_size=45,
    neg_pair_batch_size=1225,
    row_batch_size=16,
    train_cluster_len=train_cluster_len,
    valid_cluster_len=valid_cluster_len,
    test_cluster_len=test_cluster_len,
    only_plural_clusters=True,
    log_empty_vals=False,
    random_seed=random_seed
)

## Training

In [21]:
from entity_embed import EntityEmbed

ann_k = 100
model = EntityEmbed(
    datamodule,
    ann_k=ann_k,
    use_mask=True
)

In [22]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 50
early_stop_callback = EarlyStopping(
   monitor='valid_recall_at_0.3',
   min_delta=0.00,
   patience=10,
   verbose=True,
   mode='max'
)
tb_log_dir = 'tb_logs'
tb_name = 'affiliations'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name)
)

22:36:31 INFO:GPU available: True, used: True
22:36:31 INFO:TPU available: None, using: 0 TPU cores
22:36:31 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [23]:
trainer.fit(model, datamodule)

22:36:31 INFO:Train pair count: 5922
22:36:31 INFO:Valid pair count: 3867
22:36:31 INFO:Test pair count: 7006
22:36:33 INFO:
  | Name        | Type           | Params
-----------------------------------------------
0 | blocker_net | BlockerNet     | 851 K 
1 | losser      | NTXentLoss     | 0     
2 | miner       | BatchHardMiner | 0     
-----------------------------------------------
851 K     Trainable params
0         Non-trainable params
851 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [24]:
import gc
del trainer
gc.collect()

69

## Testing manually 

In [25]:
datamodule.setup(stage='test')

22:38:32 INFO:Train pair count: 5922
22:38:32 INFO:Valid pair count: 3867
22:38:32 INFO:Test pair count: 7006


In [26]:
test_row_dict = datamodule.test_row_dict
test_vector_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16
)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=55.0), HTML(value='')))




In [27]:
embedding_size = model.blocker_net.embedding_size
test_true_pair_set = datamodule.test_true_pair_set

In [28]:
import gc
del model
del datamodule
gc.collect()

1119

In [29]:
assert len(test_vector_dict) == len(test_row_dict)

In [30]:
%%time

from entity_embed import ANNEntityIndex

ann_index = ANNEntityIndex(embedding_size=embedding_size)
ann_index.insert_vector_dict(test_vector_dict)
ann_index.build()

CPU times: user 178 ms, sys: 0 ns, total: 178 ms
Wall time: 31.6 ms


In [31]:
%%time

sim_threshold = 0.3
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold=sim_threshold
)

CPU times: user 355 ms, sys: 4.16 ms, total: 359 ms
Wall time: 62.3 ms


In [32]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

36.92156862745098

In [33]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.18581112742494768, 0.848986582928918)

In [34]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

26063

In [35]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

1058

In [36]:
cos_similarity = lambda a, b: np.dot(a, b)

In [37]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.2332712,
 {'affil1': 'department of computer science , university of maryland',
  'id': 1423,
  'cluster_id': 3452},
 {'affil1': 'department of computer science , university of helsinki , finland',
  'id': 8174,
  'cluster_id': 3452})

(0.112500355,
 {'affil1': 'iit bombay', 'id': 237, 'cluster_id': 4481},
 {'affil1': 'computer science & engineering , indian institute of technology , mumbai , india',
  'id': 1126,
  'cluster_id': 4481})

(0.16948275,
 {'affil1': 'uc , berkeley', 'id': 1816, 'cluster_id': 903},
 {'affil1': 'computer science department , university of california at santa barbara , usa',
  'id': 8514,
  'cluster_id': 903})

(-0.15189417,
 {'affil1': 'u . c . berkeley , intel research , berkeley',
  'id': 4199,
  'cluster_id': 903},
 {'affil1': 'university of california , usa', 'id': 8359, 'cluster_id': 903})

(0.23071076,
 {'affil1': 'university of california at santa barbara , santa barbara , ca',
  'id': 2732,
  'cluster_id': 3202},
 {'affil1': 'university of california , davis , department of computer science , one shields avenue , davis , ca 95616 , usa ; e - mail : gertz @ cs . ucdavis . edu',
  'id': 7866,
  'cluster_id': 3202})

(0.22582476,
 {'affil1': 'university of california , santa barbara , ca',
  'id': 45,
  'cluster_id': 3202},
 {'affil1': 'university of california , davis , department of computer science , one shields avenue , davis , ca 95616 , usa ; e - mail : gertz @ cs . ucdavis . edu',
  'id': 7866,
  'cluster_id': 3202})

(0.2501775,
 {'affil1': 'university of maryland', 'id': 3452, 'cluster_id': 3452},
 {'affil1': 'department of computer science , university of helsinki , finland',
  'id': 8174,
  'cluster_id': 3452})

(0.21527648,
 {'affil1': 'data mining technologies , oracle', 'id': 347, 'cluster_id': 370},
 {'affil1': 'oracle usa , redwood city , ca , usa',
  'id': 5873,
  'cluster_id': 370})

(-0.17255367,
 {'affil1': 'university of california los angeles , los angeles',
  'id': 3323,
  'cluster_id': 3323},
 {'affil1': 'ucla computer science', 'id': 4524, 'cluster_id': 3323})

(-0.1314865,
 {'affil1': 'at & t labs', 'id': 1635, 'cluster_id': 7710},
 {'affil1': 'att labs - research , florham park , nj',
  'id': 5204,
  'cluster_id': 7710})