In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '../..')

## Load Dataset

In [4]:
from entity_embed.benchmarks import AmazonGoogleBenchmark

benchmark = AmazonGoogleBenchmark(data_dir_path="../data/")
benchmark

14:20:23 INFO:Extracting Amazon-Google...
14:20:23 INFO:Reading Amazon-Google row_dict...
14:20:23 INFO:Reading Amazon-Google train.csv...
14:20:23 INFO:Reading Amazon-Google valid.csv...
14:20:23 INFO:Reading Amazon-Google test.csv...


<AmazonGoogleBenchmark> from http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Amazon-Google/amazon_google_exp_data.zip

## Preprocess

In [5]:
attr_list = ['title', 'manufacturer', 'price']

In [6]:
from tqdm.auto import tqdm
import unidecode
import itertools
from entity_embed import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    s_tokens = itertools.islice((s_part[:30] for s_part in default_tokenizer(s)), 0, 30)
    return ' '.join(s_tokens)[:300]

for row in tqdm(benchmark.row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4589.0), HTML(value='')))




## Init Data Module

In [7]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [8]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [9]:
attr_info_dict = {
    'title': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'semantic_title': {
        'source_attr': 'title',
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
        'use_mask': True,
    },
    'manufacturer': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'price': {
        'field_type': "STRING",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    }
}

In [10]:
from entity_embed import AttrInfoDictParser

row_numericalizer = AttrInfoDictParser.from_dict(attr_info_dict, row_dict=benchmark.row_dict)
row_numericalizer.attr_info_dict

14:20:24 INFO:For attr=title, computing actual max_str_len
14:20:24 INFO:For attr=title, using actual_max_str_len=26
14:20:24 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
14:20:27 INFO:For attr=manufacturer, computing actual max_str_len
14:20:27 INFO:actual_max_str_len=15 must be pair to enable NN pooling. Updating to 16
14:20:27 INFO:For attr=manufacturer, using actual_max_str_len=16
14:20:27 INFO:For attr=price, computing actual max_str_len
14:20:27 INFO:actual_max_str_len=11 must be pair to enable NN pooling. Updating to 12
14:20:27 INFO:For attr=price, using actual_max_str_len=12


{'title': NumericalizeInfo(source_attr='title', field_type=<FieldType.MULTITOKEN: 'multitoken'>, tokenizer='entity_embed.data_utils.numericalizer.default_tokenizer', alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=26, vocab=None, n_channels=8, embed_dropout_p=0.2, use_attention=True, use_mask=True),
 'semantic_title': NumericalizeInfo(source_attr='title', field_type=<FieldType.SEMANTIC_MULTITOKEN: 'semantic_multitoken'>, tokenizer='entity_embed.data_utils.numericalizer.default_tokenizer', alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 

In [11]:
datamodule = benchmark.build_pairwise_datamodule(
    row_numericalizer=row_numericalizer,
    batch_size=20,
    row_batch_size=16,
    random_seed=random_seed
)

## Training

In [12]:
from entity_embed import LinkageEmbed

ann_k = 100
model = LinkageEmbed(
    datamodule,
    ann_k=ann_k,
    embedding_size=300
)

In [13]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 100
early_stop_callback = EarlyStopping(
   monitor='valid_f1_at_0.7',
   min_delta=0.00,
   patience=20,
   verbose=True,
   mode='max'
)
tb_log_dir = '../tb_logs'
tb_name = 'f1-amzn-googl'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name)
)

14:20:27 INFO:GPU available: True, used: True
14:20:27 INFO:TPU available: None, using: 0 TPU cores
14:20:27 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [14]:
trainer.fit(model, datamodule)

14:20:29 INFO:
  | Name        | Type       | Params
-------------------------------------------
0 | blocker_net | BlockerNet | 7.4 M 
1 | losser      | SupConLoss | 0     
-------------------------------------------
5.7 M     Trainable params
1.7 M     Non-trainable params
7.4 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [15]:
model.blocker_net.get_signature_weights()

{'title': 0.4040873646736145,
 'semantic_title': 0.35267874598503113,
 'manufacturer': 0.12780876457691193,
 'price': 0.11543218046426773}

In [16]:
from entity_embed import validate_best

validate_best(trainer)

{'valid_f1_at_0.3': 0.420863309352518,
 'valid_f1_at_0.5': 0.7553366174055828,
 'valid_f1_at_0.7': 0.8340080971659919,
 'valid_f1_at_0.9': 0.7193877551020408,
 'valid_pair_entity_ratio_at_0.3': 1.9212253829321664,
 'valid_pair_entity_ratio_at_0.5': 0.8205689277899344,
 'valid_pair_entity_ratio_at_0.7': 0.5689277899343544,
 'valid_pair_entity_ratio_at_0.9': 0.34573304157549234,
 'valid_precision_at_0.3': 0.26651480637813213,
 'valid_precision_at_0.5': 0.6133333333333333,
 'valid_precision_at_0.7': 0.7923076923076923,
 'valid_precision_at_0.9': 0.8924050632911392,
 'valid_recall_at_0.3': 1.0,
 'valid_recall_at_0.5': 0.9829059829059829,
 'valid_recall_at_0.7': 0.8803418803418803,
 'valid_recall_at_0.9': 0.6025641025641025}

## Testing

In [17]:
trainer.test(ckpt_path='best', verbose=False)

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…




[{'test_f1_at_0.3': 0.4152617568766637,
  'test_f1_at_0.5': 0.7240829346092504,
  'test_f1_at_0.7': 0.832,
  'test_f1_at_0.9': 0.7684729064039408,
  'test_pair_entity_ratio_at_0.3': 1.9370932754880694,
  'test_pair_entity_ratio_at_0.5': 0.8524945770065075,
  'test_pair_entity_ratio_at_0.7': 0.5770065075921909,
  'test_pair_entity_ratio_at_0.9': 0.37310195227765725,
  'test_precision_at_0.3': 0.2620380739081747,
  'test_precision_at_0.5': 0.5776081424936387,
  'test_precision_at_0.7': 0.7819548872180451,
  'test_precision_at_0.9': 0.9069767441860465,
  'test_recall_at_0.3': 1.0,
  'test_recall_at_0.5': 0.9700854700854701,
  'test_recall_at_0.7': 0.8888888888888888,
  'test_recall_at_0.9': 0.6666666666666666}]

## Testing manually 

In [18]:
from entity_embed.data_utils.utils import id_pairs_to_cluster_mapping_and_dict, cluster_dict_to_filtered_row_dict

__, test_cluster_dict = id_pairs_to_cluster_mapping_and_dict(benchmark.test_true_pair_set)
test_row_dict = cluster_dict_to_filtered_row_dict(benchmark.row_dict, test_cluster_dict)
test_left_vector_dict, test_right_vector_dict = model.predict(
    row_dict=test_row_dict,
    left_id_set={id_ for id_, row in test_row_dict.items() if row['__source'] == 'left'},
    right_id_set={id_ for id_, row in test_row_dict.items() if row['__source'] == 'right'},
    batch_size=16
)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=29.0), HTML(value='')))




In [19]:
embedding_size = model.blocker_net.embedding_size

In [20]:
assert (len(test_left_vector_dict) + len(test_right_vector_dict)) == len(test_row_dict)

In [21]:
%%time

from entity_embed import ANNLinkageIndex

ann_index = ANNLinkageIndex(embedding_size=embedding_size)
ann_index.insert_vector_dict(left_vector_dict=test_left_vector_dict, right_vector_dict=test_right_vector_dict)
ann_index.build()

CPU times: user 573 ms, sys: 0 ns, total: 573 ms
Wall time: 68.5 ms


In [22]:
%%time

sim_threshold = 0.7
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold=sim_threshold,
    left_vector_dict=test_left_vector_dict,
    right_vector_dict=test_right_vector_dict,
)

CPU times: user 251 ms, sys: 0 ns, total: 251 ms
Wall time: 33 ms


In [23]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

0.5770065075921909

In [24]:
from entity_embed.evaluation import precision_and_recall, f1_score

precision, recall = precision_and_recall(found_pair_set, benchmark.test_true_pair_set)
precision, recall

(0.7819548872180451, 0.8888888888888888)

In [25]:
f1_score(precision, recall)

0.832

In [26]:
false_positives = list(found_pair_set - benchmark.test_true_pair_set)
len(false_positives)

58

In [27]:
false_negatives = list(benchmark.test_true_pair_set - found_pair_set)
len(false_negatives)

26

In [28]:
cos_similarity = lambda a, b: np.dot(a, b)

In [29]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_left_vector_dict[id_left], test_right_vector_dict[id_right]),
            test_row_dict[id_left], test_row_dict[id_right]
        )
    )

(0.48584753,
 {'id': 866,
  'title': 'system care professional',
  'manufacturer': 'avanquest',
  'price': '49 . 95',
  '__source': 'left'},
 {'id': 1566,
  'title': 'avanquest usa llc system care professional',
  'manufacturer': '',
  'price': '43 . 32',
  '__source': 'right'})

(0.6904577,
 {'id': 395,
  'title': 'world history ( win / mac ) ( jewel case )',
  'manufacturer': 'fogware publishing',
  'price': '9 . 99',
  '__source': 'left'},
 {'id': 4269,
  'title': 'high school world history ( pc / mac ) fogware',
  'manufacturer': '',
  'price': '9 . 99',
  '__source': 'right'})

(0.6412057,
 {'id': 1206,
  'title': 'contentbarrier x4 10 . 4 single user ( mac )',
  'manufacturer': 'intego',
  'price': '49 . 99',
  '__source': 'left'},
 {'id': 1900,
  'title': 'intego contentbarrier x4 10 . 4',
  'manufacturer': '',
  'price': '54 . 99',
  '__source': 'right'})

(0.5173285,
 {'id': 345,
  'title': 'portfolio media kit be syst recovery 7 . 0 win small business ed',
  'manufacturer': 'symantec',
  'price': '50 . 0',
  '__source': 'left'},
 {'id': 2883,
  'title': 'symantec 11859201 be sys recovery 7 . 0 win sbs ed media cd m / l',
  'manufacturer': '',
  'price': '31 . 98',
  '__source': 'right'})

(0.53023946,
 {'id': 1360,
  'title': 'dragon naturally speaking standard v9',
  'manufacturer': 'nuance - communications - inc .',
  'price': '99 . 99',
  '__source': 'left'},
 {'id': 1973,
  'title': 'nuance communications inc . dragon ns standard v9',
  'manufacturer': 'nuance - communications - inc .',
  'price': '92 . 51',
  '__source': 'right'})

(0.6465797,
 {'id': 404,
  'title': 'cosmi rom07524 print perfect business cards dvd',
  'manufacturer': 'cosmi',
  'price': '',
  '__source': 'left'},
 {'id': 4253,
  'title': 'print perfect business cards dvd ( pc ) cosmi',
  'manufacturer': '',
  'price': '29 . 99',
  '__source': 'right'})

(0.63293093,
 {'id': 726,
  'title': 'simply put software data eliminator',
  'manufacturer': 'simply - put - software',
  'price': '',
  '__source': 'left'},
 {'id': 4151,
  'title': 'simply put software llc de905 - s data elminator ( win 95 98 me nt 2000 xp )',
  'manufacturer': 'simply - put - software',
  'price': '34 . 97',
  '__source': 'right'})

(0.62741995,
 {'id': 812,
  'title': 'extensis smartscale 1 - user pxe - 11433 )',
  'manufacturer': 'extensis corporation',
  'price': '',
  '__source': 'left'},
 {'id': 4160,
  'title': 'onone software pxe - 11433 pxl smartscale elect 1 - user english mac / win',
  'manufacturer': 'onone software',
  'price': '158 . 99',
  '__source': 'right'})

(0.65016204,
 {'id': 168,
  'title': 'autodesk discreet combustion 4 windows )',
  'manufacturer': 'autodesk',
  'price': '',
  '__source': 'left'},
 {'id': 3527,
  'title': 'autodesk combustion 4 . 0 compositing software win compositing software',
  'manufacturer': '',
  'price': '889 . 0',
  '__source': 'right'})

(0.6109418,
 {'id': 97,
  'title': 'tournament poker 2005',
  'manufacturer': 'eagle games',
  'price': '20 . 99',
  '__source': 'left'},
 {'id': 3000,
  'title': "eagle games egl 150 tournament poker no limit texas hold ' em",
  'manufacturer': 'eagle games',
  'price': '7 . 95',
  '__source': 'right'})