In [None]:
#| default_exp metrics

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import torch, numpy as np
from collections import OrderedDict
from scipy import sparse
from fastcore.utils import *
from fastcore.meta import *
from fastai.metrics import AccumMetric
import xclib.evaluation.xc_metrics as xm

## XCMetric

In [None]:
#| export
class XCMetric:

    def __init__(self, func, n_lbl:int, filterer:Optional[Union[np.array,sparse.csr_matrix]]=None, **kwargs):
        self.func, self.n_lbl, self.filterer, self.kwargs = func, n_lbl, filterer, kwargs

    def reset(self):
        self.output = []

    def accumulate(self, **kwargs):
        self.output.append(kwargs)

    def __call__(self, **kwargs):
        self.reset()
        self.accumulate(**kwargs)
        return self.value

    def apply_filter(self, data):
        if self.filterer is not None:
            data[self.filterer[:,0], self.filterer[:,1]] = 0
            data.eliminate_zeros()
        return data

    def get_pred(self, output):
        data = (output['pred_score'], output['pred_idx'], output['pred_ptr'])
        pred = sparse.csr_matrix(data, shape=(len(data[2])-1, self.n_lbl))
        pred.sum_duplicates()
        return self.apply_filter(pred)

    def get_targ(self, output):
        data = (torch.full((len(output['targ_idx']),), 1), output['targ_idx'], output['targ_ptr'])
        targ = sparse.csr_matrix(data, shape=(len(data[2])-1, self.n_lbl))
        targ.sum_duplicates()
        return self.apply_filter(targ)
    
    @property
    def value(self):
        if len(self.output) == 0: return
        output = {k:torch.cat([o[k] for o in self.output]) for k in self.output[0]}
        output['targ_ptr'] = torch.cat([torch.tensor([0]), output['targ_ptr'].cumsum(dim=0)])
        output['pred_ptr'] = torch.cat([torch.tensor([0]), output['pred_ptr'].cumsum(dim=0)])
        
        pred, targ = self.get_pred(output), self.get_targ(output)
        return self.func(pred, targ, **self.kwargs)


In [None]:
#| export
def precision(inp:sparse.csr_matrix, 
              targ:sparse.csr_matrix, 
              prop:sparse.csr_matrix=None, 
              k:Optional[int]=5, 
              pa:Optional[float]=0.55, 
              pb:Optional[float]=1.5, 
              repk:Optional[List]=None):
    
    name = ['P', 'N'] if prop is None else ['P', 'N', 'PSP', 'PSN']
    repk = [k] if repk is None else set(repk+[k])
    prop = None if prop is None else xm.compute_inv_propesity(prop, A=pa, B=pb)
    
    metric = xm.Metrics(true_labels=targ, inv_psp=prop)
    prec = metric.eval(inp, k)
    return {f'{n}@{r}': prec[i][r-1] for i,n in enumerate(name) for r in repk if r <= k}
    

In [None]:
#| export
@delegates(precision)
def Precision(n_lbl, filterer=None, **kwargs):
    return XCMetric(precision, n_lbl, filterer, **kwargs)
    

In [None]:
#| export
def recall(inp:sparse.csr_matrix, 
           targ:sparse.csr_matrix, 
           k:Optional[int]=5, 
           repk:Optional[List]=None):
    
    repk = [k] if repk is None else set(repk+[k])
    recl = xm.recall(inp, targ, k=k)
    return {f'R@{o}':recl[o-1] for o in repk if o <= k}
    

In [None]:
#| export
@delegates(precision)
def Recall(n_lbl, filterer=None, **kwargs):
    return XCMetric(recall, n_lbl, filterer, **kwargs)
    

In [None]:
#| export
def prec_recl(inp:sparse.csr_matrix, 
              targ:sparse.csr_matrix,
              prop:sparse.csr_matrix=None,
              pa:Optional[float]=0.55,
              pb:Optional[float]=1.5,
              pk:Optional[int]=5,
              rep_pk:Optional[List]=None,
              rk:Optional[int]=5,
              rep_rk:Optional[List]=None):
    metric = precision(inp, targ, prop, k=pk, pa=pa, pb=pb, repk=rep_pk)
    metric.update(recall(inp, targ, k=rk, repk=rep_rk))
    return metric
    

In [None]:
#| export
@delegates(prec_recl)
def PrecRecl(n_lbl, filterer=None, **kwargs):
    return XCMetric(prec_recl, n_lbl, filterer, **kwargs)
    

In [None]:
#| export
def sort_xm(xm):
    _ord = {'P':1, 'N':2, 'PSP':3, 'PSN':4, 'R':5, 'PSR':6}
    def _map(a,b): return (_ord.get(a,7), int(b)) 
    def sort_fn(k): return _map(*k.split('@'))
    
    m = OrderedDict()
    for k in sorted(xm, key=sort_fn): m[k] = xm[k]
    return m
    

### Example

In [None]:
output = {}
output['targ_idx'] = torch.tensor([1, 3, 5, 6, 9])
output['targ_ptr'] = torch.tensor([2, 2, 1])

output['pred_idx'] = torch.tensor([1, 2, 5, 5, 6, 9])
output['pred_score'] = torch.tensor([0.5, 0.4, 0.2, 0.3, 0.1, 0.6])
output['pred_ptr'] = torch.tensor([3, 2, 1])

filterer = np.array([[0, 3]])

In [None]:
m = PrecRecl(10, filterer, pk=10, rk=20, rep_pk=[1, 3, 5, 10], rep_rk=[10, 15, 20])

In [None]:
m(**output)

  self._set_arrayXarray(i, j, x)


{'P@1': 1.0,
 'P@10': 0.13333333333333333,
 'P@3': 0.4444444444444444,
 'P@5': 0.26666666666666666,
 'N@1': 1.0,
 'N@10': 1.0,
 'N@3': 1.0,
 'N@5': 1.0,
 'R@10': 1.0,
 'R@20': 1.0,
 'R@15': 1.0}