## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '../..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://github.com/anhaidgroup/deepmatcher/blob/master/Datasets.md

In [6]:
import glob
import csv
from tqdm.auto import tqdm

from entity_embed.data_utils.utils import Enumerator

row_dict = {}
id_enumerator = Enumerator()
left_id_set = set()
right_id_set = set()
rows_total = 2554 + 22074
clusters_total = 1154

with tqdm(total=rows_total) as pbar:
    with open(f'{home_dir}/Downloads/walmart_amazon_exp_data/exp_data/tableA.csv') as f:
        for row in csv.DictReader(f):
            row['id'] = id_enumerator[f'left-{int(row["id"])}']
            row['source'] = 'left'
            row_dict[row['id']] = row
            left_id_set.add(row['id'])
            pbar.update(1)
    
    with open(f'{home_dir}/Downloads/walmart_amazon_exp_data/exp_data/tableB.csv') as f:
        for row in csv.DictReader(f):
            row['id'] = id_enumerator[f'right-{int(row["id"])}']
            row['source'] = 'right'
            row_dict[row['id']] = row
            right_id_set.add(row['id'])
            pbar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24628.0), HTML(value='')))




In [7]:
train_true_pair_set = set()
valid_true_pair_set = set()
test_true_pair_set = set()

with open(f'{home_dir}/Downloads/walmart_amazon_exp_data/exp_data/train.csv') as f:
    for row in csv.DictReader(f):
        if int(row['label']) == 1:
            id_left = id_enumerator[f'left-{int(row["ltable_id"])}']
            id_right = id_enumerator[f'right-{int(row["rtable_id"])}']
            train_true_pair_set.add((id_left, id_right))

with open(f'{home_dir}/Downloads/walmart_amazon_exp_data/exp_data/valid.csv') as f:
    for row in csv.DictReader(f):
        if int(row['label']) == 1:
            id_left = id_enumerator[f'left-{int(row["ltable_id"])}']
            id_right = id_enumerator[f'right-{int(row["rtable_id"])}']
            valid_true_pair_set.add((id_left, id_right))

with open(f'{home_dir}/Downloads/walmart_amazon_exp_data/exp_data/test.csv') as f:
    for row in csv.DictReader(f):
        if int(row['label']) == 1:
            id_left = id_enumerator[f'left-{int(row["ltable_id"])}']
            id_right = id_enumerator[f'right-{int(row["rtable_id"])}']
            test_true_pair_set.add((id_left, id_right))
        
display(('train_true_pair_set', len(train_true_pair_set)))
display(('valid_true_pair_set', len(valid_true_pair_set)))
display(('test_true_pair_set', len(test_true_pair_set)))

('train_true_pair_set', 576)

('valid_true_pair_set', 193)

('test_true_pair_set', 193)

## Preprocess

In [8]:
attr_list = ['title', 'category', 'brand', 'modelno', 'price']

In [9]:
import unidecode
import itertools
from entity_embed import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    s_tokens = itertools.islice((s_part[:30] for s_part in default_tokenizer(s)), 0, 30)
    return ' '.join(s_tokens)[:300]

for row in tqdm(row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24628.0), HTML(value='')))




## Init Data Module

In [10]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [11]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [12]:
attr_info_dict = {
    'title': {
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    },
    'category': {
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    },
    'brand': {
        'field_type': "STRING",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'modelno': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'price': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
}

In [13]:
from entity_embed import build_row_numericalizer

row_numericalizer = build_row_numericalizer(attr_info_dict, row_dict=row_dict)
row_numericalizer.attr_info_dict

14:15:02 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
14:15:05 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
14:15:09 INFO:For attr='brand', computing actual alphabet and max_str_len
14:15:09 INFO:For attr='brand', using actual_max_str_len=48
14:15:09 INFO:For attr='modelno', computing actual alphabet and max_str_len
14:15:09 INFO:For attr='modelno', using actual_max_str_len=48
14:15:09 INFO:For attr='price', computing actual alphabet and max_str_len
14:15:09 INFO:For attr='price', using actual_max_str_len=10


{'title': NumericalizeInfo(field_type=<FieldType.SEMANTIC_MULTITOKEN: 'semantic_multitoken'>, tokenizer=<function default_tokenizer at 0x7f5dbba54dc0>, alphabet=None, max_str_len=None, vocab=<torchtext.vocab.Vocab object at 0x7f5d919d2580>),
 'category': NumericalizeInfo(field_type=<FieldType.SEMANTIC_MULTITOKEN: 'semantic_multitoken'>, tokenizer=<function default_tokenizer at 0x7f5dbba54dc0>, alphabet=None, max_str_len=None, vocab=<torchtext.vocab.Vocab object at 0x7f5e483b1340>),
 'brand': NumericalizeInfo(field_type=<FieldType.STRING: 'string'>, tokenizer=<function default_tokenizer at 0x7f5dbba54dc0>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=48, vocab=None),
 

In [14]:
from entity_embed import PairwiseDataModule

datamodule = PairwiseDataModule(
    row_dict=row_dict,
    row_numericalizer=row_numericalizer,
    pos_pair_batch_size=45,
    neg_pair_batch_size=1225,
    row_batch_size=16,
    train_true_pair_set=train_true_pair_set,
    valid_true_pair_set=valid_true_pair_set,
    test_true_pair_set=test_true_pair_set,
    random_seed=random_seed
)

## Training

In [15]:
from entity_embed import LinkageEmbed

ann_k = 100
model = LinkageEmbed(
    datamodule,
    ann_k=ann_k,
    use_mask=True,
    embedding_size=300
)

In [16]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 50
early_stop_callback = EarlyStopping(
   monitor='valid_f1_at_0.7',
   min_delta=0.00,
   patience=10,
   verbose=True,
   mode='max'
)
tb_log_dir = 'tb_logs'
tb_name = 'walmart-amazon'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name)
)

14:15:09 INFO:GPU available: True, used: True
14:15:09 INFO:TPU available: None, using: 0 TPU cores
14:15:09 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [17]:
trainer.fit(model, datamodule)

14:15:11 INFO:
  | Name        | Type           | Params
-----------------------------------------------
0 | blocker_net | BlockerNet     | 18.5 M
1 | losser      | NTXentLoss     | 0     
2 | miner       | BatchHardMiner | 0     
-----------------------------------------------
9.6 M     Trainable params
8.9 M     Non-trainable params
18.5 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [18]:
model.blocker_net.get_signature_weights()

{'title': 0.20108994841575623,
 'category': 0.10303279012441635,
 'brand': 0.2324979156255722,
 'modelno': 0.2948997914791107,
 'price': 0.1684795469045639}

## Testing

In [19]:
trainer.test(ckpt_path='best')

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_f1_at_0.3': 0.42505592841163314,
 'test_f1_at_0.5': 0.7927927927927929,
 'test_f1_at_0.7': 0.8119402985074626,
 'test_f1_at_0.9': 0.33620689655172414,
 'test_pair_entity_ratio_at_0.3': 1.8255208333333333,
 'test_pair_entity_ratio_at_0.5': 0.6536458333333334,
 'test_pair_entity_ratio_at_0.7': 0.3697916666666667,
 'test_pair_entity_ratio_at_0.9': 0.1015625,
 'test_precision_at_0.3': 0.2710413694721826,
 'test_precision_at_0.5': 0.701195219123506,
 'test_precision_at_0.7': 0.9577464788732394,
 'test_precision_at_0.9': 1.0,
 'test_recall_at_0.3': 0.9844559585492227,
 'test_recall_at_0.5': 0.9119170984455959,
 'test_recall_at_0.7': 0.7046632124352331,
 'test_recall_at_0.9': 0.20207253886010362}
--------------------------------------------------------------------------------


[{'test_precision_at_0.3': 0.2710413694721826,
  'test_recall_at_0.3': 0.9844559585492227,
  'test_f1_at_0.3': 0.42505592841163314,
  'test_pair_entity_ratio_at_0.3': 1.8255208333333333,
  'test_precision_at_0.5': 0.701195219123506,
  'test_recall_at_0.5': 0.9119170984455959,
  'test_f1_at_0.5': 0.7927927927927929,
  'test_pair_entity_ratio_at_0.5': 0.6536458333333334,
  'test_precision_at_0.7': 0.9577464788732394,
  'test_recall_at_0.7': 0.7046632124352331,
  'test_f1_at_0.7': 0.8119402985074626,
  'test_pair_entity_ratio_at_0.7': 0.3697916666666667,
  'test_precision_at_0.9': 1.0,
  'test_recall_at_0.9': 0.20207253886010362,
  'test_f1_at_0.9': 0.33620689655172414,
  'test_pair_entity_ratio_at_0.9': 0.1015625}]

## Testing manually 

In [20]:
from entity_embed.data_utils.utils import id_pairs_to_cluster_mapping_and_dict, cluster_dict_to_filtered_row_dict

__, test_cluster_dict = id_pairs_to_cluster_mapping_and_dict(test_true_pair_set)
test_row_dict = cluster_dict_to_filtered_row_dict(row_dict, test_cluster_dict)
test_left_vector_dict, test_right_vector_dict = model.predict(
    row_dict=test_row_dict,
    left_id_set=left_id_set,
    right_id_set=right_id_set,
    batch_size=16
)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=24.0), HTML(value='')))




In [21]:
embedding_size = model.blocker_net.embedding_size

In [22]:
assert (len(test_left_vector_dict) + len(test_right_vector_dict)) == len(test_row_dict)

In [23]:
%%time

from entity_embed import ANNLinkageIndex

ann_index = ANNLinkageIndex(embedding_size=embedding_size)
ann_index.insert_vector_dict(left_vector_dict=test_left_vector_dict, right_vector_dict=test_right_vector_dict)
ann_index.build()

CPU times: user 648 ms, sys: 0 ns, total: 648 ms
Wall time: 96.6 ms


In [24]:
%%time

sim_threshold = 0.7
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold=sim_threshold,
    left_vector_dict=test_left_vector_dict,
    right_vector_dict=test_right_vector_dict,
)

CPU times: user 357 ms, sys: 0 ns, total: 357 ms
Wall time: 46.2 ms


In [25]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

0.3697916666666667

In [26]:
from entity_embed.evaluation import precision_and_recall, f1_score

precision, recall = precision_and_recall(found_pair_set, test_true_pair_set)
precision, recall

(0.9577464788732394, 0.7046632124352331)

In [27]:
f1_score(precision, recall)

0.8119402985074626

In [28]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

6

In [29]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

57

In [30]:
cos_similarity = lambda a, b: np.dot(a, b)

In [31]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_left_vector_dict[id_left], test_right_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.58983624,
 {'id': 932,
  'title': 'da - lite dual vision tensioned cosmopolitan electrol - av format 10 x 10 diagonal',
  'category': 'electronics - general',
  'brand': 'da - lite',
  'modelno': '84992',
  'price': '2735 . 99',
  'source': 'left'},
 {'id': 5367,
  'title': 'da - lite tensioned cosmopolitan electrol - projection screen motorized - 1 1 - dual vision',
  'category': 'projection screens',
  'brand': 'da - lite',
  'modelno': 'cosmopolitan electrol',
  'price': '',
  'source': 'right'})

(0.55199784,
 {'id': 428,
  'title': 'hp 27 black inkjet cartridge c8727an',
  'category': 'printers',
  'brand': 'hp',
  'modelno': 'c8727an',
  'price': '19 . 97',
  'source': 'left'},
 {'id': 9544,
  'title': 'hp 27 black ink cartridge in retail packaging c8727an 140',
  'category': 'inkjet printer ink',
  'brand': 'hp',
  'modelno': 'hewc8727an',
  'price': '18 . 61',
  'source': 'right'})

(0.61013633,
 {'id': 145,
  'title': 'case logic medium slr camera bag',
  'category': 'photography - general',
  'brand': 'case logic',
  'modelno': '133827',
  'price': '47 . 99',
  'source': 'left'},
 {'id': 18445,
  'title': 'case logic slrc - 202 medium slr camera bag black',
  'category': 'cases bags',
  'brand': 'case logic',
  'modelno': 'slrc - 202black',
  'price': '37 . 95',
  'source': 'right'})

(0.49401593,
 {'id': 1883,
  'title': 'prince of persia pc',
  'category': 'software',
  'brand': 'encore',
  'modelno': '21191',
  'price': '10 . 88',
  'source': 'left'},
 {'id': 21677,
  'title': 'prince of persia jc cs',
  'category': 'audio video accessories',
  'brand': 'encore',
  'modelno': '21191',
  'price': '12 . 54',
  'source': 'right'})

(0.37433603,
 {'id': 1613,
  'title': 'pinnacle studio hd v15 ultimate collection pc',
  'category': 'software',
  'brand': 'avid technology',
  'modelno': '',
  'price': '119 . 0',
  'source': 'left'},
 {'id': 19772,
  'title': 'pinnacle studio ultimate collection v . 15',
  'category': 'computers accessories',
  'brand': 'avid',
  'modelno': '82103002401',
  'price': '99 . 99',
  'source': 'right'})

(0.64158684,
 {'id': 1340,
  'title': 'cta mini battery chargers for nikon en - el9 digital cameras',
  'category': 'electronics - general',
  'brand': 'cta',
  'modelno': 'mr - enel9',
  'price': '15 . 98',
  'source': 'left'},
 {'id': 11050,
  'title': 'cta mr - enel9 mini battery charger kit for nikon en - el9 battery',
  'category': 'battery chargers',
  'brand': 'cta digital',
  'modelno': 'mr - enel9',
  'price': '8 . 32',
  'source': 'right'})

(0.40458918,
 {'id': 1810,
  'title': 'microsoft comfort optical mouse 3000 silver blue',
  'category': 'mice',
  'brand': 'microsoft',
  'modelno': 'd1t00011',
  'price': '19 . 95',
  'source': 'left'},
 {'id': 3131,
  'title': 'comfort opt mse3000 silver blu',
  'category': 'computers accessories',
  'brand': 'microsoft',
  'modelno': 'd1t - 00011',
  'price': '11 . 99',
  'source': 'right'})

(0.67768085,
 {'id': 1559,
  'title': 'paper mate mechanical pencil starter set',
  'category': 'stationery & office machinery',
  'brand': 'paper mate',
  'modelno': '1739312',
  'price': '4 . 87',
  'source': 'left'},
 {'id': 17107,
  'title': 'paper mate 1739312 - megalead mechanical pencil starter set 0 . 7 mm 2 erasure refills blue barrel',
  'category': 'printer accessories',
  'brand': 'paper mate',
  'modelno': '',
  'price': '23 . 96',
  'source': 'right'})

(0.55054814,
 {'id': 2468,
  'title': 'mohawk color copy gloss paper ream of 500 sheets',
  'category': 'stationery & office machinery',
  'brand': 'mohawk',
  'modelno': '36101',
  'price': '12 . 94',
  'source': 'left'},
 {'id': 13367,
  'title': 'mohawk 36101 - color copy gloss paper 96 brightness 32lb 8 - 1 2 x 11 white 500 sheets ream',
  'category': 'projection screens',
  'brand': 'mohawk',
  'modelno': '',
  'price': '',
  'source': 'right'})

(0.67922693,
 {'id': 1396,
  'title': 'dj - tech t545a 15 600w 2 - way active loudspeaker',
  'category': 'stereos / audio',
  'brand': 'dj tech',
  'modelno': 't545a',
  'price': '358 . 0',
  'source': 'left'},
 {'id': 4992,
  'title': 'djtech super powered 15 600 w dj spkrs',
  'category': 'home audio theater',
  'brand': 'dj - tech',
  'modelno': 't545a',
  'price': '396 . 74',
  'source': 'right'})