# analysis

In [1]:
#| default_exp analysis

In [2]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np, re
from typing import Optional, Dict, List, Tuple
from torch.utils.data import Dataset
from scipy import sparse
import matplotlib.pyplot as plt

from fastcore.dispatch import *

from xcai.basics import *
from xcai.data import *
from xcai.learner import XCPredictionOutput

import xclib.utils.sparse as xc_sparse
import xclib.evaluation.xc_metrics as xc_metrics
import xclib.data.data_utils as du 

from IPython.display import HTML

comet_ml is installed but `COMET_API_KEY` is not set.


In [3]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [79]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/processed/'
pkl_file = f'{pkl_dir}/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

In [80]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## Metric

In [6]:
#| export
def pointwise_eval(pred_lbl:sparse.csr_matrix, data_lbl:sparse.csr_matrix, data_lbl_filterer:Optional[np.ndarray]=None,
                   topk:Optional[int]=5, metric:Optional[str]='P', return_type:Optional[str]='M'):
    
    if data_lbl_filterer is not None:
        pred_lbl = Filterer.apply(pred_lbl, data_lbl_filterer)
        data_lbl = Filterer.apply(data_lbl, data_lbl_filterer)
        
    pred_lbl = xc_sparse.retain_topk(pred_lbl, k=topk)
        
    scores = pred_lbl.multiply(data_lbl)
    scores.data[:] = 1
        
    if metric == 'P':
        scores = scores.multiply(1/(topk * data_lbl.shape[0]))
    elif metric == 'R':
        d = data_lbl.getnnz(axis=1) * data_lbl.shape[0]
        scores = scores.multiply(1/d.reshape(-1,1))
    elif metric == 'FN':
        scores = data_lbl - scores
        scores.eliminate_zeros()
        lbl_cnt = data_lbl.getnnz(axis=0)
        scores = scores.multiply(1/(lbl_cnt * data_lbl.shape[1]))

    if return_type == 'M': return scores
    elif return_type == 'L': return np.ravel(scores.sum(axis=0))
    else: return np.ravel(scores.sum(axis=1))
    

## Decile plot

In [7]:
#| export
def equal_volume_split(data_lbl:sparse.csr_matrix, n_split:int):
    lbl_cnt = data_lbl.getnnz(axis=0)
    lbl_idx = np.argsort(-lbl_cnt)
    thresh = lbl_cnt.sum()/n_split
    
    splits,split,cnt = [],[],0
    for idx in lbl_idx:
        cnt += lbl_cnt[idx]
        split.append(idx)

        if cnt > thresh: 
            splits.append(split)
            split,cnt = [],0

    if len(split): splits.append(split)
    
    if len(splits) != n_split: raise ValueError(f'Number of splits created less than {n_split}.')
    splits.append(lbl_idx.tolist())
    
    lbl_cnt = lbl_cnt.astype(np.float32)
    lbl_cnt[lbl_cnt == 0] = np.nan
    info = [f'{i+1}\n{len(split)//1000}K\n{np.nanmean(lbl_cnt[split]):.2f}' for i,split in enumerate(splits)]
    return splits,info
    

In [8]:
#| export
def get_decile_stats(pred_lbl:sparse.csr_matrix, data_lbl:sparse.csr_matrix, data_lbl_filterer:np.ndarray, 
                     n_split:Optional[int]=5, topk:Optional[int]=5, metric:Optional[str]='P'):
    
    evals = pointwise_eval(pred_lbl, data_lbl, data_lbl_filterer, topk=topk, metric=metric)
    splits, info = equal_volume_split(data_lbl, n_split)
    values = [evals[:, split].sum()*100 for split in splits]
    return info, values


In [9]:
#| export
def barplot(scores:Dict, title:Optional[str]='', ylabel:Optional[str]='', figsize:Optional[Tuple]=(15,10)):
    n_proc,n_split = len(scores),len(list(scores.values())[0][0])
    idx, width = np.arange(n_split), 0.8/n_proc
    
    fig, ax = plt.subplots(figsize=figsize)

    shift = 0
    for proc,(info,values) in scores.items(): 
        x = idx + shift
        ax.bar(x=x, height=values, width=width, alpha=1, label=proc)
        shift += width

    shift = (n_proc//2)*width if n_proc%2 else width/2 + ((n_proc-1)//2)*width

    ax.set_title(title, fontsize=22)

    ax.set_xlabel('Quantiles \n (Increasing Freq.)', fontsize=18)
    ax.set_ylabel(ylabel, fontsize=18)

    ax.set_xticks(idx + shift, info, fontsize=14)
    for o in ax.get_yticklabels(): o.set_fontsize(14)

    ax.legend(fontsize=14)
    

In [10]:
#| export
def decile_plot(preds:Dict, n_split:Optional[int]=5, topk:Optional[int]=5, metric:Optional[str]='P', 
                figsize:Optional[Tuple]=(15,10), title:Optional[str]=''):
    scores = {}

    for method, pred in preds.items():
        info, values = get_decile_stats(pred, block.test.dset.data.data_lbl, block.test.data_lbl_filterer, 
                                        n_split=5, topk=5, metric='P')
        scores[method] = (info,values)
    
    barplot(scores, title, f'{metric}@{topk}', figsize)
    

## Display text

### Dataset

In [11]:
#| export
@typedispatch
def get_pred_dset(pred:sparse.csr_matrix, block:XCDataBlock):
    data = MainXCDataset(block.test.dset.data.data_info, pred, block.test.dset.data.lbl_info, 
                         block.test.dset.data.data_lbl_filterer)
    return XCDataset(data, **block.test.dset.meta)


@typedispatch
def get_pred_dset(pred:sparse.csr_matrix, dset:XCDataset):
    data = MainXCDataset(dset.data.data_info, pred, dset.data.lbl_info, dset.data.data_lbl_filterer)
    return XCDataset(data, **dset.meta)

@typedispatch
def get_pred_dset(pred:sparse.csr_matrix, dset:MainXCDataset):
    return MainXCDataset(dset.data_info, pred, dset.lbl_info, dset.data_lbl_filterer)
    

#### Example

In [None]:
pred = sparse.csr_matrix((block.test.dset.n_data, block.train.dset.n_lbl))

In [None]:
o = get_pred_dset(pred, block.test.dset)

In [None]:
o[0]

{'data_identifier': 'Abraham_Lincoln',
 'data_input_text': 'Abraham Lincoln',
 'data_input_ids': [101, 8181, 5367, 102],
 'data_attention_mask': [1, 1, 1, 1],
 'lbl2data_idx': [],
 'lbl2data_identifier': [],
 'lbl2data_input_text': [],
 'lbl2data_input_ids': [],
 'lbl2data_attention_mask': [],
 'cat2data_idx': [19377,
  54316,
  54419,
  62824,
  63173,
  63174,
  63175,
  68208,
  68209,
  69192,
  69193,
  69194,
  69195,
  69196,
  69197,
  69198,
  69199,
  69200,
  69201,
  69202,
  69203,
  69204,
  69205,
  69206,
  69207,
  69208,
  69209,
  69210,
  69211,
  69212,
  69213,
  69214,
  69215,
  69216,
  69217,
  69218,
  69219,
  69220,
  69221],
 'cat2data_identifier': ['Category:Deaths_by_firearm_in_Washington,_D.C.',
  'Category:American_people_of_English_descent',
  'Category:1865_deaths',
  'Category:19th-century_American_politicians',
  'Category:Illinois_Republicans',
  'Category:Members_of_the_Illinois_House_of_Representatives',
  'Category:Members_of_the_United_States_

### Load

In [12]:
#| export
@typedispatch
def get_pred_sparse(out:XCPredictionOutput, n_lbl:int):
    pred_ptr = torch.concat([torch.zeros((1,), dtype=torch.long), out.pred_ptr.cumsum(dim=0)])
    return sparse.csr_matrix((out.pred_score, out.pred_idx, pred_ptr), shape=(len(out.pred_ptr), n_lbl))

@typedispatch
def get_pred_sparse(fname:str, n_lbl:int):
    with open(fname, 'rb') as f: out = pickle.load(f)
    pred_ptr = torch.concat([torch.zeros((1,), dtype=torch.long), out.pred_ptr.cumsum(dim=0)])
    return sparse.csr_matrix((out.pred_score, out.pred_idx, pred_ptr), shape=(len(out.pred_ptr), n_lbl))

@typedispatch
def load_pred_sparse(fname:str):
    o = np.load(fname)
    return sparse.csr_matrix((o['data'], o['indices'], o['indptr']), dtype=float, shape=o['shape'])

def get_output(fname:str, n_lbl:int, pred_type:Optional[str]='repr_output'):
    with open(pname, 'rb') as f: out = pickle.load(f)
    preds,targ = get_output_sparse(**getattr(out,pred_type), n_lbl=n_lbl)
    return preds, targ
    

#### Example

In [18]:
import pickle

mname = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-2/checkpoint-130000/'
fname = f'{mname}/predictions/test_predictions.pkl'

with open(fname, 'rb') as file: out = pickle.load(file)

In [19]:
pred = get_pred_sparse(fname, block.n_lbl)

In [20]:
pred = get_pred_sparse(out, block.n_lbl)

In [21]:
pred

<177515x312330 sparse matrix of type '<class 'numpy.float32'>'
	with 35503000 stored elements in Compressed Sparse Row format>

### Formatter

In [13]:
#| export
def html(text:str, c='green'): return f'<text style=color:{c}>{text}</text>'

In [14]:
#| export
class TextColumns(Dataset):
    
    def __init__(self, x, pat='.*_text$'):
        self.x, self.pat = x, pat
    
    def __getitem__(self, idx):
        o = self.x[idx]
        return {k:v for k,v in o.items() if re.match(self.pat, k)}
    

In [15]:
#| export
def display_text(pred_dset:Dataset, data_dset:Dataset, idxs:List):
    color = [('red','green'), ('black','blue')]
    text = []
    for i,idx in enumerate(idxs):
        c = color[i%len(color)]
        pred_text = "<br>".join([f'{html(k,color[0][0])}: {html(v,color[0][1])}' for k,v in pred_dset[idx].items()])
        data_text = "<br>".join([f'{html(k,color[1][0])}: {html(v,color[1][1])}' for k,v in data_dset[idx].items()])
        text.append("<br>".join([pred_text,data_text]))
    return "<br><br>".join(text)


def compare_text(pred1_dset:Dataset, pred2_dset:Dataset, data_dset:Dataset, idxs:List):
    color = [('red','green'), ('black','blue'), ('orange', 'brown')]
    text = []
    for i,idx in enumerate(idxs):
        c = color[i%len(color)]
        pred1_text = "<br>".join([f'{html(k,color[0][0])}: {html(v,color[0][1])}' for k,v in pred1_dset[idx].items()])
        pred2_text = "<br>".join([f'{html(k,color[1][0])}: {html(v,color[1][1])}' for k,v in pred2_dset[idx].items()])
        data_text = "<br>".join([f'{html(k,color[2][0])}: {html(v,color[2][1])}' for k,v in data_dset[idx].items()])
        text.append("<br>".join([pred1_text,pred2_text,data_text]))
    return "<br><br>".join(text)
    

#### Example

In [52]:
from IPython.display import HTML
import pandas as pd
from xclib.utils.sparse import retain_topk

In [82]:
pattern = r'^(data|lbl2data|cat2data)_input_text$'

pred_dset = TextColumns(get_pred_dset(retain_topk(pred, k=5), block), pat=pattern)
test_dset = TextColumns(block.test.dset, pat=pattern)

In [68]:
prec = pointwise_eval(pred, block.test.dset.data.data_lbl, block.test.dset.data.data_lbl_filterer)
idxs = np.argsort(-np.array(prec.sum(axis=1)).squeeze())

  self._set_arrayXarray(i, j, x)


In [None]:
HTML(display_text(pred_dset, test_dset, idxs=idxs[:5]))

In [93]:
data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
lbl_meta = block.test.dset.meta.lnk_meta.lbl_meta
block.test.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)

In [83]:
pattern = r'^(data|lbl2data|lnk2data)_input_text$'

momos_dset = TextColumns(get_pred_dset(retain_topk(pred, k=5), block), pat=pattern)
oak_dset = TextColumns(get_pred_dset(retain_topk(pred, k=5), block), pat=pattern)
test_dset = TextColumns(block.test.dset, pat=pattern)

In [87]:
momos_eval = pointwise_eval(pred, block.test.dset.data.data_lbl, block.test.dset.data.data_lbl_filterer)
oak_eval = pointwise_eval(pred, block.test.dset.data.data_lbl, block.test.dset.data.data_lbl_filterer)

momos_eval = np.array(momos_eval.sum(axis=1)).squeeze()
oak_eval = np.array(oak_eval.sum(axis=1)).squeeze()

In [88]:
idxs = np.argsort(oak_eval - momos_eval)[:10]

In [127]:
def display_momos(momos_dset:Dataset, oak_dset:Dataset, test_dset:Dataset, idxs:List):
    df = pd.DataFrame([test_dset[i] for i in idxs])

    df = df.rename({'data_input_text':'Document', 'lbl2data_input_text': 'Ground truth labels', 'lnk2data_input_text':'Predicted metadata'}, axis=1)
    momos_df = pd.DataFrame({'MOMOS predictions': [momos_dset[i]['lbl2data_input_text'] for i in idxs]})
    oak_df = pd.DataFrame({'OAK predictions': [oak_dset[i]['lbl2data_input_text'] for i in idxs]})
    
    df = pd.concat([df, momos_df, oak_df], axis=1)
    return df
    

In [128]:
display_momos(momos_dset, oak_dset, test_dset, idxs[:10])

Unnamed: 0,Document,Ground truth labels,Predicted metadata,MOMOS predictions,OAK predictions
0,Outline of Korean language,"[Outline of Esperanto, Outline of German language]","[Korea-related lists, Korean language, Wikipedia outlines]","[Sino-Korean vocabulary, Korean mixed script, New Korean Orthography, Korean language, History of the Korean language]","[Sino-Korean vocabulary, Korean mixed script, New Korean Orthography, Korean language, History of the Korean language]"
1,Schooners Sports and Entertainment,"[Touchdown Atlantic, CFL Expansion]","[Sports marketing, Culture ministries, Sports ministries]","[List of schooners, Sea shanty, CFL Expansion, Great Old One, Double Dribble]","[List of schooners, Sea shanty, CFL Expansion, Great Old One, Double Dribble]"
2,Abu Ishaq al-Saffar al-Bukhari,"[List of Ash'aris and Maturidis, Abu Hanifa, Abu al-Mu'in al-Nasafi, Abu al-Yusr al-Bazdawi, Abu Mansur al-Maturidi]","[Year of birth unknown, Medieval Persian physicians, 8th-century Iranian people]","[List of Shi'a Muslims, Muhammad al-KulaynÃÂ«, Shaykh al-Hur al-ÃÂmilÃÂ«, Shaykh al-SadÃÂ«q, Abu Bakr Shibli]","[List of Shi'a Muslims, Muhammad al-KulaynÃÂ«, Shaykh al-Hur al-ÃÂmilÃÂ«, Shaykh al-SadÃÂ«q, Abu Bakr Shibli]"
3,Sunchales Aeroclub Airport,"[Aviation, Argentina, Transport in Argentina, List of airports in Argentina]","[Airports in Chile, Airports in Argentina, Buenos Aires Province]","[List of airports in the Czech Republic, List of airports in Brazil, List of airports in Argentina, List of airports in Uruguay, Altiport]","[List of airports in the Czech Republic, List of airports in Brazil, List of airports in Argentina, List of airports in Uruguay, Altiport]"
4,Template:KMDb person/doc,[KMDb documentary],"[People and person external link templates, External link templates using Wikidata, Templates that add a tracking category]","[MongoDB, Comparison of structured storage software, RocksDB, KMDb documentary, KMDb person]","[MongoDB, Comparison of structured storage software, RocksDB, KMDb documentary, KMDb person]"
5,Belgian Office for Intellectual Property,"[European patent law, European Patent Office, Patent offices in Europe]","[Intellectual property organizations, Patent offices, Science and technology in Belgium]","[Copyright law of the European Union, Directive on the enforcement of intellectual property rights, European Union Intellectual Property Office, Directive on criminal measures aimed at ensuring the enforcement of intellectual property rights, European patent law]","[Copyright law of the European Union, Directive on the enforcement of intellectual property rights, European Union Intellectual Property Office, Directive on criminal measures aimed at ensuring the enforcement of intellectual property rights, European patent law]"
6,Clean Needle Technique,[Regulation of acupuncture],"[Human positions, Dance technique, Ballroom dance technique]","[Microbiological culture, Toeprinting assay, Blood culture, Ultrasonic cleaning, Sharps waste]","[Microbiological culture, Toeprinting assay, Blood culture, Ultrasonic cleaning, Sharps waste]"
7,Miguel Arribas,"[List of Puerto Ricans, List of mayors of Ponce, Puerto Rico]","[Living people, Puerto Rican independence activists, Presidents of Venezuela]","[Martyrs of the Spanish Civil War, Roman Catholic Diocese of Antipolo, List of Filipino Saints, Blesseds, and Servants of God, Fifteen Martyrs of Bicol, Roman Catholic Archdiocese of Manila]","[Martyrs of the Spanish Civil War, Roman Catholic Diocese of Antipolo, List of Filipino Saints, Blesseds, and Servants of God, Fifteen Martyrs of Bicol, Roman Catholic Archdiocese of Manila]"
8,Hopen (municipality),[List of former municipalities of Norway],"[1838 establishments in Norway, Former municipalities of Norway, 1964 disestablishments in Norway]","[List of former municipalities of Norway, Runde Bridge, lists of villages in Norway, List of churches in HelsingÃÂ¸r Municipality, List of protected areas of HillerÃÂ¸d Municipality]","[List of former municipalities of Norway, Runde Bridge, lists of villages in Norway, List of churches in HelsingÃÂ¸r Municipality, List of protected areas of HillerÃÂ¸d Municipality]"
9,Witch Story,"[List of Italian films of 1989, List of horror films of 1989]","[1987 horror films, Films about witchcraft, Supernatural horror films]","[Witch trials in the early modern period, North Berwick witch trials, Witch trial, Witchcraft accusations against children, Modern witch-hunts]","[Witch trials in the early modern period, North Berwick witch trials, Witch trial, Witchcraft accusations against children, Modern witch-hunts]"
