In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '../..')

## Load Dataset

In [4]:
from entity_embed.benchmarks import ITunesAmazonStructuredBenchmark

benchmark = ITunesAmazonStructuredBenchmark(data_dir_path="../data/")
benchmark

16:50:14 INFO:Extracting iTunes-Amazon-Structured...
16:50:14 INFO:Reading iTunes-Amazon-Structured row_dict...
16:50:14 INFO:Reading iTunes-Amazon-Structured train.csv...
16:50:14 INFO:Reading iTunes-Amazon-Structured valid.csv...
16:50:14 INFO:Reading iTunes-Amazon-Structured test.csv...


<ITunesAmazonStructuredBenchmark> from http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/iTunes-Amazon/itunes_amazon_exp_data.zip

## Preprocess

In [5]:
attr_list = ['Song_Name', 'Artist_Name', 'Album_Name', 'Genre', 'Price', 'CopyRight', 'Time', 'Released']

In [6]:
from tqdm.auto import tqdm
import unidecode
import itertools
from entity_embed import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    s_tokens = itertools.islice((s_part[:30] for s_part in default_tokenizer(s)), 0, 30)
    return ' '.join(s_tokens)[:300]

for row in tqdm(benchmark.row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=62830.0), HTML(value='')))




## Init Data Module

In [7]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [8]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [9]:
attr_info_dict = {
    'Song_Name': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'semantic_Song_Name': {
        'source_attr': 'Song_Name',
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
        'use_mask': True,
    },
    'Artist_Name': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'semantic_Artist_Name': {
        'source_attr': 'Artist_Name',
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
        'use_mask': True,
    },
    'Album_Name': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'semantic_Album_Name': {
        'source_attr': 'Album_Name',
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
        'use_mask': True,
    },
    'Genre': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'semantic_Genre': {
        'source_attr': 'Album_Name',
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
        'use_mask': True,
    },
    'Price': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'CopyRight': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'Time': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'Released': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
}

In [10]:
from entity_embed import AttrInfoDictParser

row_numericalizer = AttrInfoDictParser.from_dict(attr_info_dict, row_dict=benchmark.row_dict)
row_numericalizer.attr_info_dict

16:50:17 INFO:For attr=Song_Name, computing actual max_str_len
16:50:17 INFO:actual_max_str_len=21 must be pair to enable NN pooling. Updating to 22
16:50:17 INFO:For attr=Song_Name, using actual_max_str_len=22
16:50:17 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
16:50:21 INFO:For attr=Artist_Name, computing actual max_str_len
16:50:21 INFO:For attr=Artist_Name, using actual_max_str_len=20
16:50:21 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
16:50:24 INFO:For attr=Album_Name, computing actual max_str_len
16:50:24 INFO:For attr=Album_Name, using actual_max_str_len=16
16:50:24 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
16:50:28 INFO:For attr=Genre, computing actual max_str_len
16:50:28 INFO:actual_max_str_len=13 must be pair to enable NN pooling. Updating to 14
16:50:28 INFO:For attr=Genre, using actual_max_str_len=14
16:50:28 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
16:50:31 INFO:For attr=Price, computing actual max_str_len
16:50:31 INFO:F

{'Song_Name': NumericalizeInfo(source_attr='Song_Name', field_type=<FieldType.MULTITOKEN: 'multitoken'>, tokenizer='entity_embed.data_utils.numericalizer.default_tokenizer', alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=22, vocab=None, n_channels=8, embed_dropout_p=0.2, use_attention=True, use_mask=True),
 'semantic_Song_Name': NumericalizeInfo(source_attr='Song_Name', field_type=<FieldType.SEMANTIC_MULTITOKEN: 'semantic_multitoken'>, tokenizer='entity_embed.data_utils.numericalizer.default_tokenizer', alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',

In [11]:
datamodule = benchmark.build_pairwise_datamodule(
    row_numericalizer=row_numericalizer,
    batch_size=20,
    row_batch_size=16,
    random_seed=random_seed
)

## Training

In [12]:
from entity_embed import LinkageEmbed

ann_k = 100
model = LinkageEmbed(
    datamodule,
    ann_k=ann_k,
    embedding_size=300
)

In [13]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 100
early_stop_callback = EarlyStopping(
   monitor='valid_f1_at_0.9',
   min_delta=0.00,
   patience=20,
   verbose=True,
   mode='max'
)
tb_log_dir = '../tb_logs'
tb_name = f'f1-{benchmark.dataset_name}'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name)
)

16:50:31 INFO:GPU available: True, used: True
16:50:31 INFO:TPU available: None, using: 0 TPU cores
16:50:31 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [14]:
trainer.fit(model, datamodule)

16:50:33 INFO:
  | Name        | Type       | Params
-------------------------------------------
0 | blocker_net | BlockerNet | 26.3 M
1 | losser      | SupConLoss | 0     
-------------------------------------------
16.3 M    Trainable params
10.1 M    Non-trainable params
26.3 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [15]:
model.blocker_net.get_signature_weights()

{'Song_Name': 0.10861807316541672,
 'semantic_Song_Name': 0.1035393550992012,
 'Artist_Name': 0.07977375388145447,
 'semantic_Artist_Name': 0.08798324316740036,
 'Album_Name': 0.08326321095228195,
 'semantic_Album_Name': 0.09653947502374649,
 'Genre': 0.0609779879450798,
 'semantic_Genre': 0.09708672761917114,
 'Price': 0.06809183210134506,
 'CopyRight': 0.06887254118919373,
 'Time': 0.08170082420110703,
 'Released': 0.06355297565460205}

In [16]:
from entity_embed import validate_best

validate_best(trainer)

{'valid_f1_at_0.3': 0.8059701492537313,
 'valid_f1_at_0.5': 0.8437499999999999,
 'valid_f1_at_0.7': 0.9152542372881356,
 'valid_f1_at_0.9': 0.9454545454545454,
 'valid_pair_entity_ratio_at_0.3': 0.7547169811320755,
 'valid_pair_entity_ratio_at_0.5': 0.6981132075471698,
 'valid_pair_entity_ratio_at_0.7': 0.6037735849056604,
 'valid_pair_entity_ratio_at_0.9': 0.5283018867924528,
 'valid_precision_at_0.3': 0.675,
 'valid_precision_at_0.5': 0.7297297297297297,
 'valid_precision_at_0.7': 0.84375,
 'valid_precision_at_0.9': 0.9285714285714286,
 'valid_recall_at_0.3': 1.0,
 'valid_recall_at_0.5': 1.0,
 'valid_recall_at_0.7': 1.0,
 'valid_recall_at_0.9': 0.9629629629629629}

## Testing

In [17]:
trainer.test(ckpt_path='best', verbose=False)

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…




[{'test_f1_at_0.3': 0.8437499999999999,
  'test_f1_at_0.5': 0.9310344827586207,
  'test_f1_at_0.7': 0.9642857142857143,
  'test_f1_at_0.9': 0.9642857142857143,
  'test_pair_entity_ratio_at_0.3': 0.6851851851851852,
  'test_pair_entity_ratio_at_0.5': 0.5740740740740741,
  'test_pair_entity_ratio_at_0.7': 0.5370370370370371,
  'test_pair_entity_ratio_at_0.9': 0.5370370370370371,
  'test_precision_at_0.3': 0.7297297297297297,
  'test_precision_at_0.5': 0.8709677419354839,
  'test_precision_at_0.7': 0.9310344827586207,
  'test_precision_at_0.9': 0.9310344827586207,
  'test_recall_at_0.3': 1.0,
  'test_recall_at_0.5': 1.0,
  'test_recall_at_0.7': 1.0,
  'test_recall_at_0.9': 1.0}]