## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

https://www.informatik.uni-leipzig.de/~saeedi/musicBrainz_readme.txt

```
5 sources
---------- 
TID: a unique record's id (in the complete dataset).
CID: cluster id (records having the same CID are duplicate)
CTID: a unique id within a cluster (if two records belong to the same cluster they will have the same CID but different CTIDs). These ids (CTID) start with 1 and grow until cluster size.
SourceID: identifies to which source a record belongs (there are five sources). The sources are deduplicated.
Id: the original id from the source. Each source has its own Id-Format. Uniqueness is not guaranteed!! (can be ignored).
number: track or song number in the album.
length: the length of the track.
artist: the interpreter (artist or band) of the track.
year: date of publication.
language: language of the track.
```

In [6]:
import glob
import csv
import tqdm

current_row_id = 0
row_dict = {}
rows_total = 19375
cluster_attr = 'CID'

with tqdm.tqdm(total=rows_total) as pbar:
    for filename in glob.glob(f'{home_dir}/Downloads/musicbrainz-20-A01.csv.dapo'):
        with open(filename) as f:
            for row in csv.DictReader(f):
                row['id'] = current_row_id
                row[cluster_attr] = int(row[cluster_attr])  # convert cluster_attr to int
                row_dict[current_row_id] = row
                current_row_id += 1
                pbar.update(1)

100%|██████████| 19375/19375 [00:00<00:00, 199046.32it/s]


In [7]:
row_dict[1]

{'TID': '2',
 'CID': 2512,
 'CTID': '5',
 'SourceID': '4',
 'id': 1,
 'number': '7',
 'title': '007',
 'length': '1m 58sec',
 'artist': '[unknown]',
 'album': 'Cantigas de roda (unknown)',
 'year': 'null',
 'language': 'Por.'}

## Preprocess

In [8]:
attr_list = ['title', 'artist', 'album']

In [9]:
import unidecode
from entity_embed.data_utils.one_hot_encoders import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    return ' '.join(s_part[:30] for s_part in default_tokenizer(s))[:100]

for row in row_dict.values():
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

## Get Pairs

In [10]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [11]:
from entity_embed.data_utils.utils import row_dict_to_cluster_dict

cluster_dict = row_dict_to_cluster_dict(row_dict, cluster_attr)
len(cluster_dict)

10000

In [12]:
from entity_embed.data_utils.utils import split_clusters

train_len = 1_500
valid_len = 2_500
train_cluster_dict, valid_cluster_dict, test_cluster_dict = split_clusters(
    cluster_dict, train_len=train_len, valid_len=valid_len, random_seed=random_seed)
display(len(train_cluster_dict), len(valid_cluster_dict), len(test_cluster_dict))

1500

2500

6000

In [13]:
for d in [train_cluster_dict, valid_cluster_dict, test_cluster_dict]:
    display(sum(1 for cluster_id_list in d.values() if len(cluster_id_list) > 1))

1500

2500

1000

In [14]:
from entity_embed.data_utils.utils import cluster_dict_to_id_pairs

train_true_pair_set = cluster_dict_to_id_pairs(train_cluster_dict)
display(len(train_true_pair_set))
del train_true_pair_set
valid_true_pair_set = cluster_dict_to_id_pairs(valid_cluster_dict)
display(len(valid_true_pair_set))
del valid_true_pair_set
test_true_pair_set = cluster_dict_to_id_pairs(test_cluster_dict)
display(len(test_true_pair_set))

4883

8092

3275

In [15]:
train_row_dict = {id_: row_dict[id_] for cluster_id_list in train_cluster_dict.values() for id_ in cluster_id_list}
len(train_row_dict)

4317

In [16]:
valid_row_dict = {id_: row_dict[id_] for cluster_id_list in valid_cluster_dict.values() for id_ in cluster_id_list}
len(valid_row_dict)

7170

In [17]:
test_row_dict = {id_: row_dict[id_] for cluster_id_list in test_cluster_dict.values() for id_ in cluster_id_list}
len(test_row_dict)

7888

## Training

In [18]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [19]:
# TODO: support multiple attrs for same source attr

attr_info_dict = {
    'title': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'artist': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'album': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [20]:
from entity_embed import EntityEmbed

model = EntityEmbed(
    attr_info_dict=attr_info_dict,
    row_dict=row_dict
)

22:08:23 INFO:For attr='title', computing actual alphabet and max_str_len
22:08:23 INFO:For attr='title', using actual_max_str_len=30
22:08:23 INFO:For attr='artist', computing actual alphabet and max_str_len
22:08:23 INFO:For attr='artist', using actual_max_str_len=30
22:08:23 INFO:For attr='album', computing actual alphabet and max_str_len
22:08:23 INFO:For attr='album', using actual_max_str_len=30


In [21]:
model.attr_info_dict

{'title': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7ff041bd4940>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'artist': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7ff041bd4940>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'album': OneHotEncodingInfo(is_multi

In [22]:
epochs = 20

In [23]:
model.train(
    epochs=epochs,
    train_row_dict=train_row_dict,
    cluster_attr=cluster_attr,
    train_pos_pair_batch_size=45,
    train_neg_pair_batch_size=1225,
    valid_row_dict=valid_row_dict,
    random_seed=random_seed
)

HBox(children=(HTML(value='# training'), FloatProgress(value=0.0, max=2180.0), HTML(value='')))

22:08:37 INFO:[('title', tensor(0.3444, device='cuda:0')), ('artist', tensor(0.3104, device='cuda:0')), ('album', tensor(0.3453, device='cuda:0'))]
22:08:41 INFO:# Train Epoch:   0 Time: 15.551 Loss: 0.149, Precision: 0.308 Recall: 0.972
22:08:53 INFO:[('title', tensor(0.3477, device='cuda:0')), ('artist', tensor(0.3093, device='cuda:0')), ('album', tensor(0.3429, device='cuda:0'))]
22:08:56 INFO:# Train Epoch:   1 Time: 15.195 Loss: 0.027, Precision: 0.505 Recall: 0.972
22:09:08 INFO:[('title', tensor(0.3500, device='cuda:0')), ('artist', tensor(0.3071, device='cuda:0')), ('album', tensor(0.3428, device='cuda:0'))]
22:09:12 INFO:# Train Epoch:   2 Time: 15.600 Loss: 0.016, Precision: 0.568 Recall: 0.974
22:09:24 INFO:[('title', tensor(0.3510, device='cuda:0')), ('artist', tensor(0.3054, device='cuda:0')), ('album', tensor(0.3436, device='cuda:0'))]
22:09:28 INFO:# Train Epoch:   3 Time: 16.084 Loss: 0.012, Precision: 0.608 Recall: 0.975
22:09:40 INFO:[('title', tensor(0.3522, device='




In [24]:
# torch.save(model, "music_model.torch")

In [25]:
# model = torch.load("music_model.torch")

In [26]:
test_vector_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=493.0), HTML(value='')))




In [27]:
assert len(test_vector_dict) == len(test_row_dict)

In [28]:
%%time

from entity_embed.entity_embed import ANNEntityIndex

ann_index = ANNEntityIndex(embedding_size=model.embedding_size)
ann_index.insert_vector_dict(test_vector_dict)
ann_index.build()

CPU times: user 2.91 s, sys: 32.2 ms, total: 2.94 s
Wall time: 351 ms


In [29]:
%%time

ntop = 10
sim_threshold = 0.5

found_pair_set = ann_index.search_pairs(
    k=ntop,
    sim_threshold=sim_threshold
)

CPU times: user 4.56 s, sys: 2.44 ms, total: 4.56 s
Wall time: 424 ms


In [30]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(found_pair_set, test_row_dict)

0.7682555780933062

In [31]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.53003300330033, 0.980763358778626)

In [32]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

2848

In [33]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

63

In [34]:
cos_similarity = lambda a, b: np.dot(a, b)

In [35]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.4264195,
 {'TID': '4914',
  'CID': 8662,
  'CTID': '2',
  'SourceID': '4',
  'id': 4913,
  'number': '14',
  'title': '014 - suite andalucia : gitanerias',
  'length': '1m 58sec',
  'artist': 'frank flynn emilio',
  'album': 'tribute to ernesto lecuona ( 1997 )',
  'year': 'null',
  'language': 'pa.'},
 {'TID': '16840',
  'CID': 8662,
  'CTID': '1',
  'SourceID': '3',
  'id': 16839,
  'number': 'Suite Andalucia: Gitanerias - Tribute to Ernesto Lecuona',
  'title': '14',
  'length': '1.974',
  'artist': 'frank emilio flynn',
  'album': '',
  'year': "'97",
  'language': 'Spanish'})

(-0.012160443,
 {'TID': '15081',
  'CID': 5099,
  'CTID': '5',
  'SourceID': '1',
  'id': 15080,
  'number': '011',
  'title': "i ' ve never been to me ( hits of the 80s : the ultimate collection )",
  'length': '04:00',
  'artist': 'chrlene',
  'album': 'hits of the 80s : the ultimate collection',
  'year': '2006',
  'language': ''},
 {'TID': '17910',
  'CID': 5099,
  'CTID': '4',
  'SourceID': '5',
  'id': 17909,
  'number': '5612990',
  'title': '11',
  'length': "I've Never Been to Me",
  'artist': '240440',
  'album': 'charlene',
  'year': 'Hits of the 80s: The Ultimate Collection',
  'language': '2006'})

(0.36968943,
 {'TID': '1819',
  'CID': 966,
  'CTID': '1',
  'SourceID': '5',
  'id': 1818,
  'number': '9',
  'title': 'barcarolleinfsharp , op . 60',
  'length': '451586',
  'artist': 'fryderykchopin',
  'album': 'kefultimateresolutiondemonstra , editionno . 2',
  'year': '2010',
  'language': 'English'},
 {'TID': '9052',
  'CID': 966,
  'CTID': '2',
  'SourceID': '1',
  'id': 9051,
  'number': '009',
  'title': 'barcarolle in f sharp , op . 60 ( kef ultimate resolution demonstration disc , edition no . 2 )',
  'length': '07:31',
  'artist': 'fryderyk chopin',
  'album': 'kef ultimate resolution demonstration disc , edition no . 2',
  'year': '2010',
  'language': ''})

(0.09819428,
 {'TID': '2905',
  'CID': 2371,
  'CTID': '2',
  'SourceID': '3',
  'id': 2904,
  'number': '6',
  'title': 'unnown',
  'length': '4.4',
  'artist': 'sorg uten tarer',
  'album': '',
  'year': "'07",
  'language': 'English'},
 {'TID': '9261',
  'CID': 2371,
  'CTID': '3',
  'SourceID': '4',
  'id': 9260,
  'number': '6',
  'title': '006 - moonkissed eyes',
  'length': '4m 24sec',
  'artist': 'sorg uten tarer',
  'album': 'moonsilver ( 2007 )',
  'year': 'null',
  'language': 'Eng.'})

(0.27951378,
 {'TID': '10134',
  'CID': 5327,
  'CTID': '3',
  'SourceID': '2',
  'id': 10133,
  'number': '14',
  'title': '[ unknown ] - action in the dusk',
  'length': 'null',
  'artist': '',
  'album': 'gaewa neugdaeyi sigan',
  'year': '',
  'language': 'Korean'},
 {'TID': '17936',
  'CID': 5327,
  'CTID': '5',
  'SourceID': '4',
  'id': 17935,
  'number': '14',
  'title': '014 - action in the dusk',
  'length': 'unknown',
  'artist': '[ unknown ]',
  'album': '( unknown )',
  'year': 'null',
  'language': 'Kor.'})

(0.42684764,
 {'TID': '2208',
  'CID': 1154,
  'CTID': '1',
  'SourceID': '1',
  'id': 2207,
  'number': '020',
  'title': 'space jerk ( promos 5 : kidz stuff )',
  'length': '00:29',
  'artist': 'kevin jarvis',
  'album': 'promos 5 : kidz stuff',
  'year': '',
  'language': ''},
 {'TID': '14346',
  'CID': 1154,
  'CTID': '3',
  'SourceID': '3',
  'id': 14345,
  'number': '20',
  'title': 'nulll',
  'length': '0.483',
  'artist': 'evin jarvis',
  'album': '',
  'year': '',
  'language': 'English'})

(0.39567196,
 {'TID': '8494',
  'CID': 4390,
  'CTID': '1',
  'SourceID': '1',
  'id': 8493,
  'number': '9',
  'title': 'one ton ( volunteered slavery )',
  'length': '05:01',
  'artist': 'rahsaan roland kirk',
  'album': 'volunteered slavery',
  'year': '2005',
  'language': ''},
 {'TID': '19331',
  'CID': 4390,
  'CTID': '2',
  'SourceID': '2',
  'id': 19330,
  'number': '7',
  'title': 'rahsaan roland kirk - one ton',
  'length': '301',
  'artist': '',
  'album': 'unknown',
  'year': '05',
  'language': 'English'})

(0.33611,
 {'TID': '2965',
  'CID': 1544,
  'CTID': '1',
  'SourceID': '2',
  'id': 2964,
  'number': '8',
  'title': 'darryl way - juliet',
  'length': '247',
  'artist': '',
  'album': 'under the soft',
  'year': '91',
  'language': 'English'},
 {'TID': '18300',
  'CID': 1544,
  'CTID': '2',
  'SourceID': '3',
  'id': 18299,
  'number': '8',
  'title': '8',
  'length': '4.117',
  'artist': 'darryl way',
  'album': '',
  'year': "'91",
  'language': 'English'})

(0.34793422,
 {'TID': '9364',
  'CID': 901,
  'CTID': '2',
  'SourceID': '2',
  'id': 9363,
  'number': 'A5',
  'title': 'du bian man li nai - marinanoxia',
  'length': 'null',
  'artist': '',
  'album': "oniyanzi sailing meng gong chang ' 87 live",
  'year': '87',
  'language': 'Japanese'},
 {'TID': '16923',
  'CID': 901,
  'CTID': '4',
  'SourceID': '4',
  'id': 16922,
  'number': 'A5',
  'title': '0a5 -',
  'length': 'unknown',
  'artist': '',
  'album': 'sailin 87 live ( 1987 )',
  'year': 'null',
  'language': 'Jap.'})

(0.4339402,
 {'TID': '1050',
  'CID': 8798,
  'CTID': '2',
  'SourceID': '3',
  'id': 1049,
  'number': '27',
  'title': 'r . i . p . ( millie ) - the essential',
  'length': '4.583',
  'artist': 'noiseworks',
  'album': '',
  'year': '',
  'language': 'English'},
 {'TID': '17095',
  'CID': 8798,
  'CTID': '1',
  'SourceID': '2',
  'id': 17094,
  'number': '16',
  'title': 'noiseworks - r . i . p . ( millie )',
  'length': '275',
  'artist': '',
  'album': 'the essential',
  'year': '',
  'language': 'English'})