In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '../..')

## Load Dataset

In [4]:
from entity_embed.benchmarks import AbtBuyBenchmark

benchmark = AbtBuyBenchmark(data_dir_path="../data/")
benchmark

15:07:15 INFO:Extracting Abt-Buy...
15:07:15 INFO:Reading Abt-Buy row_dict...
15:07:16 INFO:Reading Abt-Buy train.csv...
15:07:16 INFO:Reading Abt-Buy valid.csv...
15:07:16 INFO:Reading Abt-Buy test.csv...


<AbtBuyBenchmark> from http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Textual/Abt-Buy/abt_buy_exp_data.zip

## Preprocess

In [5]:
attr_list = ['name', 'description', 'price']

In [6]:
from tqdm.auto import tqdm
import unidecode

def clean_str(s):
    return unidecode.unidecode(s).lower().strip()

for row in tqdm(benchmark.row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

  0%|          | 0/2173 [00:00<?, ?it/s]

## Init Data Module

In [7]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [8]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [9]:
attr_config_dict = {
    'name': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'semantic_name': {
        'source_attr': 'name',
        'field_type': "SEMANTIC",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    },
    'description': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'semantic_description': {
        'source_attr': 'description',
        'field_type': "SEMANTIC",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    },
    'price': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
}

In [10]:
from entity_embed import AttrConfigDictParser

row_numericalizer = AttrConfigDictParser.from_dict(attr_config_dict, row_list=benchmark.row_dict.values())
row_numericalizer.attr_config_dict

15:07:16 INFO:For attr=name, computing actual max_str_len
15:07:16 INFO:actual_max_str_len=15 must be even to enable NN pooling. Updating to 16
15:07:16 INFO:For attr=name, using actual_max_str_len=16
15:07:16 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
15:07:20 INFO:For attr=description, computing actual max_str_len
15:07:20 INFO:actual_max_str_len=15 must be even to enable NN pooling. Updating to 16
15:07:20 INFO:For attr=description, using actual_max_str_len=16
15:07:20 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
15:07:24 INFO:For attr=price, computing actual max_str_len
15:07:24 INFO:actual_max_str_len=7 must be even to enable NN pooling. Updating to 8
15:07:24 INFO:For attr=price, using actual_max_str_len=8


{'name': AttrConfig(source_attr='name', field_type=<FieldType.MULTITOKEN: 'multitoken'>, tokenizer='entity_embed.data_utils.numericalizer.default_tokenizer', alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=16, vocab=None, n_channels=8, embed_dropout_p=0.2, use_attention=True),
 'semantic_name': AttrConfig(source_attr='name', field_type=<FieldType.SEMANTIC: 'semantic_multitoken'>, tokenizer='entity_embed.data_utils.numericalizer.default_tokenizer', alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(

In [11]:
datamodule = benchmark.build_datamodule(
    row_numericalizer=row_numericalizer,
    batch_size=32,
    eval_batch_size=128,
    random_seed=random_seed
)

## Training

In [12]:
from entity_embed import LinkageEmbed
from pytorch_metric_learning.losses import NTXentLoss

ann_k = 100
model = LinkageEmbed(
    row_numericalizer,
    ann_k=ann_k,
    source_attr=datamodule.source_attr,
    left_source=datamodule.left_source,
    sim_threshold_list=[0.5, 0.7, 0.9],
    loss_cls=NTXentLoss,
)

In [13]:
import pytorch_lightning as pl
from entity_embed.early_stopping import EarlyStoppingMinEpochs
from pytorch_lightning.loggers import TensorBoardLogger

min_epochs = 5
max_epochs = 100
early_stop_callback = EarlyStoppingMinEpochs(
   min_epochs=min_epochs,
   monitor='valid_f1_at_0.7',
   min_delta=0.00,
   patience=20,
   verbose=True,
   mode='max'
)
tb_save_dir = '../tb_logs'
tb_name = f'matcher-{benchmark.dataset_name}'
trainer = pl.Trainer(
    gpus=1,
    min_epochs=min_epochs,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_save_dir, name=tb_name),
    reload_dataloaders_every_epoch=True
)

GPU available: True, used: True
15:07:24 INFO:GPU available: True, used: True
TPU available: None, using: 0 TPU cores
15:07:24 INFO:TPU available: None, using: 0 TPU cores


In [14]:
trainer.fit(model, datamodule)

15:07:24 INFO:Train positive pair count: 905
15:07:24 INFO:Valid positive pair count: 559
15:07:24 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | blocker_net | BlockerNet | 7.7 M 
1 | loss_fn     | NTXentLoss | 0     
-------------------------------------------
4.9 M     Trainable params
2.7 M     Non-trainable params
7.7 M     Total params
30.753    Total estimated model params size (MB)
15:07:58 INFO:
  | Name        | Type       | Params
-------------------------------------------
0 | blocker_net | BlockerNet | 7.7 M 
1 | loss_fn     | NTXentLoss | 0     
-------------------------------------------
4.9 M     Trainable params
2.7 M     Non-trainable params
7.7 M     Total params
30.753    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



1

## Testing

In [15]:
trainer.test(ckpt_path='best', verbose=False)

15:24:00 INFO:Test positive pair count: 576
15:24:00 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_f1_at_0.5': 0.4364746280045784,
  'test_f1_at_0.7': 0.7057324840764332,
  'test_f1_at_0.9': 0.8904347826086957,
  'test_pair_entity_ratio_at_0.5': 1.4231036882393877,
  'test_pair_entity_ratio_at_0.7': 0.6917188587334725,
  'test_pair_entity_ratio_at_0.9': 0.39944328462073764,
  'test_precision_at_0.5': 0.27970660146699267,
  'test_precision_at_0.7': 0.5573440643863179,
  'test_precision_at_0.9': 0.89198606271777,
  'test_recall_at_0.5': 0.9930555555555556,
  'test_recall_at_0.7': 0.9618055555555556,
  'test_recall_at_0.9': 0.8888888888888888}]