In [1]:
#| default_exp 04_msmarco-entity-conflation

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
import scipy.sparse as sp, numpy as np
from tqdm.auto import tqdm
from termcolor import colored, COLORS
from scipy.sparse.csgraph import connected_components

from xcai.main import *
from xcai.data import XCDataset
from xcai.analysis import *

import xclib.evaluation.xc_metrics as xc_metrics
from xclib.utils.sparse import retain_topk

  from .autonotebook import tqdm as notebook_tqdm


## Linker predictions

### `LLaMA`

In [10]:
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'
pkl_file = f'{pkl_dir}/mogicX/msmarco_data-entity-llama_distilbert-base-uncased_sxc_exact.joblib'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_llama_exact.json'
config_key = 'data_entity-llama_exact'

block = build_block(pkl_file, config_file, use_sxc=True, config_key=config_key)
linker_block = block.linker_dset('ent_meta', remove_empty=False)

In [None]:
pred_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogicX/01-msmarco-llama-entity-linker-002/predictions/'
pred_lbl = sp.load_npz(f'{pred_dir}/test_predictions_full.npz')

pred_block = get_pred_dset(retain_topk(pred_lbl, k=5), linker_block.test.dset)

disp = CompareDataset(XCDataset._initialize(linker_block.test.dset), pred_block, '1. ', '2. ')

In [None]:
metric = pointwise_eval(pred_lbl, linker_block.test.dset.data.data_lbl, topk=5)

In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:-100-1:-1]
disp.show(idxs)

[5m[7m[35m1.  data_input_text[0m [35m: how much protein and iron does fish have[0m
[5m[7m[35m2.  data_input_text[0m [35m: how much protein and iron does fish have[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: ['protein', 'iron', 'fish'][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['protein', 'iron', 'fish', 'fish symbol', 'vitamin iron'][0m

[5m[7m[35m1.  data_input_text[0m [35m: how far is lake como to lugano switzerland[0m
[5m[7m[35m2.  data_input_text[0m [35m: how far is lake como to lugano switzerland[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: ['lake como', 'lugano', 'switzerland'][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['switzerland', 'lake', 'lake como', 'lake geneva', 'lugano'][0m

[5m[7m[35m1.  data_input_text[0m [35m: phone number for autozone in calhoun ga[0m
[5m[7m[35m2.  data_input_text[0m [35m: phone number for autozone in calhoun ga[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: ['autozone', 'calh

In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:100]
disp.show(idxs)

[5m[7m[35m1.  data_input_text[0m [35m: blood diseases that are sexually transmitted[0m
[5m[7m[35m2.  data_input_text[0m [35m: blood diseases that are sexually transmitted[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: [][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['sexually transmitted disease', 'sexually transmitted', 'sexually transmitted diseases', 'bacterial sexually transmitted disease', 'sexual transmitted diseases'][0m

[5m[7m[35m1.  data_input_text[0m [35m: define bona fides[0m
[5m[7m[35m2.  data_input_text[0m [35m: define bona fides[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: [][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['bona fide', 'fides', 'bona', 'bona fide need rule', 'bona fide needs rule'][0m

[5m[7m[35m1.  data_input_text[0m [35m: effects of detox juice cleanse[0m
[5m[7m[35m2.  data_input_text[0m [35m: effects of detox juice cleanse[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: [][0m
[5m[7m[96

## `GPT`

In [None]:
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'
pkl_file = f'{pkl_dir}/mogicX/msmarco_data-entity-gpt_distilbert-base-uncased_sxc_exact.joblib'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
config_key = 'data_entity-gpt_exact'

block = build_block(pkl_file, config_file, use_sxc=True, config_key=config_key)
linker_block = block.linker_dset('ent_meta', remove_empty=False)

In [None]:
pred_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogicX/01-msmarco-gpt-entity-linker-001/predictions/'
pred_lbl = sp.load_npz(f'{pred_dir}/test_predictions_full.npz')

pred_block = get_pred_dset(retain_topk(pred_lbl, k=5), linker_block.test.dset)

disp = CompareDataset(XCDataset._initialize(linker_block.test.dset), pred_block, '1. ', '2. ')

In [None]:
metric = pointwise_eval(pred_lbl, linker_block.test.dset.data.data_lbl, topk=5)

In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:-100-1:-1]
disp.show(idxs)

[5m[7m[92m1.  data_input_text[0m [92m: does adult acne rosacea give you blepharitis[0m
[5m[7m[92m2.  data_input_text[0m [92m: does adult acne rosacea give you blepharitis[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: ['adult acne', 'rosacea', 'blepharitis'][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['blepharitis', 'rosacea', 'skin rosacea', 'acne vulgaris', 'adult acne'][0m

[5m[7m[92m1.  data_input_text[0m [92m: how much protein and iron does fish have[0m
[5m[7m[92m2.  data_input_text[0m [92m: how much protein and iron does fish have[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: ['protein', 'iron', 'fish'][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['fish', 'iron', 'protein', 'vitamin iron', 'fish diet'][0m

[5m[7m[92m1.  data_input_text[0m [92m: where does qatar fly[0m
[5m[7m[92m2.  data_input_text[0m [92m: where does qatar fly[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: ['qatar', 'qatar airways', 'doha'][0m


In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:100]
disp.show(idxs)

[5m[7m[92m1.  data_input_text[0m [92m: was hocus pocus filmed in salem[0m
[5m[7m[92m2.  data_input_text[0m [92m: was hocus pocus filmed in salem[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: [][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['salem', 'salem, oregon', 'salem, or', 'hocus pocus', 'salem, massachusetts'][0m

[5m[7m[92m1.  data_input_text[0m [92m: what beach is close to disneyland california[0m
[5m[7m[92m2.  data_input_text[0m [92m: what beach is close to disneyland california[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: [][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['disneyland', 'disneyland resort', 'venice beach, california', 'disneyland park', 'disneyland resort hotel'][0m

[5m[7m[92m1.  data_input_text[0m [92m: unemployment rate honolulu hawaii[0m
[5m[7m[92m2.  data_input_text[0m [92m: unemployment rate honolulu hawaii[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: [][0m
[5m[7m[91m2.  lbl2data_inpu

## Conflation

In [4]:
from xcai.graph.random_walk import random_walk

In [32]:
data_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogicX/01-msmarco-llama-entity-linker-002/predictions/'

trn_preds = sp.load_npz(f'{data_dir}/train_predictions_full.npz')
tst_preds = sp.load_npz(f'{data_dir}/test_predictions_full.npz')

In [33]:
trn_preds = retain_topk(trn_preds, k=3)
tst_preds = retain_topk(tst_preds, k=3)

### Random walk

In [39]:
rnd_trn_preds = random_walk(trn_preds.T, row_head_thresh=500, col_head_thresh=500, walk_length=2)
sp.save_npz(f'{data_dir}/random_walk_train_predictions.npz', rnd_trn_preds)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [10:36<00:00,  3.25s/it]


In [40]:
rnd_tst_preds = random_walk(tst_preds.T, row_head_thresh=500, col_head_thresh=500, walk_length=2)
sp.save_npz(f'{data_dir}/random_walk_test_predictions.npz', rnd_tst_preds)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [10:33<00:00,  3.23s/it]


In [None]:
rnd_trn_lbl = random_walk(linker_block.train.dset.data.data_lbl.T, row_head_thresh=50, col_head_thresh=50)
sp.save_npz(f'{data_dir}/random_walk_train_labels.npz', rnd_trn_lbl)

  0%|          | 0/196 [00:00<?, ?it/s]

In [None]:
rnd_tst_lbl = random_walk(linker_block.test.dset.data.data_lbl.T, row_head_thresh=50, col_head_thresh=50)
sp.save_npz(f'{data_dir}/random_walk_test_labels.npz', rnd_tst_lbl)

  0%|          | 0/196 [00:00<?, ?it/s]

### Connected components

In [41]:
rnd_trn_lbl = sp.load_npz(f'{data_dir}/random_walk_train_predictions.npz')

In [45]:
n_comp, labels = connected_components(rnd_trn_lbl, directed=False, return_labels=True)

components = {}
for lbl,txt in zip(labels,linker_block.train.dset.data.lbl_info['input_text']):
    components.setdefault(lbl, []).append(txt)

In [46]:
n_lbl = 100
colors = list(COLORS.keys())
colors = [colors[i] for i in np.random.permutation(len(colors))]

In [47]:
group_lengths = np.array([len(components[i]) for i in range(len(components))])
idxs = np.argsort(group_lengths)[:-n_lbl-1:-1]

In [28]:
idxs = np.random.permutation(len(components))[:n_lbl]

In [29]:
for i,idx in enumerate(idxs[1:]):
    if i % 2 == 0: print(colored(f'{idx+1}. {" || ".join(components[idx])}', colors[i % len(colors)]))
    else: print(colored(f'{idx+1}. {" || ".join(components[idx])}', colors[i % len(colors)], attrs=["reverse", "blink"]))
        

[37m8272. korotkoff sounds[0m
[5m[7m[32m231. nemesis[0m
[34m6753. nerve trunk[0m
[5m[7m[91m5282. genetic contents[0m
[90m7620. acoustic[0m
[5m[7m[30m14937. talas[0m
[95m5139. william crawford gorgas[0m
[5m[7m[30m15848. aviator park[0m
[96m6950. transgressive dune[0m
[5m[7m[35m4042. reeperbahn[0m
[94m20067. caravan of love[0m
[5m[7m[31m1067. avast secureline vpn[0m
[93m12367. blanchester[0m
[5m[7m[92m3548. remote access servers[0m
[97m7635. chovva dosham[0m
[5m[7m[33m2436. ecoli[0m
[5m[7m[37m19358. tireless scorer badge[0m
[32m134. kamea[0m
[5m[7m[34m11898. libeau lane[0m
[91m20673. collaborative divorce[0m
[5m[7m[90m12359. callisto[0m
[30m19045. xo5w20bsp[0m
[5m[7m[95m3318. keiji inafune[0m
[30m6739. competition brisket[0m
[5m[7m[96m22754. shibumi[0m
[35m17664. funtimation[0m
[5m[7m[94m17433. c2h5oh[0m
[31m7883. solodyn[0m
[5m[7m[93m11005. monaka[0m
[92m13394. ble rined octopus[0m
[5m[7m[97m318