In [1]:
#| default_exp metrics

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [17]:
#| export
import torch, numpy as np, scipy.sparse as sp
from collections import OrderedDict

from fastcore.utils import *
from fastcore.meta import *

import xclib.evaluation.xc_metrics as xm
from xclib.utils.sparse import rank, binarize

## XCMetric

In [8]:
#| export
class XCMetric:

    def __init__(self, func, n_lbl:int, filterer:Optional[Union[np.array,sp.csr_matrix]]=None, **kwargs):
        self.func, self.n_lbl, self.filterer, self.kwargs = func, n_lbl, filterer, kwargs

    def reset(self):
        self.output = []

    def accumulate(self, **kwargs):
        self.output.append(kwargs)

    def __call__(self, **kwargs):
        self.reset()
        self.accumulate(**kwargs)
        return self.value

    def apply_filter(self, data):
        if self.filterer is not None:
            data[self.filterer[:,0], self.filterer[:,1]] = 0
            data.eliminate_zeros()
        return data

    def get_pred(self, output):
        data = (output['pred_score'], output['pred_idx'], output['pred_ptr'])
        pred = sp.csr_matrix(data, shape=(len(data[2])-1, self.n_lbl))
        pred.sum_duplicates()
        return self.apply_filter(pred)

    def get_targ(self, output):
        data = (torch.full((len(output['targ_idx']),), 1), output['targ_idx'], output['targ_ptr'])
        targ = sp.csr_matrix(data, shape=(len(data[2])-1, self.n_lbl))
        targ.sum_duplicates()
        return self.apply_filter(targ)
    
    @property
    def value(self):
        if len(self.output) == 0: return
        output = {k:torch.cat([o[k] for o in self.output]) for k in self.output[0]}
        output['targ_ptr'] = torch.cat([torch.tensor([0]), output['targ_ptr'].cumsum(dim=0)])
        output['pred_ptr'] = torch.cat([torch.tensor([0]), output['pred_ptr'].cumsum(dim=0)])
        
        pred, targ = self.get_pred(output), self.get_targ(output)
        return self.func(pred, targ, **self.kwargs)


In [9]:
#| export
def precision(inp:sp.csr_matrix, 
              targ:sp.csr_matrix, 
              prop:sp.csr_matrix=None, 
              k:Optional[int]=5, 
              pa:Optional[float]=0.55, 
              pb:Optional[float]=1.5, 
              repk:Optional[List]=None):
    
    name = ['P', 'N'] if prop is None else ['P', 'N', 'PSP', 'PSN']
    repk = [k] if repk is None else set(repk+[k])
    prop = None if prop is None else xm.compute_inv_propesity(prop, A=pa, B=pb)
    
    metric = xm.Metrics(true_labels=targ, inv_psp=prop)
    prec = metric.eval(inp, k)
    return {f'{n}@{r}': prec[i][r-1] for i,n in enumerate(name) for r in repk if r <= k}
    

In [10]:
#| export
@delegates(precision)
def Precision(n_lbl, filterer=None, **kwargs):
    return XCMetric(precision, n_lbl, filterer, **kwargs)
    

In [11]:
#| export
def recall(inp:sp.csr_matrix, 
           targ:sp.csr_matrix, 
           k:Optional[int]=5, 
           repk:Optional[List]=None):
    
    repk = [k] if repk is None else set(repk+[k])
    recl = xm.recall(inp, targ, k=k)
    return {f'R@{o}':recl[o-1] for o in repk if o <= k}
    

In [12]:
#| export
@delegates(precision)
def Recall(n_lbl, filterer=None, **kwargs):
    return XCMetric(recall, n_lbl, filterer, **kwargs)
    

In [13]:
#| export
def prec_recl(inp:sp.csr_matrix, 
              targ:sp.csr_matrix,
              prop:sp.csr_matrix=None,
              pa:Optional[float]=0.55,
              pb:Optional[float]=1.5,
              pk:Optional[int]=5,
              rep_pk:Optional[List]=None,
              rk:Optional[int]=5,
              rep_rk:Optional[List]=None):
    metric = precision(inp, targ, prop, k=pk, pa=pa, pb=pb, repk=rep_pk)
    metric.update(recall(inp, targ, k=rk, repk=rep_rk))
    return metric
    

In [37]:
#| export
@delegates(prec_recl)
def PrecRecl(n_lbl, filterer=None, **kwargs):
    return XCMetric(prec_recl, n_lbl, filterer, **kwargs)
    

In [38]:
def mrr_bier(
    qrels: dict[str, dict[str, int]],
    results: dict[str, dict[str, float]],
    k_values: list[int],
    output_type: str = "mean",
) -> tuple[dict[str, float]]:
    MRR = {}

    for k in k_values:
        MRR[f"MRR@{k}"] = []

    k_max, top_hits = max(k_values), {}

    for query_id, doc_scores in results.items():
        top_hits[query_id] = sorted(
            doc_scores.items(), key=lambda item: item[1], reverse=True
        )[0:k_max]

    for query_id in top_hits:
        query_relevant_docs = {
            doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
        }
        for k in k_values:
            rr = 0
            for rank, hit in enumerate(top_hits[query_id][0:k]):
                if hit[0] in query_relevant_docs:
                    rr = 1.0 / (rank + 1)
                    break
            MRR[f"MRR@{k}"].append(rr)

    if output_type == "mean":
        for k in k_values:
            MRR[f"MRR@{k}"] = round(sum(MRR[f"MRR@{k}"]) / len(qrels), 5)

    elif output_type == "all":
        pass

    return MRR

def mrr(inp:sp.csr_matrix,
        targ:sp.csr_matrix,
        k:Optional[List]=[10]):
    assert inp.shape[0] == targ.shape[0]
    assert inp.shape[1] == targ.shape[1]

    qrels, results = dict(), dict()

    for i in range(inp.shape[0]):
        results[i] = {k:v for k,v in zip(inp[i].indices, inp[i].data)}
        qrels[i] = {i:1.0 for i in targ[i].indices}

    return mrr_bier(qrels, results, k_values=k)
    

In [39]:
#| export
def mrr(inp:sp.csr_matrix,
        targ:sp.csr_matrix,
        k:Optional[List]=[10]):
    ranks, targ = rank(inp), binarize(targ)
    metric = dict()
    for i in sorted(k, reverse=True):
        rs = ranks.copy()
        rs.data[rs.data > i] = 0.0
        rs.eliminate_zeros()
        rs = rs.multiply(targ)
        rs.data = 1.0/rs.data
        metric[f'MRR@{i}'] = rs.max(axis=1).mean()
    return metric

@delegates(mrr)
def Mrr(n_lbl, filterer=None, **kwargs):
    return XCMetric(mrr, n_lbl, filterer, **kwargs)
    

In [40]:
#| export
def prec_recl_mrr(inp:sp.csr_matrix,
        targ:sp.csr_matrix,
        prop:sp.csr_matrix=None,
        pa:Optional[float]=0.55,
        pb:Optional[float]=1.5,
        pk:Optional[int]=5,
        rep_pk:Optional[List]=None,
        rk:Optional[int]=5,
        rep_rk:Optional[List]=None,
        mk:Optional[List]=[10]):
    metric = prec_recl(inp, targ, prop=prop, pa=pa, pb=pb, pk=pk, rep_pk=rep_pk,
            rk=rk, rep_rk=rep_rk)
    metric.update(mrr(inp, targ, k=mk))
    return metric

@delegates(prec_recl_mrr)
def PrecReclMrr(n_lbl, filterer=None, **kwargs):
    return XCMetric(prec_recl_mrr, n_lbl, filterer, **kwargs)


In [42]:
#| export
def sort_xc_metrics(metric):
    order = {'P':1, 'N':2, 'PSP':3, 'PSN':4, 'R':5, 'PSR':6}
    def get_key(a,b): return (order.get(a,7), int(b)) 
    def sort_fn(k): return get_key(*k.split('@'))
    
    ord_metric = OrderedDict()
    for k in sorted(metric, key=sort_fn): metric[k] = ord_metric[k]
    return ord_metric
    

### Example

In [43]:
output = {}
output['targ_idx'] = torch.tensor([1, 3, 5, 6, 9])
output['targ_ptr'] = torch.tensor([2, 2, 1])

output['pred_idx'] = torch.tensor([1, 2, 5, 5, 6, 9])
output['pred_score'] = torch.tensor([0.5, 0.4, 0.2, 0.3, 0.1, 0.6])
output['pred_ptr'] = torch.tensor([3, 2, 1])

filterer = np.array([[0, 3]])

In [44]:
m = PrecRecl(10, filterer, pk=10, rk=20, rep_pk=[1, 3, 5, 10], rep_rk=[10, 15, 20])

In [49]:
m = PrecReclMrr(10, filterer, pk=10, rk=20, rep_pk=[1, 3, 5, 10], rep_rk=[10, 15, 20], mk=[5, 10])

In [50]:
m(**output)

{'P@1': 1.0,
 'P@10': 0.13333333333333333,
 'P@3': 0.4444444444444444,
 'P@5': 0.26666666666666666,
 'N@1': 1.0,
 'N@10': 1.0,
 'N@3': 1.0,
 'N@5': 1.0,
 'R@10': 1.0,
 'R@20': 1.0,
 'R@15': 1.0,
 'MRR@10': 1.0,
 'MRR@5': 1.0}