In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '../..')

## Load Dataset

In [4]:
from entity_embed.benchmarks import AmazonGoogleBenchmark

benchmark = AmazonGoogleBenchmark(data_dir_path="../data/")
benchmark

18:40:52 INFO:Extracting Amazon-Google...
18:40:52 INFO:Reading Amazon-Google record_dict...
18:40:52 INFO:Reading Amazon-Google train.csv...
18:40:52 INFO:Reading Amazon-Google valid.csv...
18:40:52 INFO:Reading Amazon-Google test.csv...


<AmazonGoogleBenchmark> from http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Amazon-Google/amazon_google_exp_data.zip

## Preprocess

In [5]:
field_list = ['title', 'manufacturer', 'price']

In [6]:
import unidecode

def clean_str(s):
    return unidecode.unidecode(s).lower().strip()

for record in benchmark.record_dict.values():
    for field in field_list:
        record[field] = clean_str(record[field])

## Init Data Module

In [7]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [8]:
from entity_embed.data_utils.field_config_parser import DEFAULT_ALPHABET

alphabet = DEFAULT_ALPHABET
field_config_dict = {
    'title': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'title_semantic': {
        'key': 'title',
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    },
    'manufacturer': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'price': {
        'field_type': "STRING",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [9]:
from entity_embed import FieldConfigDictParser

record_numericalizer = FieldConfigDictParser.from_dict(
    field_config_dict, record_list=benchmark.record_dict.values())

18:40:52 INFO:For field=title, computing actual max_str_len
18:40:52 INFO:For field=title, using actual_max_str_len=26
18:40:52 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
18:40:55 INFO:For field=manufacturer, computing actual max_str_len
18:40:55 INFO:actual_max_str_len=15 must be even to enable NN pooling. Updating to 16
18:40:55 INFO:For field=manufacturer, using actual_max_str_len=16
18:40:55 INFO:For field=price, computing actual max_str_len
18:40:55 INFO:actual_max_str_len=9 must be even to enable NN pooling. Updating to 10
18:40:55 INFO:For field=price, using actual_max_str_len=10


In [10]:
batch_size = 32
eval_batch_size = 128
datamodule = benchmark.build_datamodule(
    record_numericalizer=record_numericalizer,
    batch_size=batch_size,
    eval_batch_size=eval_batch_size,
    random_seed=random_seed
)

## Training

In [11]:
from entity_embed import LinkageEmbed
from pytorch_metric_learning.losses import NTXentLoss

ann_k = 100
model = LinkageEmbed(
    record_numericalizer,
    ann_k=ann_k,
    source_field=datamodule.source_field,
    left_source=datamodule.left_source,
    loss_cls=NTXentLoss,
    sim_threshold_list=[0.5, 0.7, 0.9],
)

In [12]:
trainer = model.fit(
    datamodule,
    min_epochs=5,
    max_epochs=100,
    check_val_every_n_epoch=1,
    early_stop_monitor="valid_f1_at_0.7",
    tb_save_dir='../tb_logs',
    tb_name=f'matcher-{benchmark.dataset_name}'
)

18:40:55 INFO:GPU available: True, used: True
18:40:55 INFO:TPU available: False, using: 0 TPU cores
18:40:55 INFO:Train positive pair count: 699
18:40:55 INFO:Train negative pair count: 6175
18:40:55 INFO:Valid positive pair count: 234
18:40:55 INFO:Valid positive pair count: 2059
18:40:55 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
18:41:11 INFO:
  | Name        | Type       | Params
-------------------------------------------
0 | blocker_net | BlockerNet | 7.2 M 
1 | loss_fn     | NTXentLoss | 0     
-------------------------------------------
5.5 M     Trainable params
1.7 M     Non-trainable params
7.2 M     Total params
28.776    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

18:52:32 INFO:Loading the best validation model from ../tb_logs/matcher-Amazon-Google/version_35/checkpoints/epoch=5-step=3107.ckpt...


In [13]:
model.validate(datamodule)

{'valid_f1_at_0.5': 0.10125427594070696,
 'valid_f1_at_0.7': 0.22260869565217395,
 'valid_f1_at_0.9': 0.28733997155049784,
 'valid_pair_entity_ratio_at_0.5': 2.20915380521554,
 'valid_pair_entity_ratio_at_0.7': 0.7935071846726982,
 'valid_pair_entity_ratio_at_0.9': 0.24960085151676423,
 'valid_precision_at_0.5': 0.05348108889424235,
 'valid_precision_at_0.7': 0.12877263581488935,
 'valid_precision_at_0.9': 0.21535181236673773,
 'valid_recall_at_0.5': 0.9487179487179487,
 'valid_recall_at_0.7': 0.8205128205128205,
 'valid_recall_at_0.9': 0.43162393162393164}

In [14]:
model.get_pool_weights()

{'title': 0.3514834940433502,
 'title_semantic': 0.3839889168739319,
 'manufacturer': 0.09599753469228745,
 'price': 0.16853000223636627}

## Testing

In [15]:
model.test(datamodule)

18:52:36 INFO:Test positive pair count: 234
18:52:36 INFO:Test positive pair count: 2059


{'test_f1_at_0.5': 0.09949148795047534,
 'test_f1_at_0.7': 0.21894498014747588,
 'test_f1_at_0.9': 0.27816411682892905,
 'test_pair_entity_ratio_at_0.5': 2.3046749059645353,
 'test_pair_entity_ratio_at_0.7': 0.8216012896292316,
 'test_pair_entity_ratio_at_0.9': 0.2606125738850081,
 'test_precision_at_0.5': 0.0524597808346934,
 'test_precision_at_0.7': 0.12622629169391758,
 'test_precision_at_0.9': 0.20618556701030927,
 'test_recall_at_0.5': 0.9615384615384616,
 'test_recall_at_0.7': 0.8247863247863247,
 'test_recall_at_0.9': 0.42735042735042733}