# Attention Scores Example

Please run [Record-Linkage-Example.ipynb](Record-Linkage-Example.ipynb) before this one in order to get the trained model at `../trained-models/notebooks/rl/rl-model.ckpt`.

## Boilerplate

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib import reload
import logging
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')

In [3]:
import sys

sys.path.insert(0, '..')

In [4]:
import entity_embed

In [5]:
import torch
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

## Loading Test Data

In [6]:
import json
from ordered_set import OrderedSet

def load_pair_set(filepath):
    with open(filepath, 'r') as f:
        test_pos_pair_set = json.load(f)
        return OrderedSet(tuple(pair) for pair in test_pos_pair_set)

train_pos_pair_set = load_pair_set('../trained-models/notebooks/rl/rl-train-pos-pairs.json')
valid_pos_pair_set = load_pair_set('../trained-models/notebooks/rl/rl-valid-pos-pairs.json')
test_pos_pair_set = load_pair_set('../trained-models/notebooks/rl/rl-test-pos-pairs.json')

In [7]:
import json

def load_record_dict(filepath):
    with open(filepath, 'r') as f:
        record_dict = json.load(f)
        return {int(id_): record for id_, record in record_dict.items()}

train_record_dict = load_record_dict('../trained-models/notebooks/rl/rl-train-records.json')
valid_record_dict = load_record_dict('../trained-models/notebooks/rl/rl-valid-records.json')
test_record_dict = load_record_dict('../trained-models/notebooks/rl/rl-test-records.json')

## Loading Model

In [8]:
from entity_embed import LinkageEmbed

model = LinkageEmbed.load_from_checkpoint('../trained-models/notebooks/rl/rl-model.ckpt')
model = model.to(torch.device('cuda'))

## Blocking

In [9]:
%%time

eval_batch_size = 64
ann_k = 100
sim_threshold = 0.5

test_found_pair_set = model.predict_pairs(
    record_dict=test_record_dict,
    batch_size=eval_batch_size,
    ann_k=ann_k,
    sim_threshold=sim_threshold,
    show_progress=True,
)

len(test_found_pair_set)

# batch embedding:   0%|          | 0/22 [00:00<?, ?it/s]

CPU times: user 1.47 s, sys: 635 ms, total: 2.11 s
Wall time: 1.34 s


960

In [10]:
%%time

test_left_vector_dict, test_right_vector_dict = model.predict(
    record_dict=test_record_dict,
    batch_size=eval_batch_size,
    show_progress=True,
)

len(test_left_vector_dict), len(test_right_vector_dict)

# batch embedding:   0%|          | 0/22 [00:00<?, ?it/s]

CPU times: user 393 ms, sys: 612 ms, total: 1.01 s
Wall time: 1.09 s


(406, 963)

In [11]:
test_attn_scores_dict = model.interpret_attention(
    record_dict=test_record_dict,
    batch_size=eval_batch_size,
    field='title',
)

len(test_attn_scores_dict)

# batch embedding:   0%|          | 0/22 [00:00<?, ?it/s]

1369

In [12]:
[id_ for id_, x in test_attn_scores_dict.items() if x.sum() < 0.99]

[]

In [13]:
from entity_embed.evaluation import pair_entity_ratio

pair_entity_ratio(len(test_found_pair_set), len(test_record_dict))

0.7012417823228634

In [14]:
from entity_embed.evaluation import precision_and_recall

precision_and_recall(test_found_pair_set, test_pos_pair_set)

(0.378125, 0.952755905511811)

In [15]:
hard_positives = list(test_pos_pair_set & test_found_pair_set)
hard_positives = [
    (id_left, id_right)
    for (id_left, id_right) in hard_positives
    if sim_threshold <= np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]) <= sim_threshold + 0.1
]
len(hard_positives)

24

In [16]:
import seaborn as sns
import pandas as pd
from entity_embed import default_tokenizer

def display_attention(id_, field):
    tokens = default_tokenizer(test_record_dict[id_][field])
    attn_scores = test_attn_scores_dict[id_][:len(tokens)]
    attn_df = pd.DataFrame(dict(zip(tokens, attn_scores)), index=[id_])
    cm = sns.light_palette("red", as_cmap=True)
    display(attn_df.style.background_gradient(cmap=cm, axis=1))

def display_pair_attention(pair, field):
    left_id, right_id = pair
    display_attention(left_id, field)
    display_attention(right_id, field)

for (id_left, id_right) in hard_positives[:5]:
    print(np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]))
    display_pair_attention((id_left, id_right), 'title')

0.5284208


Unnamed: 0,micromat,podlock,(,mac,)
855,0.215393,0.375849,0.169307,0.123075,0.116376


Unnamed: 0,micromat,podlock,ipod,utility,software
1552,0.21918,0.408113,0.248585,0.106758,0.017365


0.54449064


Unnamed: 0,microsoft,licenses,word,olp,c,(,05903871,)
423,0.068856,0.176699,0.277619,0.23753,0.109214,0.041581,0.040817,0.047685


Unnamed: 0,microsoft,059,-,03871,molpc,word,sa
3727,0.102101,0.202694,0.196563,0.097102,0.111707,0.138139,0.151694


0.55174214


Unnamed: 0,hijack2
1260,1.0


Unnamed: 0,me,too,software,800801,-,hijack2,win,98,nt,2000,xp,/,mac,10,.,0,or,higher
2405,0.06099,0.021998,0.00693,0.018662,0.110718,0.123963,0.101209,0.056377,0.046122,0.035507,0.04082,0.037705,0.04139,0.056448,0.049012,0.04138,0.030806,0.035587


0.54825926


Unnamed: 0,band,in,a,box,2007
1178,0.14964,0.152526,0.252355,0.236936,0.208543


Unnamed: 0,pg,music,band,in,a,box,software,for,windows,production
3513,0.17707,0.077173,0.138193,0.11261,0.112008,0.053614,0.018666,0.030654,0.054105,0.0318


0.5696411


Unnamed: 0,zero,-,g,pro,pack,for,garageband,(,appleloops,)
797,0.121497,0.073245,0.074147,0.071646,0.116185,0.099274,0.300285,0.053835,0.041032,0.048855


Unnamed: 0,east,west,propack,for,garageband,av,production,software
3960,0.113621,0.160905,0.182362,0.151915,0.335159,0.038616,0.012351,0.00507


In [17]:
false_negatives = list(test_pos_pair_set - test_found_pair_set)
len(false_negatives)

18

In [18]:
for (id_left, id_right) in false_negatives[:5]:
    print(np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]))
    display_pair_attention((id_left, id_right), 'title')

0.4690476


Unnamed: 0,foreign,policy,&,reform,(,win,/,mac,),jewel,case
112,0.081279,0.140057,0.13177,0.150604,0.041769,0.08106,0.059594,0.054723,0.030318,0.023398,0.038201


Unnamed: 0,fogware,publishing,-,10356,high,school,us,history,2,foreign,policy,&,reform
3052,0.022923,0.038884,0.047699,0.052933,0.102283,0.131801,0.104791,0.167174,0.104292,0.034214,0.03619,0.022232,0.02069


0.4419197


Unnamed: 0,microspot,macdraft,pe,(,mac,)
644,0.174039,0.170291,0.261966,0.151162,0.119979,0.122563


Unnamed: 0,microspot,macdraft,pe,personal,edition
1693,0.146526,0.13214,0.223883,0.371348,0.126103


0.45072076


Unnamed: 0,clifford,the,big,red,dog,-,thinking,adventures
50,0.170498,0.120651,0.174603,0.149637,0.16362,0.100046,0.074657,0.046289


Unnamed: 0,clifford,thinking
3453,0.670242,0.329758


0.47478366


Unnamed: 0,omniweb,5,.,0
725,0.327518,0.278055,0.209434,0.184993


Unnamed: 0,omni,web,5,.,0
1803,0.217174,0.238057,0.2278,0.169549,0.147421


0.38709295


Unnamed: 0,hijack2
1260,1.0


Unnamed: 0,hijack2,identity,and,data,security,suite
3367,0.294043,0.213012,0.160768,0.199087,0.07029,0.0628


In [19]:
false_positives = list(test_found_pair_set - test_pos_pair_set)
len(false_positives)

597

In [20]:
for (id_left, id_right) in false_positives[:5]:
    print(np.dot(test_left_vector_dict[id_left], test_right_vector_dict[id_right]))
    display_pair_attention((id_left, id_right), 'title')

0.82730997


Unnamed: 0,print,shop,22,pro,publisher,deluxe
1259,0.085782,0.09686,0.169252,0.188322,0.316136,0.143647


Unnamed: 0,printshop,20,pro,publisher
2800,0.197008,0.283588,0.22428,0.295125


0.5538443


Unnamed: 0,ae,mappoint,2006,cd
469,0.216782,0.254014,0.315596,0.213608


Unnamed: 0,microsoft,mappoint,2006,with,gps,locator,(,pc,)
4392,0.074822,0.176195,0.225445,0.163839,0.160655,0.092336,0.050313,0.03054,0.025854


0.5290783


Unnamed: 0,upg,serverlock,for,solaris,gold
200,0.201077,0.20001,0.185658,0.268595,0.14466


Unnamed: 0,freeverse,software,5014,-,toysight,gold,(,mac,10,.,2,or,higher,)
2785,0.017963,0.008756,0.051649,0.067613,0.180118,0.141541,0.083792,0.080356,0.102073,0.088178,0.06743,0.037098,0.03833,0.035104


0.5653949


Unnamed: 0,apple,ilife,',06,family,pack,(,mac,dvd,),[,older,version,]
860,0.105738,0.120886,0.120593,0.119797,0.133742,0.137395,0.064262,0.050234,0.03472,0.031217,0.027588,0.024085,0.01752,0.012223


Unnamed: 0,apple,.,mac,family,pack,software,-,five,user,license,with,1,year,subscription,&
3478,0.064013,0.072934,0.01367,0.072934,0.050594,0.00308,0.058373,0.021907,0.062834,0.07002,0.075262,0.071686,0.08711,0.065417,0.035534


0.51935875


Unnamed: 0,net,ad,creator
1138,0.351299,0.292414,0.356287


Unnamed: 0,roxio,easy,media,creator,9,suite,software,for,windows,authoring
3742,0.110653,0.192571,0.165696,0.202016,0.112364,0.059316,0.01361,0.024468,0.045945,0.062368
