In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [None]:
import sys

sys.path.insert(0, '../..')

## Load Dataset

In [None]:
from entity_embed.benchmarks import CompanyBenchmark

benchmark = CompanyBenchmark(data_dir_path="../data/")
benchmark

## Preprocess

In [None]:
attr_list = ['content']

In [None]:
from tqdm.auto import tqdm
import unidecode
import itertools
from entity_embed import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    s_tokens = itertools.islice((s_part[:30] for s_part in default_tokenizer(s)), 0, 100)
    return ' '.join(s_tokens)

for row in tqdm(benchmark.row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

## Init Data Module

In [None]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [None]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [None]:
attr_info_dict = {
    'content': {
        'field_type': "MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'alphabet': alphabet,
        'max_str_len': None,  # compute
        'use_mask': True,
    },
    'semantic_content': {
        'source_attr': 'content',
        'field_type': "SEMANTIC_MULTITOKEN",
        'tokenizer': "entity_embed.default_tokenizer",
        'vocab': "fasttext.en.300d",
        'use_mask': True,
    },
}

In [None]:
from entity_embed import AttrInfoDictParser

row_numericalizer = AttrInfoDictParser.from_dict(attr_info_dict, row_dict=benchmark.row_dict)
row_numericalizer.attr_info_dict

In [None]:
datamodule = benchmark.build_pairwise_datamodule(
    row_numericalizer=row_numericalizer,
    batch_size=50,
    row_batch_size=16,
    random_seed=random_seed
)

## Training

In [None]:
from entity_embed import LinkageEmbed

ann_k = 100
model = LinkageEmbed(
    datamodule,
    ann_k=ann_k,
    embedding_size=300
)

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

max_epochs = 20
early_stop_callback = EarlyStopping(
   monitor='valid_f1_at_0.9',
   min_delta=0.00,
   patience=3,
   verbose=True,
   mode='max'
)
tb_log_dir = '../tb_logs'
tb_name = f'f1-{benchmark.dataset_name}'
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback],
    logger=TensorBoardLogger(tb_log_dir, name=tb_name)
)

In [None]:
trainer.fit(model, datamodule)

In [None]:
model.blocker_net.get_signature_weights()

In [None]:
from entity_embed import validate_best

validate_best(trainer)

## Testing

In [None]:
trainer.test(ckpt_path='best', verbose=False)