## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
# libgomp issue, must import n2 before torch
from n2 import HnswIndex

In [4]:
import sys

sys.path.insert(0, '..')

In [5]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

https://www.informatik.uni-leipzig.de/~saeedi/musicBrainz_readme.txt

```
5 sources
---------- 
TID: a unique record's id (in the complete dataset).
CID: cluster id (records having the same CID are duplicate)
CTID: a unique id within a cluster (if two records belong to the same cluster they will have the same CID but different CTIDs). These ids (CTID) start with 1 and grow until cluster size.
SourceID: identifies to which source a record belongs (there are five sources). The sources are deduplicated.
Id: the original id from the source. Each source has its own Id-Format. Uniqueness is not guaranteed!! (can be ignored).
number: track or song number in the album.
length: the length of the track.
artist: the interpreter (artist or band) of the track.
year: date of publication.
language: language of the track.
```

In [6]:
import glob
import csv
import tqdm

current_row_id = 0
row_dict = {}
rows_total = 1937500
cluster_attr = 'CID'

with tqdm.tqdm(total=rows_total) as pbar:
    for filename in glob.glob(f'{home_dir}/Downloads/musicbrainz-2000-A01.csv.dapo'):
        with open(filename) as f:
            for row in csv.DictReader(f):
                row['id'] = current_row_id
                row[cluster_attr] = int(row[cluster_attr])  # convert cluster_attr to int
                row_dict[current_row_id] = row
                current_row_id += 1
                pbar.update(1)

100%|██████████| 1937500/1937500 [00:08<00:00, 229301.06it/s]


In [7]:
row_dict[1]

{'TID': '2',
 'CID': 2,
 'CTID': '1',
 'SourceID': '5',
 'id': 1,
 'number': '17',
 'title': 'Mustard Gas',
 'length': '129000',
 'artist': 'Action Painting!',
 'album': 'There and Back Again Lane',
 'year': '1995',
 'language': 'English'}

## Preprocess

In [8]:
attr_list = ['title', 'artist', 'album']

In [9]:
import unidecode
from entity_embed.data_utils.one_hot_encoders import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    return ' '.join(s_part[:30] for s_part in default_tokenizer(s))[:100]

for row in row_dict.values():
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

## Get Pairs

In [10]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [11]:
from entity_embed.data_utils.utils import row_dict_to_cluster_dict

cluster_dict = row_dict_to_cluster_dict(row_dict, cluster_attr)
len(cluster_dict)

1000000

In [12]:
from entity_embed.data_utils.utils import split_clusters

train_len = 2_000
valid_len = 5_000
train_cluster_dict, valid_cluster_dict, test_cluster_dict = split_clusters(
    cluster_dict, train_len=train_len, valid_len=valid_len, random_seed=random_seed)
display(len(train_cluster_dict), len(valid_cluster_dict), len(test_cluster_dict))

2000

5000

993000

In [13]:
for d in [train_cluster_dict, valid_cluster_dict, test_cluster_dict]:
    display(sum(1 for cluster_id_list in d.values() if len(cluster_id_list) > 1))

2000

5000

493000

In [14]:
from entity_embed.data_utils.utils import cluster_dict_to_id_pairs

train_true_pair_set = cluster_dict_to_id_pairs(train_cluster_dict)
display(len(train_true_pair_set))
del train_true_pair_set
valid_true_pair_set = cluster_dict_to_id_pairs(valid_cluster_dict)
display(len(valid_true_pair_set))
del valid_true_pair_set
test_true_pair_set = cluster_dict_to_id_pairs(test_cluster_dict)
display(len(test_true_pair_set))

6618

15749

1602633

In [15]:
train_row_dict = {id_: row_dict[id_] for cluster_id_list in train_cluster_dict.values() for id_ in cluster_id_list}
len(train_row_dict)

5798

In [16]:
valid_row_dict = {id_: row_dict[id_] for cluster_id_list in valid_cluster_dict.values() for id_ in cluster_id_list}
len(valid_row_dict)

14203

In [17]:
test_row_dict = {id_: row_dict[id_] for cluster_id_list in test_cluster_dict.values() for id_ in cluster_id_list}
len(test_row_dict)

1917499

## Training

In [18]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [19]:
# TODO: support multiple attrs for same source attr

attr_info_dict = {
    'title': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'artist': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'album': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [20]:
from entity_embed import EntityEmbed

model = EntityEmbed(
    attr_info_dict=attr_info_dict,
    row_dict=row_dict
)

22:26:42 INFO:For attr='title', computing actual alphabet and max_str_len
22:26:51 INFO:For attr='title', using actual_max_str_len=30
22:26:51 INFO:For attr='artist', computing actual alphabet and max_str_len
22:26:56 INFO:For attr='artist', using actual_max_str_len=30
22:26:56 INFO:For attr='album', computing actual alphabet and max_str_len
22:27:02 INFO:For attr='album', using actual_max_str_len=30


In [21]:
model.attr_info_dict

{'title': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7ff5ef338790>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'artist': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7ff5ef338790>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'album': OneHotEncodingInfo(is_multi

In [22]:
epochs = 20

In [23]:
model.train(
    epochs=epochs,
    train_row_dict=train_row_dict,
    cluster_attr=cluster_attr,
    train_pos_pair_batch_size=45,
    train_neg_pair_batch_size=1225,
    valid_row_dict=valid_row_dict,
    random_seed=random_seed
)

HBox(children=(HTML(value='# training'), FloatProgress(value=0.0, max=2960.0), HTML(value='')))

22:27:25 INFO:[('title', tensor(0.3422, device='cuda:0')), ('artist', tensor(0.3195, device='cuda:0')), ('album', tensor(0.3383, device='cuda:0'))]
22:27:33 INFO:# Train Epoch:   0 Time: 23.174 Loss: 0.154, Precision: 0.273 Recall: 0.973
22:27:48 INFO:[('title', tensor(0.3468, device='cuda:0')), ('artist', tensor(0.3170, device='cuda:0')), ('album', tensor(0.3362, device='cuda:0'))]
22:27:56 INFO:# Train Epoch:   1 Time: 23.094 Loss: 0.028, Precision: 0.383 Recall: 0.975
22:28:12 INFO:[('title', tensor(0.3485, device='cuda:0')), ('artist', tensor(0.3144, device='cuda:0')), ('album', tensor(0.3371, device='cuda:0'))]
22:28:20 INFO:# Train Epoch:   2 Time: 23.817 Loss: 0.016, Precision: 0.433 Recall: 0.975
22:28:36 INFO:[('title', tensor(0.3505, device='cuda:0')), ('artist', tensor(0.3121, device='cuda:0')), ('album', tensor(0.3375, device='cuda:0'))]
22:28:44 INFO:# Train Epoch:   3 Time: 23.949 Loss: 0.012, Precision: 0.487 Recall: 0.976
22:29:00 INFO:[('title', tensor(0.3521, device='




In [24]:
torch.save(model, "music_model.torch")

In [25]:
# model = torch.load("music_model.torch")

In [26]:
test_vector_dict = model.predict(
    row_dict=test_row_dict,
    batch_size=16)

HBox(children=(HTML(value='# batch embedding'), FloatProgress(value=0.0, max=119844.0), HTML(value='')))




In [27]:
assert len(test_vector_dict) == len(test_row_dict)

In [28]:
%%time

from entity_embed.entity_embed import ANNEntityIndex

ann_index = ANNEntityIndex(embedding_size=model.embedding_size)
ann_index.insert_vector_dict(test_vector_dict)
ann_index.build()

CPU times: user 2h 21min 38s, sys: 9.38 s, total: 2h 21min 47s
Wall time: 12min 32s


In [29]:
%%time

ntop = 10
sim_threshold = 0.5

found_pair_set = ann_index.search_pairs(
    k=ntop,
    sim_threshold=sim_threshold
)

CPU times: user 5h 22min 15s, sys: 5.09 s, total: 5h 22min 20s
Wall time: 27min 37s


In [30]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(found_pair_set, test_row_dict)

6.472547312932106

In [31]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(found_pair_set, test_true_pair_set)

(0.12382952586889336, 0.9589600363901155)

In [32]:
false_positives = list(found_pair_set - test_true_pair_set)
len(false_positives)

10874242

In [33]:
false_negatives = list(test_true_pair_set - found_pair_set)
len(false_negatives)

65772

In [34]:
cos_similarity = lambda a, b: np.dot(a, b)

In [35]:
for (id_left, id_right) in false_negatives[:10]:
    display(
        (
            cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),
            row_dict[id_left], row_dict[id_right]
        )
    )

(0.47730002,
 {'TID': '1254312',
  'CID': 443895,
  'CTID': '5',
  'SourceID': '3',
  'id': 1254311,
  'number': '17',
  'title': "saythatyou ' reherebildhits2002 : dieerste",
  'length': '3.462',
  'artist': 'fragma',
  'album': '',
  'year': "'01",
  'language': 'English'},
 {'TID': '1479456',
  'CID': 443895,
  'CTID': '2',
  'SourceID': '5',
  'id': 1479455,
  'number': '17',
  'title': "say that you ' re heer",
  'length': '207720',
  'artist': 'fragma',
  'album': 'bildhits 2002 : die erste',
  'year': '2001',
  'language': 'English'})

(0.13932073,
 {'TID': '865485',
  'CID': 839820,
  'CTID': '2',
  'SourceID': '4',
  'id': 865484,
  'number': '5',
  'title': '005 ---',
  'length': '3m 13sec',
  'artist': '',
  'album': '( 2002 )',
  'year': 'null',
  'language': 'Jap.'},
 {'TID': '1626816',
  'CID': 839820,
  'CTID': '1',
  'SourceID': '3',
  'id': 1626815,
  'number': '5',
  'title': 'hurusato - nan bu sheng gang - - ren sheng yi shi chuan',
  'length': '3.217',
  'artist': 'jiu shi rang',
  'album': '',
  'year': "'02",
  'language': 'Japanese'})

(0.44234428,
 {'TID': '948128',
  'CID': 815016,
  'CTID': '3',
  'SourceID': '4',
  'id': 948127,
  'number': '9',
  'title': '009 - lean on me',
  'length': '4m 18sec',
  'artist': 'bill withers',
  'album': 'singersandsongwriterscollectio ( unknown )',
  'year': 'null',
  'language': 'Eng.'},
 {'TID': '1578856',
  'CID': 815016,
  'CTID': '1',
  'SourceID': '2',
  'id': 1578855,
  'number': '9',
  'title': 'bill ithers - lean on me',
  'length': '258',
  'artist': '',
  'album': 'singers and songwriters collection',
  'year': '',
  'language': 'English'})

(0.680475,
 {'TID': '676589',
  'CID': 376179,
  'CTID': '3',
  'SourceID': '3',
  'id': 676588,
  'number': '1',
  'title': 'symphony no . 9 in e minor " from the new world ", op . 95 : i . adagio -- allegro molto - symphony ',
  'length': '8.883',
  'artist': 'antonin',
  'album': '',
  'year': '',
  'language': 'English'},
 {'TID': '1291343',
  'CID': 376179,
  'CTID': '2',
  'SourceID': '2',
  'id': 1291342,
  'number': '1',
  'title': 'antonin dvorak - symphony no . 9 in e minor " from the new world ", op . 95 : i . adagio -- allegro ',
  'length': '533',
  'artist': '',
  'album': 'symphony no . 9 " from the new world " / slavonic dances',
  'year': '',
  'language': 'English'})

(0.36638412,
 {'TID': '420117',
  'CID': 216997,
  'CTID': '1',
  'SourceID': '4',
  'id': 420116,
  'number': '13',
  'title': '013 - dea thwish ( live )',
  'length': 'unknown',
  'artist': "einstein ' s sister",
  'album': 'intertwined : i hope this helps in some way ( unknown )',
  'year': 'null',
  'language': 'Eng.'},
 {'TID': '729647',
  'CID': 216997,
  'CTID': '2',
  'SourceID': '5',
  'id': 729646,
  'number': '13',
  'title': 'deathwish ( live )',
  'length': '',
  'artist': "einstein ' s sister",
  'album': 'english',
  'year': '',
  'language': 'English'})

(0.5775426,
 {'TID': '1533165',
  'CID': 941684,
  'CTID': '2',
  'SourceID': '3',
  'id': 1533164,
  'number': '5',
  'title': 'angels we have heard on high - christmas carols',
  'length': '1.514',
  'artist': '[ christmas music ]',
  'album': '',
  'year': "'95",
  'language': 'English'},
 {'TID': '1824672',
  'CID': 941684,
  'CTID': '1',
  'SourceID': '2',
  'id': 1824671,
  'number': '5',
  'title': '[ christmasmusic ] angelswehaveheardonhigh',
  'length': '90',
  'artist': '',
  'album': 'christmas carols',
  'year': '95',
  'language': 'English'})

(0.3763197,
 {'TID': '924644',
  'CID': 864128,
  'CTID': '3',
  'SourceID': '1',
  'id': 924643,
  'number': '006',
  'title': 'got to rock on ( audio - visions )',
  'length': '03:21',
  'artist': '',
  'album': 'audio - visions',
  'year': '1996',
  'language': ''},
 {'TID': '1001741',
  'CID': 864128,
  'CTID': '2',
  'SourceID': '5',
  'id': 1001740,
  'number': '6',
  'title': 'got to rock on',
  'length': '2 01866',
  'artist': 'kansas',
  'album': '',
  'year': '1996',
  'language': 'English'})

(0.27092427,
 {'TID': '300168',
  'CID': 611569,
  'CTID': '3',
  'SourceID': '4',
  'id': 300167,
  'number': '7',
  'title': "007idon ' ttrustmenwithearringsintheirea",
  'length': '3m 59sec',
  'artist': 'gilbert osullivan',
  'album': 'in the key of g . ( 1989 )',
  'year': 'null',
  'language': 'ng.'},
 {'TID': '1184403',
  'CID': 611569,
  'CTID': '1',
  'SourceID': '2',
  'id': 1184402,
  'number': '7',
  'title': "gilbert o ' sullivan - i don ' t trust men with earrings in their ears",
  'length': '239',
  'artist': '',
  'album': 'in the ekey of g ..',
  'year': '89',
  'language': 'English'})

(0.8126404,
 {'TID': '32432',
  'CID': 943304,
  'CTID': '2',
  'SourceID': '1',
  'id': 32431,
  'number': '014',
  'title': "keeps gettin ' better ( nrj 200 % hits )",
  'length': '03:01',
  'artist': 'christina aguilera',
  'album': 'nrj 200 % hits',
  'year': '2009',
  'language': ''},
 {'TID': '753442',
  'CID': 943304,
  'CTID': '3',
  'SourceID': '2',
  'id': 753441,
  'number': '14',
  'title': "christina aguilera - keeps gettin ' better",
  'length': '181',
  'artist': '',
  'album': 'nrj 200 hits',
  'year': '09',
  'language': '[Multiple languages]'})

(0.3331428,
 {'TID': '1203564',
  'CID': 723671,
  'CTID': '2',
  'SourceID': '5',
  'id': 1203563,
  'number': '9',
  'title': 'zen yao hui zhe yang',
  'length': '233866',
  'artist': 'li zhen xian',
  'album': 'qian mian nu hai love me',
  'year': '2008',
  'language': 'Chinese'},
 {'TID': '1401392',
  'CID': 723671,
  'CTID': '1',
  'SourceID': '4',
  'id': 1401391,
  'number': '9',
  'title': '009 -',
  'length': '3m 53sec',
  'artist': '',
  'album': 'love me ( 2008 )',
  'year': 'null',
  'language': 'Chi.'})