# Attention Scores Example

Please run [Record-Linkage-Example.ipynb](Record-Linkage-Example.ipynb) before this one in order to get the trained model at `../trained-models/notebooks/rl/rl-model.ckpt`.

## Boilerplate

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '..')

In [4]:
import entity_embed

In [5]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

## Loading Test Data

In [6]:
import json
from ordered_set import OrderedSet

def load_pair_set(filepath):
    with open(filepath, 'r') as f:
        test_pos_pair_set = json.load(f)
        return OrderedSet(tuple(pair) for pair in test_pos_pair_set)

train_pos_pair_set = load_pair_set('../trained-models/notebooks/rl/rl-train-pos-pairs.json')
valid_pos_pair_set = load_pair_set('../trained-models/notebooks/rl/rl-valid-pos-pairs.json')
test_pos_pair_set = load_pair_set('../trained-models/notebooks/rl/rl-test-pos-pairs.json')

In [7]:
import json

def load_record_dict(filepath):
    with open(filepath, 'r') as f:
        record_dict = json.load(f)
        return {int(id_): record for id_, record in record_dict.items()}

train_record_dict = load_record_dict('../trained-models/notebooks/rl/rl-train-records.json')
valid_record_dict = load_record_dict('../trained-models/notebooks/rl/rl-valid-records.json')
test_record_dict = load_record_dict('../trained-models/notebooks/rl/rl-test-records.json')

## Loading Model

In [8]:
from entity_embed import LinkageEmbed

model = LinkageEmbed.load_from_checkpoint('../trained-models/notebooks/rl/rl-model.ckpt')
model = model.to(torch.device('cuda'))

## Blocking

In [9]:
%%time

eval_batch_size = 64
ann_k = 100
sim_threshold = 0.5

test_found_pair_set = model.predict_pairs(
    record_dict=test_record_dict,
    batch_size=eval_batch_size,
    ann_k=ann_k,
    sim_threshold=sim_threshold,
    show_progress=True,
)

len(test_found_pair_set)

# batch embedding:   0%|          | 0/22 [00:00<?, ?it/s]

CPU times: user 1.53 s, sys: 578 ms, total: 2.1 s
Wall time: 1.32 s


1027

In [10]:
%%time

test_left_vector_dict, test_right_vector_dict = model.predict(
    record_dict=test_record_dict,
    batch_size=eval_batch_size,
    show_progress=True,
)

len(test_left_vector_dict), len(test_right_vector_dict)

# batch embedding:   0%|          | 0/22 [00:00<?, ?it/s]

CPU times: user 447 ms, sys: 586 ms, total: 1.03 s
Wall time: 1.14 s


(406, 963)

In [11]:
test_attn_scores_dict = model.interpret_attention(
    record_dict=test_record_dict,
    batch_size=eval_batch_size,
    field='title',
)

len(test_attn_scores_dict)

# batch embedding:   0%|          | 0/22 [00:00<?, ?it/s]

1369

In [12]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(test_found_pair_set), len(test_record_dict))

0.75018261504748

In [13]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(test_found_pair_set, test_pos_pair_set)

(0.35345666991236613, 0.952755905511811)

In [14]:
hard_positives = list(test_pos_pair_set & test_found_pair_set)
hard_positives = [
    (id_left, id_right)
    for (id_left, id_right) in hard_positives
    if sim_threshold <= np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]) <= sim_threshold + 0.1
]
len(hard_positives)

18

In [15]:
import pandas as pd
import seaborn as sns

def display_attention(id_, field):
    val = test_record_dict[id_][field]
    attn_scores = test_attn_scores_dict[id_][:len(val.split())]
    attn_df = pd.DataFrame(dict(zip(val.split(), attn_scores)), index=[id_])
    cm = sns.light_palette("red", as_cmap=True)
    display(attn_df.style.background_gradient(cmap=cm, axis=1))

def display_pair_attention(pair, field):
    left_id, right_id = pair
    display_attention(left_id, field)
    display_attention(right_id, field)

for (id_left, id_right) in hard_positives[:5]:
    print(np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]))
    display_pair_attention((id_left, id_right), 'title')

0.56413823


Unnamed: 0,micromat,podlock,(,mac,)
855,0.200589,0.290565,0.136797,0.090351,0.092011


Unnamed: 0,micromat,podlock,ipod,utility,software
1552,0.211529,0.308867,0.190773,0.104421,0.029679


0.5060007


Unnamed: 0,hijack2
1260,0.309989


Unnamed: 0,me,too,software,800801,-,hijack2,win,98,nt,2000,xp,/,mac,10,.,0,or,higher
2405,0.061568,0.034789,0.012571,0.033064,0.105356,0.118844,0.091359,0.059822,0.054685,0.042801,0.0441,0.032442,0.028481,0.039138,0.04355,0.037384,0.030343,0.04003


0.51596963


Unnamed: 0,zero,-,g,pro,pack,for,garageband,(,appleloops,)
797,0.131953,0.09584,0.097956,0.090986,0.113048,0.095525,0.234135,0.055736,0.041469,0.043352


Unnamed: 0,east,west,propack,for,garageband,av,production,software
3960,0.089039,0.133312,0.185782,0.159489,0.3416,0.058214,0.021221,0.011344


0.56924784


Unnamed: 0,school,zone,pencil,-,pal,software,big,phonics,(,cd,rom,&,book,)
19,0.117748,0.114365,0.075442,0.045753,0.063288,0.015395,0.047318,0.05882,0.063923,0.052791,0.063104,0.074729,0.069363,0.050097


Unnamed: 0,pencil,-,pal,big,phonics
2600,0.134765,0.212887,0.242428,0.205859,0.204061


0.54849666


Unnamed: 0,faxstf,pro,mac,os,10,.,3,9,or,above
1161,0.071674,0.094214,0.095564,0.115048,0.133837,0.080537,0.135922,0.0429,0.040706,0.035808


Unnamed: 0,allume,smith,micro,faxstf,pro
2273,0.114426,0.063109,0.142721,0.145955,0.164066


In [16]:
false_negatives = list(test_pos_pair_set - test_found_pair_set)
len(false_negatives)

18

In [17]:
for (id_left, id_right) in false_negatives[:5]:
    print(np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]))
    display_pair_attention((id_left, id_right), 'title')

0.4181442


Unnamed: 0,corel,snapfire,plus,[,ignite,photo,fun,]
552,0.138126,0.090277,0.18028,0.118514,0.108722,0.181631,0.097803,0.084646


Unnamed: 0,corel,corporation,snapfire,plus
1904,0.147772,0.125614,0.082408,0.144304


0.44709966


Unnamed: 0,foreign,policy,&,reform,(,win,/,mac,),jewel,case
112,0.069001,0.138821,0.145052,0.172942,0.034655,0.087579,0.06101,0.043694,0.026506,0.02239,0.028819


Unnamed: 0,fogware,publishing,-,10356,high,school,us,history,2,foreign,policy,&,reform
3052,0.027249,0.028246,0.05008,0.047381,0.103982,0.12756,0.088985,0.160765,0.103958,0.036985,0.048687,0.036727,0.036096


0.4741135


Unnamed: 0,clifford,the,big,red,dog,-,thinking,adventures
50,0.159681,0.124132,0.149954,0.149953,0.160517,0.123887,0.084409,0.047465


Unnamed: 0,clifford,thinking
3453,0.442445,0.236488


0.49276412


Unnamed: 0,omniweb,5,.,0
725,0.227882,0.217475,0.20088,0.172564


Unnamed: 0,omni,web,5,.,0
1803,0.140417,0.181534,0.211243,0.194329,0.163552


0.48059654


Unnamed: 0,train,sim,modeler,design,studio
1350,0.15098,0.170325,0.185783,0.212958,0.279954


Unnamed: 0,abacus,train,sim,modeler
1958,0.184054,0.214456,0.267194,0.31334


In [18]:
false_positives = list(test_found_pair_set - test_pos_pair_set)
len(false_positives)

664

In [19]:
for (id_left, id_right) in false_positives[:5]:
    print(np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]))
    display_pair_attention((id_left, id_right), 'title')

0.7204137


Unnamed: 0,print,shop,22,pro,publisher,deluxe
1259,0.104937,0.152767,0.198373,0.179916,0.229915,0.134091


Unnamed: 0,printshop,20,pro,publisher
2800,0.201209,0.309976,0.23144,0.257375


0.58963317


Unnamed: 0,ae,mappoint,2006,cd
469,0.229766,0.283581,0.281881,0.204772


Unnamed: 0,microsoft,mappoint,2006,with,gps,locator,(,pc,)
4392,0.115509,0.185339,0.179359,0.140424,0.131453,0.108565,0.062387,0.041956,0.035008


0.5331933


Unnamed: 0,upg,serverlock,for,solaris,gold
200,0.190985,0.209826,0.197998,0.25131,0.149881


Unnamed: 0,freeverse,software,5014,-,toysight,gold,(,mac,10,.,2,or,higher,)
2785,0.026781,0.013017,0.056985,0.063934,0.148294,0.13435,0.0871,0.072952,0.093101,0.100122,0.074046,0.042554,0.048317,0.038447


0.54700345


Unnamed: 0,apple,ilife,',06,family,pack,(,mac,dvd,),[,older,version,]
860,0.090943,0.09371,0.125803,0.118287,0.11129,0.117449,0.066811,0.050793,0.04273,0.044668,0.045425,0.043254,0.025143,0.023696


Unnamed: 0,apple,.,mac,family,pack,software,-,five,user,license,with,1,year,subscription,&
3478,0.057929,0.098221,0.014846,0.07033,0.056706,0.005276,0.058178,0.025641,0.051894,0.046556,0.060568,0.057592,0.083535,0.066279,0.035797


0.6331956


Unnamed: 0,apple,ilife,',06,(,mac,dvd,),[,older,version,]
542,0.116475,0.118553,0.155587,0.137826,0.086079,0.067373,0.058156,0.06251,0.064748,0.062026,0.036361,0.034306


Unnamed: 0,apple,ilife,',06,family,pack
1591,0.147593,0.148411,0.193483,0.174719,0.160424,0.175372
