In [1]:
#| default_exp 04_msmarco-entity-conflation

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
import scipy.sparse as sp, numpy as np
from tqdm.auto import tqdm
from termcolor import colored, COLORS
from scipy.sparse.csgraph import connected_components

from xcai.main import *
from xcai.data import XCDataset
from xcai.analysis import *

import xclib.evaluation.xc_metrics as xc_metrics
from xclib.utils.sparse import retain_topk

## Linker predictions

### `LLaMA`

In [4]:
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'
pkl_file = f'{pkl_dir}/mogicX/msmarco_data-entity-llama_distilbert-base-uncased_sxc_exact.joblib'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_llama_exact.json'
config_key = 'data_entity-llama_exact'

block = build_block(pkl_file, config_file, use_sxc=True, config_key=config_key)
linker_block = block.linker_dset('ent_meta', remove_empty=False)

In [None]:
pred_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogicX/01-msmarco-llama-entity-linker-002/predictions/'
pred_lbl = sp.load_npz(f'{pred_dir}/test_predictions_full.npz')

pred_block = get_pred_dset(retain_topk(pred_lbl, k=5), linker_block.test.dset)

disp = CompareDataset(XCDataset._initialize(linker_block.test.dset), pred_block, '1. ', '2. ')

In [None]:
metric = pointwise_eval(pred_lbl, linker_block.test.dset.data.data_lbl, topk=5)

In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:-100-1:-1]
disp.show(idxs)

[5m[7m[35m1.  data_input_text[0m [35m: how much protein and iron does fish have[0m
[5m[7m[35m2.  data_input_text[0m [35m: how much protein and iron does fish have[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: ['protein', 'iron', 'fish'][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['protein', 'iron', 'fish', 'fish symbol', 'vitamin iron'][0m

[5m[7m[35m1.  data_input_text[0m [35m: how far is lake como to lugano switzerland[0m
[5m[7m[35m2.  data_input_text[0m [35m: how far is lake como to lugano switzerland[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: ['lake como', 'lugano', 'switzerland'][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['switzerland', 'lake', 'lake como', 'lake geneva', 'lugano'][0m

[5m[7m[35m1.  data_input_text[0m [35m: phone number for autozone in calhoun ga[0m
[5m[7m[35m2.  data_input_text[0m [35m: phone number for autozone in calhoun ga[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: ['autozone', 'calh

In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:100]
disp.show(idxs)

[5m[7m[35m1.  data_input_text[0m [35m: blood diseases that are sexually transmitted[0m
[5m[7m[35m2.  data_input_text[0m [35m: blood diseases that are sexually transmitted[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: [][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['sexually transmitted disease', 'sexually transmitted', 'sexually transmitted diseases', 'bacterial sexually transmitted disease', 'sexual transmitted diseases'][0m

[5m[7m[35m1.  data_input_text[0m [35m: define bona fides[0m
[5m[7m[35m2.  data_input_text[0m [35m: define bona fides[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: [][0m
[5m[7m[96m2.  lbl2data_input_text[0m [96m: ['bona fide', 'fides', 'bona', 'bona fide need rule', 'bona fide needs rule'][0m

[5m[7m[35m1.  data_input_text[0m [35m: effects of detox juice cleanse[0m
[5m[7m[35m2.  data_input_text[0m [35m: effects of detox juice cleanse[0m
[5m[7m[96m1.  lbl2data_input_text[0m [96m: [][0m
[5m[7m[96

## `GPT`

In [None]:
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'
pkl_file = f'{pkl_dir}/mogicX/msmarco_data-entity-gpt_distilbert-base-uncased_sxc_exact.joblib'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
config_key = 'data_entity-gpt_exact'

block = build_block(pkl_file, config_file, use_sxc=True, config_key=config_key)
linker_block = block.linker_dset('ent_meta', remove_empty=False)

In [None]:
pred_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogicX/01-msmarco-gpt-entity-linker-001/predictions/'
pred_lbl = sp.load_npz(f'{pred_dir}/test_predictions_full.npz')

pred_block = get_pred_dset(retain_topk(pred_lbl, k=5), linker_block.test.dset)

disp = CompareDataset(XCDataset._initialize(linker_block.test.dset), pred_block, '1. ', '2. ')

In [None]:
metric = pointwise_eval(pred_lbl, linker_block.test.dset.data.data_lbl, topk=5)

In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:-100-1:-1]
disp.show(idxs)

[5m[7m[92m1.  data_input_text[0m [92m: does adult acne rosacea give you blepharitis[0m
[5m[7m[92m2.  data_input_text[0m [92m: does adult acne rosacea give you blepharitis[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: ['adult acne', 'rosacea', 'blepharitis'][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['blepharitis', 'rosacea', 'skin rosacea', 'acne vulgaris', 'adult acne'][0m

[5m[7m[92m1.  data_input_text[0m [92m: how much protein and iron does fish have[0m
[5m[7m[92m2.  data_input_text[0m [92m: how much protein and iron does fish have[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: ['protein', 'iron', 'fish'][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['fish', 'iron', 'protein', 'vitamin iron', 'fish diet'][0m

[5m[7m[92m1.  data_input_text[0m [92m: where does qatar fly[0m
[5m[7m[92m2.  data_input_text[0m [92m: where does qatar fly[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: ['qatar', 'qatar airways', 'doha'][0m


In [None]:
idxs = np.argsort(np.ravel(metric.sum(axis=1)))[:100]
disp.show(idxs)

[5m[7m[92m1.  data_input_text[0m [92m: was hocus pocus filmed in salem[0m
[5m[7m[92m2.  data_input_text[0m [92m: was hocus pocus filmed in salem[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: [][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['salem', 'salem, oregon', 'salem, or', 'hocus pocus', 'salem, massachusetts'][0m

[5m[7m[92m1.  data_input_text[0m [92m: what beach is close to disneyland california[0m
[5m[7m[92m2.  data_input_text[0m [92m: what beach is close to disneyland california[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: [][0m
[5m[7m[91m2.  lbl2data_input_text[0m [91m: ['disneyland', 'disneyland resort', 'venice beach, california', 'disneyland park', 'disneyland resort hotel'][0m

[5m[7m[92m1.  data_input_text[0m [92m: unemployment rate honolulu hawaii[0m
[5m[7m[92m2.  data_input_text[0m [92m: unemployment rate honolulu hawaii[0m
[5m[7m[91m1.  lbl2data_input_text[0m [91m: [][0m
[5m[7m[91m2.  lbl2data_inpu

## Conflation

In [13]:
from xcai.graph.operations import *
import matplotlib.pyplot as plt

In [33]:
data_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogicX/01-msmarco-llama-entity-linker-002/predictions/'

trn_preds = sp.load_npz(f'{data_dir}/train_predictions_full.npz')
tst_preds = sp.load_npz(f'{data_dir}/test_predictions_full.npz')

In [34]:
trn_preds = retain_topk(trn_preds, k=10)
tst_preds = retain_topk(tst_preds, k=10)

In [72]:
tst_preds = sp.load_npz(f'{data_dir}/test_predictions_full.npz')
tst_preds = retain_topk(tst_preds, k=3)

### Random walk

In [76]:
preds = Graph.threshold_on_degree(tst_preds.transpose().tocsr(), thresh=5)
preds = preds.transpose().tocsr()

In [87]:
rnd_tst_pred = Graph.random_walk(preds.transpose().tocsr(), walk_to=2, batch_size=1024, prob_reset=0.8, n_hops=2)

  0%|          | 0/196 [00:00<?, ?it/s]

In [54]:
sp.save_npz(f'{data_dir}/random_walk_train_predictions.npz', rnd_tst_pred)

### Connected components

In [41]:
rnd_trn_lbl = sp.load_npz(f'{data_dir}/random_walk_train_predictions.npz')

In [88]:
n_comp, labels = connected_components(rnd_tst_pred, directed=False, return_labels=True)

components = {}
for lbl,txt in zip(labels,linker_block.train.dset.data.lbl_info['input_text']):
    components.setdefault(lbl, []).append(txt)

In [95]:
n_lbl = 100
colors = list(COLORS.keys())
colors = [colors[i] for i in np.random.permutation(len(colors))]

In [114]:
group_lengths = np.array([len(components[i]) for i in range(len(components))])

In [116]:
idxs = np.argsort(group_lengths)[:-n_lbl]

In [118]:
idxs = np.logical_and(group_lengths > 2, group_lengths < 6)

In [108]:
idxs = np.random.permutation(len(components))[:n_lbl]

In [120]:
for i,idx in enumerate(idxs):
    if i % 2 == 0: print(colored(f'{idx+1}. {" || ".join(components[idx])}', colors[i % len(colors)]))
    else: print(colored(f'{idx+1}. {" || ".join(components[idx])}', colors[i % len(colors)], attrs=["reverse", "blink"]))
        

[36m1. manhattan project[0m
[5m[7m[91m1. manhattan project[0m
[93m1. manhattan project[0m
[5m[7m[90m1. manhattan project[0m
[94m1. manhattan project[0m
[5m[7m[30m2. justice || united states department of justice[0m
[35m1. manhattan project[0m
[5m[7m[32m1. manhattan project[0m
[95m1. manhattan project[0m
[5m[7m[33m1. manhattan project[0m
[34m1. manhattan project[0m
[5m[7m[37m1. manhattan project[0m
[30m1. manhattan project[0m
[5m[7m[96m1. manhattan project[0m
[92m2. justice || united states department of justice[0m
[5m[7m[31m1. manhattan project[0m
[97m1. manhattan project[0m
[5m[7m[36m2. justice || united states department of justice[0m
[91m1. manhattan project[0m
[5m[7m[93m1. manhattan project[0m
[90m1. manhattan project[0m
[5m[7m[94m2. justice || united states department of justice[0m
[30m1. manhattan project[0m
[5m[7m[35m2. justice || united states department of justice[0m
[32m1. manhattan project[0m
[5m[7m

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)

