## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

In [6]:
from collections import defaultdict
import itertools

def Enumerator(start=0, initial=()):
    return defaultdict(itertools.count(start).__next__, initial)

In [7]:
import glob
import csv
from tqdm.auto import tqdm

id_enumerator = Enumerator()
row_dict = {}
left_id_set = set()
right_id_set = set()
rows_total = 1081 + 1092
clusters_total = 1097

with tqdm(total=rows_total) as pbar:
    with open(f'{home_dir}/Downloads/Abt-Buy/Abt.csv', encoding="latin1") as f:
        for row in csv.DictReader(f):
            row['id'] = id_enumerator[row["id"]]
            row['source'] = 'abt'
            row_dict[row['id']] = row
            left_id_set.add(row['id'])
            pbar.update(1)
    
    with open(f'{home_dir}/Downloads/Abt-Buy/Buy.csv', encoding="latin1") as f:
        for row in csv.DictReader(f):
            row['id'] = id_enumerator[row["id"]]
            row['source'] = 'buy'
            row_dict[row['id']] = row
            right_id_set.add(row['id'])
            pbar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2173.0), HTML(value='')))




In [8]:
true_pair_set = set()

with open(f'{home_dir}/Downloads/Abt-Buy/abt_buy_perfectMapping.csv') as f:
    for row in csv.DictReader(f):
        id_left = id_enumerator[row['idAbt']]
        id_right = id_enumerator[row['idBuy']]
        true_pair_set.add(tuple(sorted([id_left, id_right])))

len(true_pair_set)

1097

In [9]:
from entity_embed.data_utils.utils import id_pairs_to_cluster_mapping_and_dict

cluster_mapping, cluster_dict = id_pairs_to_cluster_mapping_and_dict(true_pair_set)
len(cluster_mapping)

2173

In [10]:
len(cluster_dict)

1076

In [11]:
cluster_attr = 'cluster_id'
max_cluster_id = max(cluster_mapping.values())

for row_id, row in tqdm(row_dict.items()):
    try:
        row[cluster_attr] = cluster_mapping[row_id]
    except KeyError:
        row[cluster_attr] = max_cluster_id
        max_cluster_id += 1

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2173.0), HTML(value='')))




In [12]:
[row_dict[row_id] for row_id in next(iter(true_pair_set))]

[{'id': 305,
  'name': 'TiVo HD Digital Video Recorder (180 Hour) - TCD652160',
  'description': 'TiVo HD Digital Video Recorder - TCD652160/ Search, Record And Watch Shows In HD/ Record Up To 20 Hours In HD (Or 180 Hours In Standard Definition)/ Record Two Shows At Once In HD/ Replaces Your Cable Box And Works With Over-The-Air Antenna/ USB Connectivity/ Remote Control/ Netflix Instant Streaming/ TiVo Service Required And Sold Separately',
  'price': '',
  'source': 'abt',
  'cluster_id': 305},
 {'id': 1508,
  'name': 'Tivo TIV652160 Tivo High Definition',
  'description': '',
  'manufacturer': 'TiVo',
  'price': '',
  'source': 'buy',
  'cluster_id': 305}]

## Preprocess

In [13]:
attr_list = ['name', 'description', 'price']

In [14]:
import unidecode
from entity_embed import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    s_tokens = itertools.islice((s_part[:30] for s_part in default_tokenizer(s)), 0, 30)
    return ' '.join(s_tokens)[:300]

for row in tqdm(row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2173.0), HTML(value='')))




## Init Data Module

In [15]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [16]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [17]:
attr_info_dict = {
    'name': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'description': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'price': {
        'field_type': "STRING",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    }
}

In [18]:
from entity_embed.data_utils.helpers import AttrInfoDictParser

row_numericalizer = AttrInfoDictParser.from_dict(attr_info_dict, row_dict=row_dict)
row_numericalizer.attr_info_dict

16:44:21 INFO:For attr='name', computing actual max_str_len
16:44:21 INFO:actual_max_str_len=15 must be pair to enable NN pooling. Updating to 16
16:44:21 INFO:For attr='name', using actual_max_str_len=16
16:44:21 INFO:For attr='description', computing actual max_str_len
16:44:21 INFO:actual_max_str_len=15 must be pair to enable NN pooling. Updating to 16
16:44:21 INFO:For attr='description', using actual_max_str_len=16
16:44:21 INFO:For attr='price', computing actual max_str_len
16:44:21 INFO:For attr='price', using actual_max_str_len=14


{'name': NumericalizeInfo(field_type=<FieldType.MULTITOKEN: 'multitoken'>, tokenizer=<function default_tokenizer at 0x7f90ac5dac10>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=16, vocab=None, n_channels=8, embed_dropout_p=0.2, use_attention=True, use_mask=True),
 'description': NumericalizeInfo(field_type=<FieldType.MULTITOKEN: 'multitoken'>, tokenizer=<function default_tokenizer at 0x7f90ac5dac10>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', '

In [19]:
from entity_embed import LinkageDataModule

train_cluster_len = 200
valid_cluster_len = 200
datamodule = LinkageDataModule(
    row_dict=row_dict,
    cluster_attr=cluster_attr,
    row_numericalizer=row_numericalizer,
    pos_pair_batch_size=10,
    neg_pair_batch_size=435,
    row_batch_size=16,
    train_cluster_len=train_cluster_len,
    valid_cluster_len=valid_cluster_len,
    test_cluster_len=clusters_total - valid_cluster_len - train_cluster_len,
    only_plural_clusters=True,
    left_id_set=left_id_set,
    right_id_set=right_id_set,
    random_seed=random_seed
)

## Training

In [20]:
from entity_embed import LinkageEmbed

ann_k = 100
model = LinkageEmbed(
    datamodule,
    ann_k=ann_k,
)

In [21]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 50
early_stop_callback = EarlyStopping(
   monitor='valid_recall_at_0.3',
   min_delta=0.00,
   patience=10,
   verbose=True,
   mode='max'
)
tb_log_dir = 'tb_logs'
tb_name = 'abt-buy'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name),
)

16:44:21 INFO:GPU available: True, used: True
16:44:21 INFO:TPU available: None, using: 0 TPU cores
16:44:21 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [22]:
trainer.fit(model, datamodule)

16:44:21 INFO:Train pair count: 214
16:44:21 INFO:Valid pair count: 200
16:44:21 INFO:Test pair count: 704
16:44:25 INFO:
  | Name        | Type           | Params
-----------------------------------------------
0 | blocker_net | BlockerNet     | 1.8 M 
1 | losser      | NTXentLoss     | 0     
2 | miner       | BatchHardMiner | 0     
-----------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [23]:
model.blocker_net.get_signature_weights()

{'name': 0.3976720869541168,
 'description': 0.33654817938804626,
 'price': 0.2657797336578369}

## Testing

In [24]:
trainer.test(ckpt_path='best')

16:45:44 INFO:Train pair count: 214
16:45:44 INFO:Valid pair count: 200
16:45:44 INFO:Test pair count: 704


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_f1_at_0.3': 0.04348828781339572,
 'test_f1_at_0.5': 0.2418780129951792,
 'test_f1_at_0.7': 0.4888597640891219,
 'test_f1_at_0.9': 0.18324607329842932,
 'test_pair_entity_ratio_at_0.3': 21.715226939970716,
 'test_pair_entity_ratio_at_0.5': 2.987554904831625,
 'test_pair_entity_ratio_at_0.7': 0.6120058565153733,
 'test_pair_entity_ratio_at_0.9': 0.05417276720351391,
 'test_precision_at_0.3': 0.02224994100394431,
 'test_precision_at_0.5': 0.14138691497182063,
 'test_precision_at_0.7': 0.4461722488038278,
 'test_precision_at_0.9': 0.9459459459459459,
 'test_recall_at_0.3': 0.9565217391304348,
 'test_recall_at_0.5': 0.836231884057971,
 'test_recall_at_0.7': 0.5405797101449276,
 'test_recall_at_0.9': 0.10144927536231885}
--------------------------------------------------------------------------------


[{'test_precision_at_0.3': 0.02224994100394431,
  'test_recall_at_0.3': 0.9565217391304348,
  'test_f1_at_0.3': 0.04348828781339572,
  'test_pair_entity_ratio_at_0.3': 21.715226939970716,
  'test_precision_at_0.5': 0.14138691497182063,
  'test_recall_at_0.5': 0.836231884057971,
  'test_f1_at_0.5': 0.2418780129951792,
  'test_pair_entity_ratio_at_0.5': 2.987554904831625,
  'test_precision_at_0.7': 0.4461722488038278,
  'test_recall_at_0.7': 0.5405797101449276,
  'test_f1_at_0.7': 0.4888597640891219,
  'test_pair_entity_ratio_at_0.7': 0.6120058565153733,
  'test_precision_at_0.9': 0.9459459459459459,
  'test_recall_at_0.9': 0.10144927536231885,
  'test_f1_at_0.9': 0.18324607329842932,
  'test_pair_entity_ratio_at_0.9': 0.05417276720351391}]

## Testing manually

In [25]:
test_row_dict = datamodule.test_row_dict
test_left_vector_dict, test_right_vector_dict = model.predict(
    row_dict=test_row_dict,
    left_id_set=left_id_set,
    right_id_set=right_id_set,
    batch_size=16
)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=86.0), HTML(value='')))




In [26]:
embedding_size = model.blocker_net.embedding_size
test_true_pair_set = datamodule.test_true_pair_set

In [27]:
assert (len(test_left_vector_dict) + len(test_right_vector_dict)) == len(test_row_dict)

In [28]:
%%time

from entity_embed import ANNLinkageIndex

ann_index = ANNLinkageIndex(embedding_size=embedding_size)
ann_index.insert_vector_dict(left_vector_dict=test_left_vector_dict, right_vector_dict=test_right_vector_dict)
ann_index.build()

CPU times: user 581 ms, sys: 12.9 ms, total: 594 ms
Wall time: 122 ms


In [29]:
%%time

sim_threshold = 0.3
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold=sim_threshold,
    left_vector_dict=test_left_vector_dict,
    right_vector_dict=test_right_vector_dict,
)

CPU times: user 1.08 s, sys: 7.04 ms, total: 1.09 s
Wall time: 244 ms


In [30]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

21.715226939970716

In [31]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.02224994100394431, 0.9565217391304348)

In [32]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

29003

In [33]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

30

In [34]:
cos_similarity = lambda a, b: np.dot(a, b)

In [35]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_left_vector_dict[id_left], test_right_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.28407127,
 {'id': 22,
  'name': 'sony xplod 10 - disc add - on cd / mp3 changer - cdx565mxrf',
  'description': 'sony xplod 10 - disc add - on cd / mp3 changer - cdx565mxrf / cd / cd - r / cd - rw and mp3 playback / mp3 decoding',
  'price': '',
  'source': 'abt',
  'cluster_id': 22},
 {'id': 1143,
  'name': 'sony cdx - 565mxrf 10 - disc cd / mp3 changer',
  'description': '',
  'manufacturer': 'Sony',
  'price': '$ 134 . 12',
  'source': 'buy',
  'cluster_id': 22})

(0.19086897,
 {'id': 607,
  'name': 'denon blu - ray disc dvd / cd digital player / transport - dvd2500btci',
  'description': 'denon blu - ray disc dvd / cd digital player / transport - dvd2500btci / hdmi 1 . 3a with high definition output up to 1080p / full 10 -',
  'price': '',
  'source': 'abt',
  'cluster_id': 607},
 {'id': 1653,
  'name': 'denon dvd - 2500btci blu - ray disc player',
  'description': 'bd - r , dvd - rw , cd - rw , secure digital ( sd ), mini secure digital ( minisd ) - bd video , dvd video ,',
  'manufacturer': 'Denon',
  'price': '$ 655 . 97',
  'source': 'buy',
  'cluster_id': 607})

(0.120742984,
 {'id': 940,
  'name': 'canon black eos 50d digital slr camera body - eos50dbody',
  'description': "canon black eos 50d digital slr camera body - eos50dbody / 15 . 1 megapixel cmos sensor / digic 4 image processor / 3 . 0 ' clear view lcd",
  'price': '',
  'source': 'abt',
  'cluster_id': 940},
 {'id': 2098,
  'name': "canon eos 50d 15 megapixel slr digital camera with live view & face detection , 3 ' lcd , 6 . 3fps & digic 4 image processor - body only",
  'description': '',
  'manufacturer': 'Canon',
  'price': '$ 1 , 127 . 95',
  'source': 'buy',
  'cluster_id': 940})

(0.23513594,
 {'id': 732,
  'name': "lg 24 ' ldf6920ww fully integrated built in white dishwasher - ldf6920wh",
  'description': "lg 24 ' ldf6920ww fully integrated built in white dishwasher - ldf6920wh / xl tall tub cleans up to 16 place settings at once / adjustable upper rack / lodecibel",
  'price': '',
  'source': 'abt',
  'cluster_id': 732},
 {'id': 1858,
  'name': 'lg dishwasher',
  'description': '',
  'manufacturer': 'LG Electronics',
  'price': '',
  'source': 'buy',
  'cluster_id': 732})

(0.010162044,
 {'id': 186,
  'name': 'universal mrf - 350 rf black base station - mrf350',
  'description': 'universal mrf - 350 rf black base station - mrf350 / no more pointing / rf addressable / ir routing / expand operating range up to 100 feet / compatible',
  'price': '$ 250 . 00',
  'source': 'abt',
  'cluster_id': 186},
 {'id': 1437,
  'name': 'universal remote control mrf - 350 addressable narrow band rf base station with rfx - 250 antenna',
  'description': '',
  'manufacturer': 'UNIVERSAL REMOTE CONTROL, INC',
  'price': '$ 174 . 72',
  'source': 'buy',
  'cluster_id': 186})

(0.13233179,
 {'id': 255,
  'name': 'whirlpool white front load washer - wfw9200swh',
  'description': 'whirlpool duet wfw9200sq white front load washer - wfw9200swh / 4 . 0 cu . ft . capacity / 6th sense technology / quiet wash plus noise reduction / built',
  'price': '',
  'source': 'abt',
  'cluster_id': 255},
 {'id': 1952,
  'name': "whirlpool 27 ' duet washer horiz axis wp",
  'description': '',
  'manufacturer': 'Whirlpool',
  'price': '$ 910 . 91',
  'source': 'buy',
  'cluster_id': 255})

(0.099445134,
 {'id': 124,
  'name': 'kingdom hearts ii video game for the sony ps2 - 662248904115',
  'description': 'kingdom hearts ii video game for the sony ps2 - 662248904115 / guide sora and friends in all - new locales based on disney films / fight alongside classic final',
  'price': '',
  'source': 'abt',
  'cluster_id': 124},
 {'id': 1163,
  'name': 'kingdom hearts ii',
  'description': '',
  'manufacturer': 'SQUARE ENIX, LLC',
  'price': '$ 19 . 99',
  'source': 'buy',
  'cluster_id': 124})

(0.12379789,
 {'id': 619,
  'name': 'sony bravia wireless home theater system in black - davhdx576wf',
  'description': 'sony bravia wireless home theater system in black - davhdx576wf / 5 . 1 channel surround sound / s - air technology / bravia sync / digital media port /',
  'price': '',
  'source': 'abt',
  'cluster_id': 619},
 {'id': 1944,
  'name': 'sony bravia dav - hdx576wf home theater system',
  'description': 'dvd player , amplifier , 5 . 1 speakers - 5 disc ( s ) - progressive scan - 1000w rms - dolby pro logic ii',
  'manufacturer': 'Sony',
  'price': '$ 478 . 72',
  'source': 'buy',
  'cluster_id': 619})

(0.23784205,
 {'id': 149,
  'name': 'waring professional cool - touch deep fryer - black / stainless steel finish - df100',
  'description': 'waring professional cool - touch deep fryer - df100 / large frying basket / 60 - minute timer / removable control panel / unique heating element / breakaway cord /',
  'price': '$ 70 . 00',
  'source': 'abt',
  'cluster_id': 149},
 {'id': 1351,
  'name': 'waring pro deep fryer 3qt - black',
  'description': '',
  'manufacturer': 'Waring',
  'price': '',
  'source': 'buy',
  'cluster_id': 149})

(0.29690045,
 {'id': 205,
  'name': 'monster ps3 hdmi - 2m playstation 3 gamelink hdmi digital video / audio cable - ps3hdmi2m',
  'description': 'monster ps3 hdmi - 2m playstation 3 gamelink hdmi digital video / audio cable - ps3hdmi2m / duraflex jacket / all - in - one digital av cable / 24k',
  'price': '',
  'source': 'abt',
  'cluster_id': 205},
 {'id': 1347,
  'name': 'monster game 127961 gamelink ( tm ) 2m hdmi digital video / audio cable for ps3',
  'description': '',
  'manufacturer': 'Monster Game',
  'price': '$ 45 . 79',
  'source': 'buy',
  'cluster_id': 205})