## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

https://www.informatik.uni-leipzig.de/~saeedi/musicBrainz_readme.txt

```
5 sources
---------- 
TID: a unique record's id (in the complete dataset).
CID: cluster id (records having the same CID are duplicate)
CTID: a unique id within a cluster (if two records belong to the same cluster they will have the same CID but different CTIDs). These ids (CTID) start with 1 and grow until cluster size.
SourceID: identifies to which source a record belongs (there are five sources). The sources are deduplicated.
Id: the original id from the source. Each source has its own Id-Format. Uniqueness is not guaranteed!! (can be ignored).
number: track or song number in the album.
length: the length of the track.
artist: the interpreter (artist or band) of the track.
year: date of publication.
language: language of the track.
```

In [6]:
import glob
import csv
import tqdm

current_row_id = 0
row_dict = {}
rows_total = 19375
cluster_attr = 'CID'

with tqdm.tqdm(total=rows_total) as pbar:
    for filename in glob.glob(f'{home_dir}/Downloads/musicbrainz-20-A01.csv.dapo'):
        with open(filename) as f:
            for row in csv.DictReader(f):
                row['id'] = current_row_id
                row[cluster_attr] = int(row[cluster_attr])  # convert cluster_attr to int
                row_dict[current_row_id] = row
                current_row_id += 1
                pbar.update(1)

100%|██████████| 19375/19375 [00:00<00:00, 196132.69it/s]


In [7]:
row_dict[1]

{'TID': '2',
 'CID': 2512,
 'CTID': '5',
 'SourceID': '4',
 'id': 1,
 'number': '7',
 'title': '007',
 'length': '1m 58sec',
 'artist': '[unknown]',
 'album': 'Cantigas de roda (unknown)',
 'year': 'null',
 'language': 'Por.'}

## Preprocess

In [8]:
attr_list = ['title', 'artist', 'album']

In [9]:
import unidecode
from entity_embed.data_utils.one_hot_encoders import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    return ' '.join(s_part[:30] for s_part in default_tokenizer(s))[:100]

for row in row_dict.values():
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

## Get Pairs

In [10]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [11]:
from entity_embed.data_utils.utils import row_dict_to_cluster_dict

cluster_dict = row_dict_to_cluster_dict(row_dict, cluster_attr)
len(cluster_dict)

10000

In [12]:
from entity_embed.data_utils.utils import split_clusters

train_len = 1_500
valid_len = 2_500
train_cluster_dict, valid_cluster_dict, test_cluster_dict = split_clusters(
    cluster_dict, train_len=train_len, valid_len=valid_len, random_seed=random_seed)
display(len(train_cluster_dict), len(valid_cluster_dict), len(test_cluster_dict))

1500

2000

6500

In [13]:
for d in [train_cluster_dict, valid_cluster_dict, test_cluster_dict]:
    display(sum(1 for cluster_id_list in d.values() if len(cluster_id_list) > 1))

1500

2000

1500

In [14]:
from entity_embed.data_utils.utils import cluster_dict_to_id_pairs

train_true_pair_set = cluster_dict_to_id_pairs(train_cluster_dict)
display(len(train_true_pair_set))
del train_true_pair_set
valid_true_pair_set = cluster_dict_to_id_pairs(valid_cluster_dict)
display(len(valid_true_pair_set))
del valid_true_pair_set
test_true_pair_set = cluster_dict_to_id_pairs(test_cluster_dict)
display(len(test_true_pair_set))

4883

4883

4900

In [15]:
train_row_dict = {id_: row_dict[id_] for cluster_id_list in train_cluster_dict.values() for id_ in cluster_id_list}
len(train_row_dict)

4317

In [16]:
valid_row_dict = {id_: row_dict[id_] for cluster_id_list in valid_cluster_dict.values() for id_ in cluster_id_list}
len(valid_row_dict)

5742

In [17]:
test_row_dict = {id_: row_dict[id_] for cluster_id_list in test_cluster_dict.values() for id_ in cluster_id_list}
len(test_row_dict)

9316

## Training

In [18]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [19]:
# TODO: support multiple attrs for same source attr

attr_info_dict = {
    'title': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'artist': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'album': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [20]:
from entity_embed import EntityEmbed

model = EntityEmbed(
    attr_info_dict=attr_info_dict,
    row_dict=row_dict
)

22:07:48 INFO:For attr='title', computing actual alphabet and max_str_len
22:07:48 INFO:For attr='title', using actual_max_str_len=30
22:07:48 INFO:For attr='artist', computing actual alphabet and max_str_len
22:07:48 INFO:For attr='artist', using actual_max_str_len=30
22:07:48 INFO:For attr='album', computing actual alphabet and max_str_len
22:07:48 INFO:For attr='album', using actual_max_str_len=30


In [21]:
model.attr_info_dict

{'title': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7f9409c4cc10>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'artist': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7f9409c4cc10>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'album': OneHotEncodingInfo(is_multi

In [22]:
epochs = 20

In [23]:
model.train(
    epochs=epochs,
    train_row_dict=train_row_dict,
    cluster_attr=cluster_attr,
    train_pos_pair_batch_size=45,
    train_neg_pair_batch_size=1225,
    valid_row_dict=valid_row_dict,
    random_seed=random_seed
)

HBox(children=(HTML(value='# training'), FloatProgress(value=0.0, max=2180.0), HTML(value='')))

22:08:02 INFO:[('title', tensor(0.3458, device='cuda:0')), ('artist', tensor(0.3071, device='cuda:0')), ('album', tensor(0.3471, device='cuda:0'))]
22:08:05 INFO:# Train Epoch:   0 Time: 14.490 Loss: 0.158, Precision: 0.406 Recall: 0.965
22:08:17 INFO:[('title', tensor(0.3459, device='cuda:0')), ('artist', tensor(0.3086, device='cuda:0')), ('album', tensor(0.3455, device='cuda:0'))]
22:08:19 INFO:# Train Epoch:   1 Time: 14.420 Loss: 0.027, Precision: 0.550 Recall: 0.968
22:08:31 INFO:[('title', tensor(0.3488, device='cuda:0')), ('artist', tensor(0.3069, device='cuda:0')), ('album', tensor(0.3443, device='cuda:0'))]
22:08:34 INFO:# Train Epoch:   2 Time: 14.562 Loss: 0.017, Precision: 0.632 Recall: 0.968
22:08:46 INFO:[('title', tensor(0.3493, device='cuda:0')), ('artist', tensor(0.3054, device='cuda:0')), ('album', tensor(0.3453, device='cuda:0'))]
22:08:49 INFO:# Train Epoch:   3 Time: 14.722 Loss: 0.013, Precision: 0.681 Recall: 0.970
22:09:00 INFO:[('title', tensor(0.3505, device='




In [24]:
# torch.save(model, "music_model.torch")

In [25]:
# model = torch.load("music_model.torch")

In [26]:
test_vector_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=583.0), HTML(value='')))




In [27]:
assert len(test_vector_dict) == len(test_row_dict)

In [28]:
%%time

from entity_embed.entity_embed import ANNEntityIndex

ann_index = ANNEntityIndex(embedding_size=model.embedding_size)
ann_index.insert_vector_dict(test_vector_dict)
ann_index.build()

CPU times: user 3.61 s, sys: 30.4 ms, total: 3.64 s
Wall time: 435 ms


In [29]:
%%time

ntop = 10
sim_threshold = 0.5

found_pair_set = ann_index.search_pairs(
    k=ntop,
    sim_threshold=sim_threshold
)

CPU times: user 5.78 s, sys: 2.56 ms, total: 5.78 s
Wall time: 535 ms


In [30]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(found_pair_set, test_row_dict)

0.9377415199656505

In [31]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.5502518315018315, 0.9810204081632653)

In [32]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

3929

In [33]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

93

In [34]:
cos_similarity = lambda a, b: np.dot(a, b)

In [35]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.44871008,
 {'TID': '1970',
  'CID': 7215,
  'CTID': '2',
  'SourceID': '2',
  'id': 1969,
  'number': '2',
  'title': 'luci & anenakkav',
  'length': '185',
  'artist': '',
  'album': 'munad vahtu !',
  'year': '02',
  'language': 'Estonian'},
 {'TID': '17571',
  'CID': 7215,
  'CTID': '3',
  'SourceID': '3',
  'id': 17570,
  'number': '2',
  'title': 'nakkav - munad vahtu !',
  'length': '3.083',
  'artist': 'luci & ane',
  'album': '',
  'year': "'02",
  'language': 'Estonian'})

(0.15048532,
 {'TID': '2905',
  'CID': 2371,
  'CTID': '2',
  'SourceID': '3',
  'id': 2904,
  'number': '6',
  'title': 'unnown',
  'length': '4.4',
  'artist': 'sorg uten tarer',
  'album': '',
  'year': "'07",
  'language': 'English'},
 {'TID': '9261',
  'CID': 2371,
  'CTID': '3',
  'SourceID': '4',
  'id': 9260,
  'number': '6',
  'title': '006 - moonkissed eyes',
  'length': '4m 24sec',
  'artist': 'sorg uten tarer',
  'album': 'moonsilver ( 2007 )',
  'year': 'null',
  'language': 'Eng.'})

(0.4312421,
 {'TID': '919',
  'CID': 1149,
  'CTID': '2',
  'SourceID': '5',
  'id': 918,
  'number': '7',
  'title': 'hard',
  'length': '210320',
  'artist': 'asia',
  'album': '',
  'year': '1994',
  'language': 'English'},
 {'TID': '8988',
  'CID': 1149,
  'CTID': '4',
  'SourceID': '2',
  'id': 8987,
  'number': '7',
  'title': 'asia - hard on me',
  'length': '210',
  'artist': '',
  'album': 'gold',
  'year': '05',
  'language': 'English'})

(0.3204738,
 {'TID': '2965',
  'CID': 1544,
  'CTID': '1',
  'SourceID': '2',
  'id': 2964,
  'number': '8',
  'title': 'darryl way - juliet',
  'length': '247',
  'artist': '',
  'album': 'under the soft',
  'year': '91',
  'language': 'English'},
 {'TID': '18300',
  'CID': 1544,
  'CTID': '2',
  'SourceID': '3',
  'id': 18299,
  'number': '8',
  'title': '8',
  'length': '4.117',
  'artist': 'darryl way',
  'album': '',
  'year': "'91",
  'language': 'English'})

(0.17702493,
 {'TID': '9864',
  'CID': 7993,
  'CTID': '2',
  'SourceID': '4',
  'id': 9863,
  'number': '24',
  'title': '024 -',
  'length': '0m 9sec',
  'artist': '',
  'album': 'tv ! !! () ( 2009 )',
  'year': 'null',
  'language': 'Jap.'},
 {'TID': '15482',
  'CID': 7993,
  'CTID': '1',
  'SourceID': '3',
  'id': 15481,
  'number': '24',
  'title': 'kibodointorodakushiyon - tvanime [ keion !] ohuishiyaru bandoyaroyo !! ( bandosukoafu )',
  'length': '0.155',
  'artist': 'fang ke hou teitaimu',
  'album': '',
  'year': "'09",
  'language': 'Japanese'})

(0.20932999,
 {'TID': '9864',
  'CID': 7993,
  'CTID': '2',
  'SourceID': '4',
  'id': 9863,
  'number': '24',
  'title': '024 -',
  'length': '0m 9sec',
  'artist': '',
  'album': 'tv ! !! () ( 2009 )',
  'year': 'null',
  'language': 'Jap.'},
 {'TID': '17288',
  'CID': 7993,
  'CTID': '4',
  'SourceID': '1',
  'id': 17287,
  'number': '024',
  'title': 'kibodointorodakushiyon ( tvanime [ keion !] ohuishiyaru bandoyaroyo !! ( bandosukoafu ))',
  'length': '00:09',
  'artist': 'fang ke hou teitaimu',
  'album': 'tvanime [ keion !] ohuishiyaru bandoyaroyo !! ( bandosukoafu )',
  'year': '2009',
  'language': ''})

(0.46784,
 {'TID': '11097',
  'CID': 5243,
  'CTID': '3',
  'SourceID': '1',
  'id': 11096,
  'number': '006',
  'title': 'apressadinho ( forro de vanguarda )',
  'length': '03:12',
  'artist': 'genario',
  'album': 'forro de vanguarda',
  'year': '1981',
  'language': ''},
 {'TID': '19173',
  'CID': 5243,
  'CTID': '2',
  'SourceID': '5',
  'id': 19172,
  'number': '6',
  'title': 'apressadinho',
  'length': '192000',
  'artist': 'gena ! rio',
  'album': 'forra3devanguarda',
  'year': '1981',
  'language': 'Portu  guese'})

(0.21118999,
 {'TID': '2905',
  'CID': 2371,
  'CTID': '2',
  'SourceID': '3',
  'id': 2904,
  'number': '6',
  'title': 'unnown',
  'length': '4.4',
  'artist': 'sorg uten tarer',
  'album': '',
  'year': "'07",
  'language': 'English'},
 {'TID': '7677',
  'CID': 2371,
  'CTID': '4',
  'SourceID': '5',
  'id': 7676,
  'number': '6',
  'title': 'moonkissed eyes',
  'length': '264000',
  'artist': 'sorg uten tarer',
  'album': 'moonsilver',
  'year': '2007',
  'language': 'English'})

(0.46546796,
 {'TID': '16095',
  'CID': 9077,
  'CTID': '2',
  'SourceID': '3',
  'id': 16094,
  'number': '1',
  'title': 'almost brave - live at the fireside bowl',
  'length': '2.076',
  'artist': "$ wingin ' utter $",
  'album': '',
  'year': "'95",
  'language': 'English'},
 {'TID': '17602',
  'CID': 9077,
  'CTID': '1',
  'SourceID': '2',
  'id': 17601,
  'number': '1',
  'title': "$ wingin ' utter $ - almost brave",
  'length': '124',
  'artist': '',
  'album': 'live at the fireside bowl',
  'year': '95',
  'language': 'English'})

(0.4938655,
 {'TID': '11429',
  'CID': 5243,
  'CTID': '4',
  'SourceID': '2',
  'id': 11428,
  'number': '6',
  'title': 'genario - apressadinho',
  'length': '192',
  'artist': '',
  'album': 'forro de vanguarda',
  'year': '81',
  'language': 'Portuguese'},
 {'TID': '19173',
  'CID': 5243,
  'CTID': '2',
  'SourceID': '5',
  'id': 19172,
  'number': '6',
  'title': 'apressadinho',
  'length': '192000',
  'artist': 'gena ! rio',
  'album': 'forra3devanguarda',
  'year': '1981',
  'language': 'Portu  guese'})