## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

In [6]:
import glob
import csv
from tqdm.auto import tqdm

row_dict = {}
rows_total = 2260

with tqdm(total=rows_total) as pbar:
    with open(f'{home_dir}/Downloads/affiliationstrings/affiliationstrings_ids.csv') as f:
        for row in csv.DictReader(f):
            row['id'] = int(row.pop('id1'))
            row_dict[row['id']] = row
            pbar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2260.0), HTML(value='')))




In [7]:
true_pair_set = set()

with open(f'{home_dir}/Downloads/affiliationstrings/affiliationstrings_mapping.csv') as f:
    for row in csv.DictReader(f, fieldnames=['id1', 'id2']):
        true_pair_set.add(tuple(sorted([int(row['id1']), int(row['id2'])])))

len(true_pair_set)

16408

In [8]:
from entity_embed.data_utils.utils import id_pairs_to_cluster_mapping_and_dict

cluster_mapping, cluster_dict = id_pairs_to_cluster_mapping_and_dict(true_pair_set)
len(cluster_mapping)

2260

In [9]:
len(cluster_dict)

330

In [10]:
from entity_embed.data_utils.utils import cluster_dict_to_id_pairs

assert len(true_pair_set - cluster_dict_to_id_pairs(cluster_dict)) == 0

In [11]:
cluster_attr = 'cluster_id'

for row_id, row in tqdm(row_dict.items()):
    row[cluster_attr] = cluster_mapping[row_id]

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2260.0), HTML(value='')))




In [12]:
row_dict[2727]

{'affil1': 'IBM Yamato Software Laboratory', 'id': 2727, 'cluster_id': 2727}

In [13]:
[row_dict[row_id] for row_id in cluster_dict[row_dict[2727]['cluster_id']]]

[{'affil1': 'IBM Yamato Software Laboratory', 'id': 2727, 'cluster_id': 2727},
 {'affil1': 'IBM Tokyo Research Lab, Tokyo, Japan',
  'id': 7609,
  'cluster_id': 2727},
 {'affil1': 'IBM Tokyo Research Laboratory', 'id': 2725, 'cluster_id': 2727}]

## Preprocess

In [14]:
attr_list = ['affil1']

In [15]:
import unidecode
from entity_embed import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    return ' '.join(s_part for s_part in default_tokenizer(s))

for row in tqdm(row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2260.0), HTML(value='')))




## Init Data Module

In [16]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [17]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [18]:
attr_info_dict = {
    'affil1': {
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    }
}

In [19]:
from entity_embed import build_row_numericalizer

row_numericalizer = build_row_numericalizer(attr_info_dict, row_dict=row_dict)
row_numericalizer.attr_info_dict

14:43:16 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt


{'affil1': NumericalizeInfo(field_type=<FieldType.SEMANTIC_MULTITOKEN: 'semantic_multitoken'>, tokenizer=<function default_tokenizer at 0x7f377999ca60>, alphabet=None, max_str_len=None, vocab=<torchtext.vocab.Vocab object at 0x7f3751012e50>)}

In [20]:
from entity_embed import DeduplicationDataModule

train_cluster_len = 100
valid_cluster_len = 100
test_cluster_len = len(cluster_dict) - valid_cluster_len - train_cluster_len
datamodule = DeduplicationDataModule(
    row_dict=row_dict,
    cluster_attr=cluster_attr,
    row_numericalizer=row_numericalizer,
    pos_pair_batch_size=45,
    neg_pair_batch_size=1225,
    row_batch_size=16,
    train_cluster_len=train_cluster_len,
    valid_cluster_len=valid_cluster_len,
    test_cluster_len=test_cluster_len,
    only_plural_clusters=True,
    log_empty_vals=False,
    random_seed=random_seed
)

## Training

In [21]:
from entity_embed import EntityEmbed

ann_k = 100
model = EntityEmbed(
    datamodule,
    ann_k=ann_k,
    use_mask=True,
    embedding_size=300,
)

In [22]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 50
early_stop_callback = EarlyStopping(
   monitor='valid_recall_at_0.3',
   min_delta=0.00,
   patience=10,
   verbose=True,
   mode='max'
)
tb_log_dir = 'tb_logs'
tb_name = 'affiliations'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name)
)

14:43:20 INFO:GPU available: True, used: True
14:43:20 INFO:TPU available: None, using: 0 TPU cores
14:43:20 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [23]:
trainer.fit(model, datamodule)

14:43:20 INFO:Train pair count: 5922
14:43:20 INFO:Valid pair count: 3867
14:43:20 INFO:Test pair count: 7006
14:43:22 INFO:
  | Name        | Type           | Params
-----------------------------------------------
0 | blocker_net | BlockerNet     | 935 K 
1 | losser      | NTXentLoss     | 0     
2 | miner       | BatchHardMiner | 0     
-----------------------------------------------
407 K     Trainable params
528 K     Non-trainable params
935 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [24]:
import gc
del trainer
gc.collect()

123

## Testing manually 

In [25]:
datamodule.setup(stage='test')

14:45:00 INFO:Train pair count: 5922
14:45:00 INFO:Valid pair count: 3867
14:45:00 INFO:Test pair count: 7006


In [26]:
test_row_dict = datamodule.test_row_dict
test_vector_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16
)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=55.0), HTML(value='')))




In [27]:
embedding_size = model.blocker_net.embedding_size
test_true_pair_set = datamodule.test_true_pair_set

In [28]:
import gc
del model
del datamodule
gc.collect()

1097

In [29]:
assert len(test_vector_dict) == len(test_row_dict)

In [30]:
%%time

from entity_embed import ANNEntityIndex

ann_index = ANNEntityIndex(embedding_size=embedding_size)
ann_index.insert_vector_dict(test_vector_dict)
ann_index.build()

CPU times: user 208 ms, sys: 1.47 ms, total: 209 ms
Wall time: 49.8 ms


In [31]:
%%time

sim_threshold = 0.3
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold=sim_threshold
)

CPU times: user 372 ms, sys: 4.97 ms, total: 377 ms
Wall time: 72.4 ms


In [32]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

64.5836216839677

In [33]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.11072614923027467, 0.8849557522123894)

In [34]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

49794

In [35]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

806

In [36]:
cos_similarity = lambda a, b: np.dot(a, b)

In [37]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.21186301,
 {'affil1': 'department of computer science , university of maryland',
  'id': 1423,
  'cluster_id': 3452},
 {'affil1': 'department of computer science , university of helsinki , finland',
  'id': 8174,
  'cluster_id': 3452})

(0.3952813,
 {'affil1': 'at & t labs , research', 'id': 6764, 'cluster_id': 7710},
 {'affil1': 'at & t labs - research , florham park , usa 07932',
  'id': 9264,
  'cluster_id': 7710})

(0.430941,
 {'affil1': 'university of california at santa barbara , santa barbara , ca',
  'id': 2732,
  'cluster_id': 3202},
 {'affil1': 'university of california , davis , department of computer science , one shields avenue , davis , ca 95616 , usa ; e - mail : gertz @ cs . ucdavis . edu',
  'id': 7866,
  'cluster_id': 3202})

(0.20608443,
 {'affil1': 'university of maryland', 'id': 3452, 'cluster_id': 3452},
 {'affil1': 'department of computer science , university of helsinki , finland',
  'id': 8174,
  'cluster_id': 3452})

(0.4382892,
 {'affil1': 'university of california , santa barbara , ca',
  'id': 45,
  'cluster_id': 3202},
 {'affil1': 'university of california , davis , department of computer science , one shields avenue , davis , ca 95616 , usa ; e - mail : gertz @ cs . ucdavis . edu',
  'id': 7866,
  'cluster_id': 3202})

(0.38150468,
 {'affil1': 'at & t labs -- research', 'id': 7429, 'cluster_id': 7710},
 {'affil1': 'at & t labs --- research , florham park , nj',
  'id': 7659,
  'cluster_id': 7710})

(0.18702777,
 {'affil1': 'the mitre corporation , mclean , virginia and computer science department , center for automation research and institute for advanced computer studies , university of maryland at college park',
  'id': 1450,
  'cluster_id': 3452},
 {'affil1': 'univ . of maryland', 'id': 3343, 'cluster_id': 3452})

(0.20384452,
 {'affil1': 'university of helsinki , finland',
  'id': 8286,
  'cluster_id': 3452},
 {'affil1': 'university of maryland , college park , usa',
  'id': 8682,
  'cluster_id': 3452})

(0.46497926,
 {'affil1': 'ibm t . j . watson res . center', 'id': 627, 'cluster_id': 5636},
 {'affil1': 'ibm t . j . watson research center , yorktown heights , ny , usa',
  'id': 6254,
  'cluster_id': 5636})

(0.42087674,
 {'affil1': 'at & t labs - research , usa', 'id': 8577, 'cluster_id': 7710},
 {'affil1': 'at & t labs , florham park , nj', 'id': 9240, 'cluster_id': 7710})