## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

https://www.informatik.uni-leipzig.de/~saeedi/musicBrainz_readme.txt

```
5 sources
---------- 
TID: a unique record's id (in the complete dataset).
CID: cluster id (records having the same CID are duplicate)
CTID: a unique id within a cluster (if two records belong to the same cluster they will have the same CID but different CTIDs). These ids (CTID) start with 1 and grow until cluster size.
SourceID: identifies to which source a record belongs (there are five sources). The sources are deduplicated.
Id: the original id from the source. Each source has its own Id-Format. Uniqueness is not guaranteed!! (can be ignored).
number: track or song number in the album.
length: the length of the track.
artist: the interpreter (artist or band) of the track.
year: date of publication.
language: language of the track.
```

In [6]:
import glob
import csv
from tqdm.auto import tqdm

current_row_id = 0
row_dict = {}
rows_total = 19375
clusters_total = 10000
cluster_attr = 'CID'

with tqdm(total=rows_total) as pbar:
    with open(f'{home_dir}/Downloads/musicbrainz-20-A01.csv.dapo') as f:
        for row in csv.DictReader(f):
            row['id'] = current_row_id
            row[cluster_attr] = int(row[cluster_attr])  # convert cluster_attr to int
            row_dict[current_row_id] = row
            current_row_id += 1
            pbar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=19375.0), HTML(value='')))




## Preprocess

In [7]:
attr_list = ['title', 'artist', 'album']

In [8]:
import unidecode
from entity_embed.data_utils.one_hot_encoders import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    return ' '.join(s_part[:30] for s_part in default_tokenizer(s))[:100]

for row in tqdm(row_dict.values()):
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=19375.0), HTML(value='')))




## Init Data Module

In [9]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [10]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [11]:
attr_info_dict = {
    'title': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'artist': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'album': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [12]:
from entity_embed import MultiSigDedupEmbed

train_cluster_len = 1500
valid_cluster_len = 2500
ann_k = 10
use_mask = True
model = MultiSigDedupEmbed(
    # data kwargs
    row_dict=row_dict,
    attr_info_dict=attr_info_dict,
    cluster_attr=cluster_attr,
    pos_pair_batch_size=45,
    neg_pair_batch_size=1225,
    row_batch_size=16,
    train_cluster_len=train_cluster_len,
    valid_cluster_len=valid_cluster_len,
    test_cluster_len=clusters_total - valid_cluster_len - train_cluster_len,
    only_plural_clusters=True,
    random_seed=random_seed,
    # model kwargs
    use_mask=use_mask,
    ann_k=ann_k,
)

11:14:34 INFO:For attr='title', computing actual alphabet and max_str_len
11:14:34 INFO:For attr='title', using actual_max_str_len=30
11:14:34 INFO:For attr='artist', computing actual alphabet and max_str_len
11:14:34 INFO:For attr='artist', using actual_max_str_len=30
11:14:34 INFO:For attr='album', computing actual alphabet and max_str_len
11:14:34 INFO:For attr='album', using actual_max_str_len=30


## Training

In [13]:
gpus = 1
max_epochs = 50
check_val_every_n_epoch = 1
early_stopping_monitor = 'valid_recall_at_0.5'
tb_log_dir = 'tb_logs'
tb_name = 'music'

model.fit(
    gpus=gpus,
    max_epochs=max_epochs,
    check_val_every_n_epoch=check_val_every_n_epoch,
    early_stopping_monitor=early_stopping_monitor,
    tb_log_dir=tb_log_dir,
    tb_name=tb_name,
)

11:14:34 INFO:Fit model_sig_i=0, learning signature with unused_attr_list=['title', 'artist', 'album']
11:14:34 INFO:GPU available: True, used: True
11:14:34 INFO:TPU available: None, using: 0 TPU cores
11:14:34 INFO:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
11:14:34 INFO:Train pair count: 4883
11:14:34 INFO:Valid pair count: 8092
11:14:34 INFO:Test pair count: 3275
11:14:36 INFO:
  | Name        | Type           | Params
-----------------------------------------------
0 | blocker_net | BlockerNet     | 3.4 M 
1 | losser      | NTXentLoss     | 0     
2 | miner       | BatchHardMiner | 0     
-----------------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [14]:
for lt_module in model.lt_module_list:
    display(lt_module.get_signature_weights())

{'title': 0.3368707597255707,
 'artist': 0.25958314538002014,
 'album': 0.40354612469673157}

## Testing manually 

In [15]:
model.datamodule.setup(stage='test')

11:27:07 INFO:Train pair count: 4883
11:27:07 INFO:Valid pair count: 8092
11:27:07 INFO:Test pair count: 3275


In [16]:
test_row_dict = model.datamodule.test_row_dict
test_multisig_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16
)
test_multisig_dict.keys()

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=181.0), HTML(value='')))




dict_keys([('title', 'artist', 'album')])

In [17]:
test_true_pair_set = model.datamodule.test_true_pair_set
len(test_true_pair_set)

3275

In [18]:
%%time

from entity_embed import MultiSigANNDedupIndex

ann_index = MultiSigANNDedupIndex(
    multisig_dict_keys=model.multisig_dict_keys,
    embedding_size=model.embedding_size,
)
ann_index.insert_multisig_dict(test_multisig_dict)
ann_index.build()

CPU times: user 691 ms, sys: 3.06 ms, total: 694 ms
Wall time: 103 ms


In [33]:
%%time

sim_threshold_dict = {
    ('title', 'artist', 'album'): 0.3,
}
found_pair_set = ann_index.search_pairs(
    k=ann_k,
    sim_threshold_dict=sim_threshold_dict,
)

CPU times: user 908 ms, sys: 3.63 ms, total: 911 ms
Wall time: 88 ms


In [34]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(found_pair_set), len(test_row_dict))

5.662396121883656

In [35]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.19721152082186755, 0.9847328244274809)

In [36]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

13128

In [37]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

50

In [38]:
cos_similarity = lambda a, b: np.dot(a, b)

In [39]:
test_multisig_dict.keys()

dict_keys([('title', 'artist', 'album')])

In [40]:
test_vector_dict = \
    test_multisig_dict[('title', 'artist', 'album')]

for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.14874573,
 {'TID': '4914',
  'CID': 8662,
  'CTID': '2',
  'SourceID': '4',
  'id': 4913,
  'number': '14',
  'title': '014 - suite andalucia : gitanerias',
  'length': '1m 58sec',
  'artist': 'frank flynn emilio',
  'album': 'tribute to ernesto lecuona ( 1997 )',
  'year': 'null',
  'language': 'pa.'},
 {'TID': '16840',
  'CID': 8662,
  'CTID': '1',
  'SourceID': '3',
  'id': 16839,
  'number': 'Suite Andalucia: Gitanerias - Tribute to Ernesto Lecuona',
  'title': '14',
  'length': '1.974',
  'artist': 'frank emilio flynn',
  'album': '',
  'year': "'97",
  'language': 'Spanish'})

(0.22217415,
 {'TID': '3422',
  'CID': 8904,
  'CTID': '3',
  'SourceID': '2',
  'id': 3421,
  'number': '1',
  'title': 'null',
  'length': '364',
  'artist': '',
  'album': 'untitled 2 / bad brother',
  'year': '05',
  'language': 'English'},
 {'TID': '12727',
  'CID': 8904,
  'CTID': '4',
  'SourceID': '3',
  'id': 12726,
  'number': '1',
  'title': 'untitled 2 - untitled 2 / bad brother',
  'length': '6.067',
  'artist': 'phrenetic',
  'album': '',
  'year': "'05",
  'language': 'English'})

(0.11789996,
 {'TID': '15081',
  'CID': 5099,
  'CTID': '5',
  'SourceID': '1',
  'id': 15080,
  'number': '011',
  'title': "i ' ve never been to me ( hits of the 80s : the ultimate collection )",
  'length': '04:00',
  'artist': 'chrlene',
  'album': 'hits of the 80s : the ultimate collection',
  'year': '2006',
  'language': ''},
 {'TID': '17910',
  'CID': 5099,
  'CTID': '4',
  'SourceID': '5',
  'id': 17909,
  'number': '5612990',
  'title': '11',
  'length': "I've Never Been to Me",
  'artist': '240440',
  'album': 'charlene',
  'year': 'Hits of the 80s: The Ultimate Collection',
  'language': '2006'})

(0.026767042,
 {'TID': '2905',
  'CID': 2371,
  'CTID': '2',
  'SourceID': '3',
  'id': 2904,
  'number': '6',
  'title': 'unnown',
  'length': '4.4',
  'artist': 'sorg uten tarer',
  'album': '',
  'year': "'07",
  'language': 'English'},
 {'TID': '9261',
  'CID': 2371,
  'CTID': '3',
  'SourceID': '4',
  'id': 9260,
  'number': '6',
  'title': '006 - moonkissed eyes',
  'length': '4m 24sec',
  'artist': 'sorg uten tarer',
  'album': 'moonsilver ( 2007 )',
  'year': 'null',
  'language': 'Eng.'})

(0.26439455,
 {'TID': '8494',
  'CID': 4390,
  'CTID': '1',
  'SourceID': '1',
  'id': 8493,
  'number': '9',
  'title': 'one ton ( volunteered slavery )',
  'length': '05:01',
  'artist': 'rahsaan roland kirk',
  'album': 'volunteered slavery',
  'year': '2005',
  'language': ''},
 {'TID': '19331',
  'CID': 4390,
  'CTID': '2',
  'SourceID': '2',
  'id': 19330,
  'number': '7',
  'title': 'rahsaan roland kirk - one ton',
  'length': '301',
  'artist': '',
  'album': 'unknown',
  'year': '05',
  'language': 'English'})

(0.332804,
 {'TID': '2208',
  'CID': 1154,
  'CTID': '1',
  'SourceID': '1',
  'id': 2207,
  'number': '020',
  'title': 'space jerk ( promos 5 : kidz stuff )',
  'length': '00:29',
  'artist': 'kevin jarvis',
  'album': 'promos 5 : kidz stuff',
  'year': '',
  'language': ''},
 {'TID': '14346',
  'CID': 1154,
  'CTID': '3',
  'SourceID': '3',
  'id': 14345,
  'number': '20',
  'title': 'nulll',
  'length': '0.483',
  'artist': 'evin jarvis',
  'album': '',
  'year': '',
  'language': 'English'})

(0.021027733,
 {'TID': '2965',
  'CID': 1544,
  'CTID': '1',
  'SourceID': '2',
  'id': 2964,
  'number': '8',
  'title': 'darryl way - juliet',
  'length': '247',
  'artist': '',
  'album': 'under the soft',
  'year': '91',
  'language': 'English'},
 {'TID': '18300',
  'CID': 1544,
  'CTID': '2',
  'SourceID': '3',
  'id': 18299,
  'number': '8',
  'title': '8',
  'length': '4.117',
  'artist': 'darryl way',
  'album': '',
  'year': "'91",
  'language': 'English'})

(0.29932302,
 {'TID': '9364',
  'CID': 901,
  'CTID': '2',
  'SourceID': '2',
  'id': 9363,
  'number': 'A5',
  'title': 'du bian man li nai - marinanoxia',
  'length': 'null',
  'artist': '',
  'album': "oniyanzi sailing meng gong chang ' 87 live",
  'year': '87',
  'language': 'Japanese'},
 {'TID': '16923',
  'CID': 901,
  'CTID': '4',
  'SourceID': '4',
  'id': 16922,
  'number': 'A5',
  'title': '0a5 -',
  'length': 'unknown',
  'artist': '',
  'album': 'sailin 87 live ( 1987 )',
  'year': 'null',
  'language': 'Jap.'})

(0.22890994,
 {'TID': '895',
  'CID': 2198,
  'CTID': '2',
  'SourceID': '3',
  'id': 894,
  'number': '4',
  'title': 'the cn tower belongs to the dead ( many ives version ) - many lives - 49 mp',
  'length': '5.467',
  'artist': 'final fantasy',
  'album': '',
  'year': "'06",
  'language': 'English'},
 {'TID': '4240',
  'CID': 2198,
  'CTID': '1',
  'SourceID': '2',
  'id': 4239,
  'number': '4',
  'title': 'n . a .',
  'length': '328',
  'artist': '',
  'album': 'many lives - 49 mp',
  'year': '06',
  'language': 'English'})

(0.17193685,
 {'TID': '5427',
  'CID': 5816,
  'CTID': '2',
  'SourceID': '3',
  'id': 5426,
  'number': '1',
  'title': 'null',
  'length': '3.45',
  'artist': 'akvarium',
  'album': '',
  'year': "'99",
  'language': 'Russian'},
 {'TID': '11177',
  'CID': 5816,
  'CTID': '1',
  'SourceID': '2',
  'id': 11176,
  'number': '1',
  'title': 'akvarium - russkaia nirvana',
  'length': '207',
  'artist': '',
  'album': 'luchshie pesni ( disc 3 : 1991 - 1996 )',
  'year': '99',
  'language': 'Russian'})