In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '../..')

## Load Dataset

In [4]:
from entity_embed.benchmarks import WalmartAmazonStructuredBenchmark

benchmark = WalmartAmazonStructuredBenchmark(data_dir_path="../data/")
benchmark

11:14:17 INFO:Extracting Walmart-Amazon-Structured...
11:14:17 INFO:Reading Walmart-Amazon-Structured record_dict...
11:14:17 INFO:Reading Walmart-Amazon-Structured train.csv...
11:14:17 INFO:Reading Walmart-Amazon-Structured valid.csv...
11:14:17 INFO:Reading Walmart-Amazon-Structured test.csv...


<WalmartAmazonStructuredBenchmark> from http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Walmart-Amazon/walmart_amazon_exp_data.zip

## Preprocess

In [5]:
field_list = ['title', 'category', 'brand', 'modelno', 'price']

In [6]:
import unidecode

def clean_str(s):
    return unidecode.unidecode(s).lower().strip()

for record_dict in [benchmark.train_record_dict, benchmark.valid_record_dict, benchmark.test_record_dict]:
    for record in record_dict.values():
        for field in field_list:
            record[field] = clean_str(record[field])

## Init Data Module

In [7]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [10]:
from entity_embed.data_utils.field_config_parser import DEFAULT_ALPHABET

alphabet = DEFAULT_ALPHABET
field_config_dict = {
    'title': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'title_semantic': {
        'key': 'title',
        'field_type': "SEMANTIC",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    },
    'category': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'category_semantic': {
        'key': 'category',
        'field_type': "SEMANTIC",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
    },
    'brand': {
        'field_type': "STRING",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'modelno': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'price': {
        'field_type': "STRING",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
}

In [11]:
from entity_embed import FieldConfigDictParser

record_numericalizer = FieldConfigDictParser.from_dict(
    field_config_dict, record_list=benchmark.record_dict.values())

11:14:30 INFO:For field=title, computing actual max_str_len
11:14:30 INFO:actual_max_str_len=23 must be even to enable NN pooling. Updating to 24
11:14:30 INFO:For field=title, using actual_max_str_len=24
11:14:31 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
11:14:34 INFO:For field=category, computing actual max_str_len
11:14:34 INFO:actual_max_str_len=15 must be even to enable NN pooling. Updating to 16
11:14:34 INFO:For field=category, using actual_max_str_len=16
11:14:34 INFO:Loading vectors from .vector_cache/wiki.en.vec.pt
11:14:37 INFO:For field=brand, computing actual max_str_len
11:14:37 INFO:For field=brand, using actual_max_str_len=46
11:14:37 INFO:For field=modelno, computing actual max_str_len
11:14:37 INFO:For field=modelno, using actual_max_str_len=48
11:14:37 INFO:For field=price, computing actual max_str_len
11:14:37 INFO:For field=price, using actual_max_str_len=8


In [12]:
batch_size = 32
eval_batch_size = 256
datamodule = benchmark.build_datamodule(
    record_numericalizer=record_numericalizer,
    batch_size=batch_size,
    eval_batch_size=eval_batch_size,
    random_seed=random_seed
)

## Training

In [13]:
from entity_embed import Matcher

model = Matcher(
    record_numericalizer=record_numericalizer
)

In [14]:
trainer = model.fit(
    datamodule,
    min_epochs=5,
    max_epochs=100,
    check_val_every_n_epoch=1,
    early_stop_monitor="valid_f1_at_0.5",
    tb_save_dir='../tb_logs',
    tb_name=f'matcher-{benchmark.dataset_name}'
)

11:14:38 INFO:GPU available: True, used: True
11:14:38 INFO:TPU available: False, using: 0 TPU cores
11:14:38 INFO:Train positive pair count: 576
11:14:38 INFO:Train negative pair count: 5568
11:14:38 INFO:Valid positive pair count: 193
11:14:38 INFO:Valid positive pair count: 1856
11:14:38 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
11:15:56 INFO:
  | Name        | Type              | Params
--------------------------------------------------
0 | matcher_net | MatcherNet        | 32.9 M
1 | loss_fn     | BCEWithLogitsLoss | 0     
--------------------------------------------------
23.8 M    Trainable params
9.0 M     Non-trainable params
32.9 M    Total params
131.517   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

11:26:47 INFO:Loading the best validation model from ../tb_logs/matcher-Walmart-Amazon-Structured/version_0/checkpoints/epoch=10-step=2111.ckpt...


In [15]:
model.validate(datamodule)

{'valid_f1_at_0.3': tensor(0.6213),
 'valid_f1_at_0.5': tensor(0.6212),
 'valid_f1_at_0.7': tensor(0.6319),
 'valid_f1_at_0.9': tensor(0.5976),
 'valid_precision_at_0.3': tensor(0.5271),
 'valid_precision_at_0.5': tensor(0.5690),
 'valid_precision_at_0.7': tensor(0.6368),
 'valid_precision_at_0.9': tensor(0.6966),
 'valid_recall_at_0.3': tensor(0.7565),
 'valid_recall_at_0.5': tensor(0.6839),
 'valid_recall_at_0.7': tensor(0.6269),
 'valid_recall_at_0.9': tensor(0.5233)}

## Testing

In [16]:
model.test(datamodule)

11:26:56 INFO:Test positive pair count: 193
11:26:56 INFO:Test positive pair count: 1856


{'test_f1_at_0.3': tensor(0.5624),
 'test_f1_at_0.5': tensor(0.5694),
 'test_f1_at_0.7': tensor(0.5782),
 'test_f1_at_0.9': tensor(0.5697),
 'test_precision_at_0.3': tensor(0.4750),
 'test_precision_at_0.5': tensor(0.5216),
 'test_precision_at_0.7': tensor(0.5924),
 'test_precision_at_0.9': tensor(0.6667),
 'test_recall_at_0.3': tensor(0.6891),
 'test_recall_at_0.5': tensor(0.6269),
 'test_recall_at_0.7': tensor(0.5648),
 'test_recall_at_0.9': tensor(0.4974)}