In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '../..')

## Load Dataset

In [4]:
from entity_embed.benchmarks import FodorsZagatsBenchmark

benchmark = FodorsZagatsBenchmark(data_dir_path="../data/")
benchmark

22:44:46 INFO:Extracting Fodors-Zagats...
22:44:46 INFO:Reading Fodors-Zagats record_dict...
22:44:46 INFO:Reading Fodors-Zagats train.csv...
22:44:46 INFO:Reading Fodors-Zagats valid.csv...
22:44:46 INFO:Reading Fodors-Zagats test.csv...


<FodorsZagatsBenchmark> from http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Fodors-Zagats/fodors_zagat_exp_data.zip

## Preprocess

In [5]:
field_list = ['name', 'addr', 'city', 'phone', 'type']

In [6]:
import unidecode

def clean_str(s):
    return unidecode.unidecode(s).lower().strip()

for record_dict in [benchmark.train_record_dict, benchmark.valid_record_dict, benchmark.test_record_dict]:
    for record in record_dict.values():
        for field in field_list:
            record[field] = clean_str(record[field])

### Rename attr `type` to `type_` (avoids clash on pytorch)

In [7]:
for record_dict in [benchmark.train_record_dict, benchmark.valid_record_dict, benchmark.test_record_dict]:
    for record in record_dict.values():
        if 'type' in record:
            record['type_'] = record['type']
            del record['type']

In [8]:
del field_list[field_list.index('type')]
field_list.append('type_')

## Init Data Module

In [9]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [10]:
from entity_embed import PairNumericalizer

pair_numericalizer = PairNumericalizer(field_list)

In [11]:
batch_size = 32
eval_batch_size = 256
datamodule = benchmark.build_matcher_datamodule(
    pair_numericalizer=pair_numericalizer,
    batch_size=batch_size,
    eval_batch_size=eval_batch_size,
    random_seed=random_seed
)

## Training

In [12]:
from entity_embed import Matcher

model = Matcher(
    pair_numericalizer=pair_numericalizer
)

22:44:50 INFO:Load pretrained SentenceTransformer: stsb-distilbert-base
22:44:50 INFO:Did not find folder stsb-distilbert-base
22:44:50 INFO:Search model on server: http://sbert.net/models/stsb-distilbert-base.zip
22:44:50 INFO:Load SentenceTransformer from folder: /home/fjsj/.cache/torch/sentence_transformers/sbert.net_models_stsb-distilbert-base
22:44:51 INFO:Use pytorch device: cuda


In [13]:
trainer = model.fit(
    datamodule,
    min_epochs=5,
    max_epochs=100,
    check_val_every_n_epoch=1,
    early_stop_monitor="valid_f1_at_0.5",
    tb_save_dir='../tb_logs',
    tb_name=f'matcher-{benchmark.dataset_name}'
)

22:44:51 INFO:GPU available: True, used: True
22:44:51 INFO:TPU available: False, using: 0 TPU cores
22:44:51 INFO:Train positive pair count: 66
22:44:51 INFO:Train negative pair count: 501
22:44:51 INFO:Valid positive pair count: 22
22:44:51 INFO:Valid positive pair count: 168
22:44:51 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
22:44:54 INFO:
  | Name        | Type              | Params
--------------------------------------------------
0 | matcher_net | MatcherNet        | 66.4 M
1 | loss_fn     | BCEWithLogitsLoss | 0     
--------------------------------------------------
66.4 M    Trainable params
0         Non-trainable params
66.4 M    Total params
265.455   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

22:47:47 INFO:Loading the best validation model from ../tb_logs/matcher-Fodors-Zagats/version_1/checkpoints/epoch=7-step=143.ckpt...
22:47:47 INFO:Load pretrained SentenceTransformer: stsb-distilbert-base
22:47:47 INFO:Did not find folder stsb-distilbert-base
22:47:47 INFO:Search model on server: http://sbert.net/models/stsb-distilbert-base.zip
22:47:47 INFO:Load SentenceTransformer from folder: /home/fjsj/.cache/torch/sentence_transformers/sbert.net_models_stsb-distilbert-base
22:47:48 INFO:Use pytorch device: cuda


In [14]:
model.validate(datamodule)

{'valid_f1_at_0.3': tensor(0.9362),
 'valid_f1_at_0.5': tensor(0.9565),
 'valid_f1_at_0.7': tensor(0.9565),
 'valid_f1_at_0.9': tensor(0.9767),
 'valid_precision_at_0.3': tensor(0.8800),
 'valid_precision_at_0.5': tensor(0.9167),
 'valid_precision_at_0.7': tensor(0.9167),
 'valid_precision_at_0.9': tensor(1.),
 'valid_recall_at_0.3': tensor(1.),
 'valid_recall_at_0.5': tensor(1.),
 'valid_recall_at_0.7': tensor(1.),
 'valid_recall_at_0.9': tensor(0.9545)}

## Testing

In [15]:
model.test(datamodule)

22:47:56 INFO:Test positive pair count: 22
22:47:56 INFO:Test positive pair count: 167


{'test_f1_at_0.3': tensor(0.9545),
 'test_f1_at_0.5': tensor(0.9767),
 'test_f1_at_0.7': tensor(0.9767),
 'test_f1_at_0.9': tensor(0.9268),
 'test_precision_at_0.3': tensor(0.9545),
 'test_precision_at_0.5': tensor(1.),
 'test_precision_at_0.7': tensor(1.),
 'test_precision_at_0.9': tensor(1.),
 'test_recall_at_0.3': tensor(0.9545),
 'test_recall_at_0.5': tensor(0.9545),
 'test_recall_at_0.7': tensor(0.9545),
 'test_recall_at_0.9': tensor(0.8636)}