**TODO**
- Properly handle empty attributes

## Load Dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.insert(0, '..')

In [3]:
# TODO: import here, conflict with libgomp from pytorch
from n2 import HnswIndex
import os

In [4]:
import os
home_dir = os.getenv('HOME')

https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution

https://www.informatik.uni-leipzig.de/~saeedi/musicBrainz_readme.txt

```
5 sources
---------- 
TID: a unique record's id (in the complete dataset).
CID: cluster id (records having the same CID are duplicate)
CTID: a unique id within a cluster (if two records belong to the same cluster they will have the same CID but different CTIDs). These ids (CTID) start with 1 and grow until cluster size.
SourceID: identifies to which source a record belongs (there are five sources). The sources are deduplicated.
Id: the original id from the source. Each source has its own Id-Format. Uniqueness is not guaranteed!! (can be ignored).
number: track or song number in the album.
length: the length of the track.
artist: the interpreter (artist or band) of the track.
year: date of publication.
language: language of the track.
```

In [5]:
import glob
import csv
import tqdm

current_row_id = 0
row_dict = {}
rows_total = 19375
cluster_id_attr = 'CID'

with tqdm.tqdm(total=rows_total) as pbar:
    for filename in glob.glob(f'{home_dir}/Downloads/musicbrainz-20-A01.csv.dapo'):
        with open(filename) as f:
            for row in csv.DictReader(f):
                row['id'] = current_row_id
                row[cluster_id_attr] = int(row[cluster_id_attr])  # convert cluster_id_attr to int
                row_dict[current_row_id] = row
                current_row_id += 1
                pbar.update(1)

100%|██████████| 19375/19375 [00:00<00:00, 174523.21it/s]


In [6]:
row_dict[1]

{'TID': '2',
 'CID': 2512,
 'CTID': '5',
 'SourceID': '4',
 'id': 1,
 'number': '7',
 'title': '007',
 'length': '1m 58sec',
 'artist': '[unknown]',
 'album': 'Cantigas de roda (unknown)',
 'year': 'null',
 'language': 'Por.'}

## Preprocess

In [7]:
attr_list = ['title', 'artist', 'album']

In [8]:
import unidecode
from entity_embed.data_utils.one_hot_encoders import default_tokenizer

def clean_str(s):
    s = unidecode.unidecode(s).lower().strip()
    return ' '.join(s_part[:30] for s_part in default_tokenizer(s))[:100]

for row in row_dict.values():
    for attr in attr_list:
        row[attr] = clean_str(row[attr])

In [9]:
len(set(r[cluster_id_attr] for r in row_dict.values()))

10000

In [10]:
row_list = list(row_dict.values())
row_list.sort(key=lambda row: (row[cluster_id_attr]))

In [11]:
from ordered_set import OrderedSet  # ensure reproducibility
import itertools

true_pair_set = OrderedSet(
    tuple(sorted((row_left['id'], row_right['id'])))
    for __, row_cluster_list in itertools.groupby(row_list, key=lambda row: row[cluster_id_attr])
    for row_left, row_right in itertools.combinations(row_cluster_list, 2)
)
len(true_pair_set)

16250

In [12]:
[row_dict[id_] for id_ in next(iter(true_pair_set))]

[{'TID': '1',
  'CID': 1,
  'CTID': '1',
  'SourceID': '2',
  'id': 0,
  'number': '9',
  'title': "daniel balavoine - l ' enfant aux yeux d ' italie",
  'length': '219',
  'artist': '',
  'album': 'de vous a elle en passant par moi',
  'year': '75',
  'language': 'French'},
 {'TID': '15184',
  'CID': 1,
  'CTID': '2',
  'SourceID': '3',
  'id': 15183,
  'number': '9',
  'title': "l ' enfant aux yeux d ' italie - de vous a elle en passant par moi",
  'length': '3.663',
  'artist': 'daniel balavoine',
  'album': '',
  'year': "'75",
  'language': 'French'}]

In [13]:
from entity_embed.evaluation import precision_and_recall

## Get Pairs

In [14]:
import random

random_seed = 42
random.seed(random_seed)

In [15]:
len(true_pair_set)

16250

In [16]:
import random

train_len = 5_000
train_pair_set = OrderedSet(random.sample(true_pair_set, train_len))
valid_pair_set = true_pair_set - train_pair_set

print(len(train_pair_set))
print(len(valid_pair_set))

5000
11250


In [17]:
train_id_list = OrderedSet(id_ for pair in train_pair_set for id_ in pair)
len(train_id_list)

7582

In [18]:
train_row_dict = {id_: row_dict[id_] for id_ in train_id_list}
len(train_row_dict)

7582

In [19]:
valid_id_list = OrderedSet(id_ for pair in valid_pair_set for id_ in pair)
len(valid_id_list)

12381

In [20]:
valid_row_dict = {id_: row_dict[id_] for id_ in valid_id_list}
len(valid_row_dict)

12381

## Training

In [21]:
import torch
import numpy as np

In [22]:
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [23]:
alphabet = list('0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ')

In [24]:
# TODO: support multiple attrs for same source attr

attr_info_dict = {
    'title': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'artist': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    },
    'album': {
        'is_multitoken': True,
        'tokenizer': default_tokenizer,
        'alphabet': alphabet,
        'max_str_len': None,  # compute
    }
}

In [25]:
from entity_embed import EntityEmbed

model = EntityEmbed(
    attr_info_dict=attr_info_dict,
    row_dict=row_dict
)

In [26]:
model.attr_info_dict

{'title': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7f6bc3390160>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'artist': OneHotEncodingInfo(is_multitoken=True, tokenizer=<function default_tokenizer at 0x7f6bc3390160>, alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '], max_str_len=30),
 'album': OneHotEncodingInfo(is_multi

In [27]:
epochs = 10

In [28]:
model.train(
    epochs=epochs,
    train_row_dict=train_row_dict,
    cluster_id_attr=cluster_id_attr,
    pos_pair_batch_size=45,
    neg_pair_batch_size=1225,
    random_seed=random_seed
)

# Train Epoch:   0 Time: 14.005 Loss: 0.263:  10%|█         | 178/1780 [00:14<02:05, 12.81it/s]

[('title', tensor(0.3527, device='cuda:0')), ('artist', tensor(0.3054, device='cuda:0')), ('album', tensor(0.3419, device='cuda:0'))]


# Train Epoch:   1 Time: 13.756 Loss: 0.067:  20%|██        | 356/1780 [00:27<01:49, 13.02it/s]

[('title', tensor(0.3554, device='cuda:0')), ('artist', tensor(0.3004, device='cuda:0')), ('album', tensor(0.3442, device='cuda:0'))]


# Train Epoch:   2 Time: 14.037 Loss: 0.033:  30%|███       | 534/1780 [00:41<01:35, 13.08it/s]

[('title', tensor(0.3572, device='cuda:0')), ('artist', tensor(0.3017, device='cuda:0')), ('album', tensor(0.3411, device='cuda:0'))]


# Train Epoch:   3 Time: 13.974 Loss: 0.021:  40%|████      | 712/1780 [00:55<01:26, 12.35it/s]

[('title', tensor(0.3599, device='cuda:0')), ('artist', tensor(0.2993, device='cuda:0')), ('album', tensor(0.3408, device='cuda:0'))]


# Train Epoch:   4 Time: 14.032 Loss: 0.017:  50%|█████     | 890/1780 [01:10<01:06, 13.35it/s]

[('title', tensor(0.3630, device='cuda:0')), ('artist', tensor(0.2967, device='cuda:0')), ('album', tensor(0.3403, device='cuda:0'))]


# Train Epoch:   5 Time: 13.780 Loss: 0.015:  60%|██████    | 1068/1780 [01:23<00:53, 13.34it/s]

[('title', tensor(0.3639, device='cuda:0')), ('artist', tensor(0.2935, device='cuda:0')), ('album', tensor(0.3426, device='cuda:0'))]


# Train Epoch:   6 Time: 13.936 Loss: 0.014:  70%|███████   | 1246/1780 [01:37<00:41, 12.84it/s]

[('title', tensor(0.3659, device='cuda:0')), ('artist', tensor(0.2896, device='cuda:0')), ('album', tensor(0.3445, device='cuda:0'))]


# Train Epoch:   7 Time: 13.842 Loss: 0.013:  80%|████████  | 1424/1780 [01:51<00:27, 12.71it/s]

[('title', tensor(0.3678, device='cuda:0')), ('artist', tensor(0.2886, device='cuda:0')), ('album', tensor(0.3436, device='cuda:0'))]


# Train Epoch:   8 Time: 13.999 Loss: 0.013:  90%|█████████ | 1602/1780 [02:05<00:13, 13.22it/s]

[('title', tensor(0.3696, device='cuda:0')), ('artist', tensor(0.2884, device='cuda:0')), ('album', tensor(0.3420, device='cuda:0'))]


# Train Epoch:   9 Time: 13.832 Loss: 0.012: 100%|██████████| 1780/1780 [02:19<00:00, 12.72it/s]

[('title', tensor(0.3730, device='cuda:0')), ('artist', tensor(0.2829, device='cuda:0')), ('album', tensor(0.3441, device='cuda:0'))]





In [29]:
# torch.save(model, "music_model.torch")

In [30]:
# model = torch.load("music_model.torch")

In [31]:
valid_vector_list = model.evaluate(
    row_dict=valid_row_dict,
    batch_size=64)

# batch embedding: 100%|██████████| 194/194 [00:05<00:00, 36.07it/s]


In [32]:
valid_id_to_index = {id_: index for index, id_ in enumerate(valid_id_list)}

In [33]:
%%time

ef_construction = 150
M = 64
metric = 'angular'

# https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md#construction-parameters
approx_knn_index = HnswIndex(dimension=valid_vector_list[0].shape[0], metric=metric)
for valid_vector in valid_vector_list:
    approx_knn_index.add_data(valid_vector)

approx_knn_index.build(    
    m=M,
    max_m0=M,
    ef_construction=ef_construction,
    n_threads=os.cpu_count(),
)

CPU times: user 4.48 s, sys: 25.8 ms, total: 4.51 s
Wall time: 524 ms


In [34]:
%%time

ntop = 10

neighbor_distance_list = approx_knn_index.batch_search_by_ids(
    item_ids=list(range(len(valid_vector_list))),
    k=ntop,
    ef_search=-1,
    num_threads=os.cpu_count(),
    include_distances=True
)

CPU times: user 8.04 s, sys: 5.24 ms, total: 8.04 s
Wall time: 735 ms


In [35]:
neighbor_distance_list[0]

[(0, 0.0),
 (1, 0.18714380264282227),
 (7562, 0.5210917592048645),
 (7561, 0.5268628597259521),
 (1475, 0.5302277207374573),
 (1474, 0.5445627570152283),
 (6979, 0.5527433753013611),
 (1332, 0.5741868019104004),
 (6240, 0.5771656632423401),
 (6899, 0.578498125076294)]

In [36]:
threshold = 0.5
distance_threshold = 1 - threshold

In [37]:
[row_dict[valid_id_list[neighbor]] for neighbor, distance in neighbor_distance_list[0] if distance <= distance_threshold]

[{'TID': '1',
  'CID': 1,
  'CTID': '1',
  'SourceID': '2',
  'id': 0,
  'number': '9',
  'title': "daniel balavoine - l ' enfant aux yeux d ' italie",
  'length': '219',
  'artist': '',
  'album': 'de vous a elle en passant par moi',
  'year': '75',
  'language': 'French'},
 {'TID': '15184',
  'CID': 1,
  'CTID': '2',
  'SourceID': '3',
  'id': 15183,
  'number': '9',
  'title': "l ' enfant aux yeux d ' italie - de vous a elle en passant par moi",
  'length': '3.663',
  'artist': 'daniel balavoine',
  'album': '',
  'year': "'75",
  'language': 'French'}]

In [38]:
found_pair_set = set()

for i, neighbor_distance in tqdm.tqdm_notebook(enumerate(neighbor_distance_list), total=len(neighbor_distance_list)):
    for j, distance in neighbor_distance:
        if i != j and distance <= distance_threshold:
            pair = tuple(sorted([valid_id_list[i], valid_id_list[j]]))
            found_pair_set.add(pair)

len(found_pair_set)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, neighbor_distance in tqdm.tqdm_notebook(enumerate(neighbor_distance_list), total=len(neighbor_distance_list)):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12381.0), HTML(value='')))




27252

In [39]:
def pe_ratio(found_pair_set, valid_id_list):
    return len(found_pair_set) / len(valid_id_list)

pe_ratio(found_pair_set, valid_id_list)

2.20111461109765

In [40]:
precision_and_recall(found_pair_set, valid_pair_set)

(0.4077498899163364, 0.9877333333333334)

In [41]:
false_positives = list(found_pair_set - valid_pair_set)
len(false_positives)

16140

In [42]:
false_negatives = list(valid_pair_set - found_pair_set)
len(false_negatives)

138

In [43]:
cos_similarity = lambda a, b: np.dot(a, b)

In [44]:
for (id_left, id_right) in false_negatives[:10]:
    i = valid_id_to_index[id_left]
    j = valid_id_to_index[id_right]
    display((cos_similarity(valid_vector_list[i], valid_vector_list[j]), row_dict[id_left], row_dict[id_right]))

(0.39653063,
 {'TID': '372',
  'CID': 196,
  'CTID': '1',
  'SourceID': '2',
  'id': 371,
  'number': '24',
  'title': 'ernesto tagliaferri - mandulinata a napoli',
  'length': '175',
  'artist': '',
  'album': 'italian popular songs , volume 1',
  'year': '',
  'language': 'Itlaian'},
 {'TID': '7494',
  'CID': 196,
  'CTID': '2',
  'SourceID': '3',
  'id': 7493,
  'number': '13',
  'title': 'mandulinataanapoliitalianpopul , volume1',
  'length': '2.924',
  'artist': 'ernesto tagliaferri',
  'album': '',
  'year': '',
  'language': 'Italian'})

(0.08300196,
 {'TID': '490',
  'CID': 266,
  'CTID': '1',
  'SourceID': '3',
  'id': 489,
  'number': '5',
  'title': 'onriman - sakurada zhan di ju ge yao quan ji',
  'length': '5.715',
  'artist': 'tachibanamaria',
  'album': '',
  'year': '',
  'language': 'Japanese'},
 {'TID': '1355',
  'CID': 266,
  'CTID': '2',
  'SourceID': '4',
  'id': 1354,
  'number': '5',
  'title': '005 -',
  'length': '5m 42sec',
  'artist': '',
  'album': '( unknown )',
  'year': 'null',
  'language': 'Jap.'})

(0.031204147,
 {'TID': '561',
  'CID': 298,
  'CTID': '1',
  'SourceID': '3',
  'id': 560,
  'number': '6',
  'title': "prikhodi - zn @ menatel '",
  'length': '4.055',
  'artist': 'splin',
  'album': '',
  'year': "'00",
  'language': 'Russian'},
 {'TID': '2355',
  'CID': 298,
  'CTID': '2',
  'SourceID': '4',
  'id': 2354,
  'number': '6',
  'title': '006 -',
  'length': '4m 3sec',
  'artist': '',
  'album': '@ ( 2000 )',
  'year': 'null',
  'language': 'Rus.'})

(-0.0875065,
 {'TID': '2355',
  'CID': 298,
  'CTID': '2',
  'SourceID': '4',
  'id': 2354,
  'number': '6',
  'title': '006 -',
  'length': '4m 3sec',
  'artist': '',
  'album': '@ ( 2000 )',
  'year': 'null',
  'language': 'Rus.'},
 {'TID': '3390',
  'CID': 298,
  'CTID': '3',
  'SourceID': '5',
  'id': 3389,
  'number': '6',
  'title': 'prikhodi',
  'length': '243293',
  'artist': 'splin',
  'album': "zn @ menatel '",
  'year': '2000',
  'language': 'Russian'})

(0.28250557,
 {'TID': '831',
  'CID': 442,
  'CTID': '1',
  'SourceID': '3',
  'id': 830,
  'number': '13',
  'title': '19841031',
  'length': '8.667',
  'artist': 'u2',
  'album': '',
  'year': '',
  'language': 'English'},
 {'TID': '6047',
  'CID': 442,
  'CTID': '2',
  'SourceID': '4',
  'id': 6046,
  'number': '13',
  'title': '013 - bad',
  'length': '8m 40sec',
  'artist': 'u2',
  'album': '1984 - 10 - 31 : sportpaleis ahoy , rotterdam , netherlands ( unknown )',
  'year': 'null',
  'language': 'nEg.'})

(0.35595816,
 {'TID': '831',
  'CID': 442,
  'CTID': '1',
  'SourceID': '3',
  'id': 830,
  'number': '13',
  'title': '19841031',
  'length': '8.667',
  'artist': 'u2',
  'album': '',
  'year': '',
  'language': 'English'},
 {'TID': '11882',
  'CID': 442,
  'CTID': '3',
  'SourceID': '5',
  'id': 11881,
  'number': '13',
  'title': 'bad',
  'length': '520000',
  'artist': 'u2',
  'album': '1984 - 10 - 31 : sportpaleis ahoy , rotterdam , netherlands',
  'year': '',
  'language': 'English'})

(0.28613654,
 {'TID': '831',
  'CID': 442,
  'CTID': '1',
  'SourceID': '3',
  'id': 830,
  'number': '13',
  'title': '19841031',
  'length': '8.667',
  'artist': 'u2',
  'album': '',
  'year': '',
  'language': 'English'},
 {'TID': '18004',
  'CID': 442,
  'CTID': '4',
  'SourceID': '1',
  'id': 18003,
  'number': '013',
  'title': 'bad ( 1984 - 10 - 31 : sportpaleis ahoy , rotterdam , netherlands )',
  'length': '08:40',
  'artist': 'u2',
  'album': '1984 - 10 - 31 : sportpaleis ahoy , rotterdam , netherlands',
  'year': '',
  'language': ''})

(-0.16045132,
 {'TID': '887',
  'CID': 470,
  'CTID': '1',
  'SourceID': '1',
  'id': 886,
  'number': '006',
  'title': "sweet and slow ( guy ' s all - star shoe band )",
  'length': '03:04',
  'artist': "guy ' s all star shoe band",
  'album': "guy ' s all - star shoe band",
  'year': '',
  'language': ''},
 {'TID': '17706',
  'CID': 470,
  'CTID': '2',
  'SourceID': '2',
  'id': 17705,
  'number': 'MBox10988814-HH',
  'title': '6',
  'length': "Guy's All Star Shoe Band - Sweet and Slow",
  'artist': '1184',
  'album': '',
  'year': "Guy's All-Star Shoe Band",
  'language': ''})

(0.41244337,
 {'TID': '948',
  'CID': 502,
  'CTID': '1',
  'SourceID': '3',
  'id': 947,
  'number': '16',
  'title': "let ' s get together - subarashikikonosekai original soundtrack",
  'length': '0.267',
  'artist': 'shi yuan zhang qing',
  'album': '',
  'year': "'07",
  'language': 'Japanese'},
 {'TID': '7146',
  'CID': 502,
  'CTID': '2',
  'SourceID': '4',
  'id': 7145,
  'number': '16',
  'title': '',
  'length': '0m 16sec',
  'artist': '',
  'album': 'soundtrack',
  'year': 'null',
  'language': 'Jap.'})

(0.39290592,
 {'TID': '984',
  'CID': 521,
  'CTID': '1',
  'SourceID': '2',
  'id': 983,
  'number': '11',
  'title': "al martino - you ' re the love of my life",
  'length': '178',
  'artist': '',
  'album': 'best of jose carreras gala',
  'year': '04',
  'language': '[Multiple languages]'},
 {'TID': '2803',
  'CID': 521,
  'CTID': '2',
  'SourceID': '3',
  'id': 2802,
  'number': '11',
  'title': "you ' re",
  'length': '2.967',
  'artist': 'al martino',
  'album': '',
  'year': "'04",
  'language': ''})