# core

In [None]:
#| default_exp core

In [None]:
#| export
import pandas as pd, numpy as np, logging, sys, re, os, torch
import torch.nn.functional as F
from torch.utils.data import Sampler
from itertools import chain
from scipy import sparse
from IPython.display import display
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import List, Dict, Union, Optional, Any
from fastcore.dispatch import *
from fastcore.basics import *

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## `Info`: LOADS METADATA

In [None]:
#| export
def show_data(x:Dict, n:Optional[int]=10, seed:Optional[int]=None):
    with pd.option_context('display.max_colwidth', None):
        display(pd.DataFrame(x).sample(n, random_state=seed))

In [None]:
#| export
class Info():

    def __init__(self):
        self.tokz, self.info = None, None
        
    @staticmethod
    def _read_text(fname:str, enc:Optional[str]='latin-1'):
        with open(fname, encoding=enc) as f:
            info = [o[:-1] for o in f]
        return info
        
    @staticmethod
    def _read_info(fname:str, sep:Optional[str]='->', cols:Optional[List]=None, enc:Optional[str]='latin-1'):
        info = Info._read_text(fname, enc=enc)
        info = list(zip(*[o.split(sep) for o in info]))
        cols = list(range(len(info))) if cols is None else cols
        if len(cols) != len(info): raise ValueError(f'`cols` and `info` should have same number of elements.')
        return {p:q for p,q in zip(cols, info)}

    def read_info(self, fname:Optional[str], sep:Optional[str]='->', cols:Optional[List]=None, enc:Optional[str]='latin-1'):
        self.info = Info._read_info(fname, sep, cols, enc)
        return self.info
    
    def tokenize(self, fld:Union[int, str], tokz:Union[str, PreTrainedTokenizerBase], max_len:Optional[int]=None):
        if self.tokz is None: self.tokz = tokz if isinstance(tokz, PreTrainedTokenizerBase) else AutoTokenizer.from_pretrained(tokz)
        fld = list(self.info.keys())[0] if fld is None else fld
        if fld is None: logging.info(f'`fld` not given as input, so value set to {fld}.')
        if fld not in self.info: raise ValueError(f'`{fld}` is invalid `fld` value.')
        self.info.update(self.tokz(self.info[fld], truncation=True, max_length=max_len))
        return self.info

    def show_data(self, n:Optional[int]=10, seed:Optional[int]=None):
        with pd.option_context('display.max_colwidth', None):
            display(pd.DataFrame(self.info).sample(n, random_state=seed))

    def __len__(self):
        if self.info is None: return 0
        n_info = [len(v) for v in self.info.values()]
        if len(n_info) == 0: raise ValueError('`info` cannot be empty.')
        if not np.all([o == n_info[0] for o in n_info]): raise ValueError('`info` should contain features with same length.')
        return n_info[0]

    @classmethod
    def from_txt(cls, 
                 fname:str, 
                 sep:Optional[str]='->', 
                 cols:Optional[List]=None, 
                 enc:Optional[str]='latin-1',
                 use_tokz:Optional[bool]=False,
                 tokz:Optional[Union[str,PreTrainedTokenizerBase]]=None,
                 fld:Optional[str]=None,
                 max_len:Optional[int]=None, 
                 **kwargs):
        self = cls()
        self.info = self.read_info(fname, sep, cols, enc)
        if use_tokz: self.tokenize(fld, tokz, max_len)
        return self.info
        

### Example

In [None]:
fname = f'/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/train.raw.txt'
cols = ['identifier', 'input_text']

In [None]:
info = Info.from_txt(fname, cols=cols, use_tokz=True, tokz='bert-base-uncased', fld=cols[1], max_len=32)

In [None]:
show_data(info, n=5)

Unnamed: 0,identifier,input_text,input_ids,token_type_ids,attention_mask
91182,Convertible_arbitrage,Convertible arbitrage,"[101, 22840, 12098, 16313, 24449, 102]","[0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1]"
31680,Leopoldo_Lugones,Leopoldo Lugones,"[101, 12752, 2080, 11320, 7446, 2229, 102]","[0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1]"
158476,Interstate_269,Interstate 269,"[101, 7553, 25717, 102]","[0, 0, 0, 0]","[1, 1, 1, 1]"
390138,Institute_for_Health_Freedom,Institute for Health Freedom,"[101, 2820, 2005, 2740, 4071, 102]","[0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1]"
248643,Abuja_Securities_and_Commodities_Exchange,Abuja Securities and Commodities Exchange,"[101, 8273, 3900, 12012, 1998, 21955, 3863, 102]","[0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1]"


In [None]:
info = Info()
_ = info.read_info(fname, cols=cols)

In [None]:
info.show_data(n=10)

Unnamed: 0,identifier,input_text
247245,20_minutes_(Switzerland),20 minutes (Switzerland)
2305,Isoelectric_point,Isoelectric point
538866,Moldova_in_the_Eurovision_Song_Contest_2013,Moldova in the Eurovision Song Contest 2013
100310,Lemhi_Pass,Lemhi Pass
639082,Order_of_precedence_in_Kelantan,Order of precedence in Kelantan
510129,"RTS,S","RTS,S"
216809,Geering_(automobile),Geering (automobile)
569292,Sole_Survivor_(2013_film),Sole Survivor (2013 film)
175818,Puerto_Rico_National_Cemetery,Puerto Rico National Cemetery
615747,Anna_IllÃ©s,Anna IllÃ©s


In [None]:
len(info)

693082

## `Filterer`: XC FILTERING

In [None]:
#| export
class Filterer:

    @staticmethod
    def load_filter(fname:str):
        if fname is not None and os.path.exists(fname): return np.loadtxt(fname, dtype=np.int64)
        
    @staticmethod
    def generate(train_id:List, test_id:List, lbl_id:List, train_lbl:sparse.csr_matrix, test_lbl:sparse.csr_matrix):
        _, train_idx, lbl2train_idx = np.intersect1d(train_id, lbl_id, return_indices=True)
        train_lbl_filterer = np.vstack([train_idx, lbl2train_idx]).T
        
        _, test_idx, lbl2test_idx = np.intersect1d(test_id, lbl_id, return_indices=True)
        test_lbl_filterer = np.vstack([test_idx, lbl2test_idx]).T
        
        train_udx, train_udx2idx = np.unique(train_idx, return_index=True)
        lbl2test_udx, lbl2test_udx2idx = np.unique(lbl2test_idx, return_index=True)
        
        _test_lbl_filterer = train_lbl[train_udx][:, lbl2test_udx].T
        
        rows, cols = _test_lbl_filterer.nonzero()
        test_idx = test_idx[lbl2test_udx2idx[rows]]
        lbl2test_idx = lbl2train_idx[train_udx2idx[cols]]
        
        _test_lbl_filterer = np.vstack([test_idx, lbl2test_idx]).T
        test_lbl_filterer = np.vstack([test_lbl_filterer, _test_lbl_filterer])
    
        return train_lbl_filterer, test_lbl_filterer

    @staticmethod
    def sample(f:np.array, sz:tuple, idx:List):
        f = sparse.coo_matrix((np.full(f.shape[0],1), (f[:, 0], f[:, 1])), shape=sz).tocsr()
        f = f[idx].tocoo()
        return np.vstack([f.row, f.col]).T

    @staticmethod
    def prune(data:sparse.csr_matrix, data_filterer:np.array):
        data = data.copy()
        data[data_filterer[:,0], data_filterer[:,1]] = 0
        data.eliminate_zeros()
        
        idx = np.where(data.getnnz(axis=1) > 0)[0]
        return data[idx], Filterer.sample(data_filterer, data.shape, idx), idx

    @staticmethod
    def apply(data:sparse.csr_matrix, data_filterer:np.array):
        data[data_filterer[:,0], data_filterer[:,1]] = 0
        data.eliminate_zeros()
        return data

        

## Helper

In [None]:
#| export
def store_attr(names=None, self=None, but='', cast=False, store_args=None, is_none=True, **attrs):
    fr = sys._getframe(1)
    args = argnames(fr, True)
    if self: args = ('self', *args)
    else: self = fr.f_locals[args[0]]
    if store_args is None: store_args = not hasattr(self,'__slots__')
    if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {}
    anno = annotations(self) if cast else {}
    if names and isinstance(names,str): names = re.split(', *', names)
    ns = names if names is not None else getattr(self, '__slots__', args[1:])
    added = {n:fr.f_locals[n] for n in ns}
    attrs = {**attrs, **added}
    if isinstance(but,str): but = re.split(', *', but)
    attrs = {k:v for k,v in attrs.items() if k not in but}
    return _store_attr(self, anno, is_none, **attrs)
    

In [None]:
#| export
def _store_attr(self, anno, is_none, **attrs):
    stored = getattr(self, '__stored_args__', None)
    for n,v in attrs.items():
        if n in anno: v = anno[n](v)
        if is_none or v is not None: setattr(self, n, v)
        if stored is not None: stored[n] = v
       

In [None]:
#| export
def get_attr(x, attr:str):
    for a in attr.split('.'): x = getattr(x, a)
    return x

## `BalancedClusters`: CLUSTERING

In [None]:
#| export
class BalancedClusters:

    @staticmethod
    def binary_kmeans(x:torch.Tensor, idx:Optional[torch.Tensor]=None, tol:Optional[float]=1e-4):
        n, x = x.shape[0], F.normalize(x)
        if n == 1: return (idx,)
            
        rnd_idx = torch.randperm(n)[:2]
        c = x[rnd_idx]
        sim = x@c.T
        
        old_sc, new_sc = None, None
        while old_sc is None or new_sc - old_sc >= tol:
            p,q = torch.chunk(torch.argsort(sim[:,1]-sim[:,0]), 2)
            c = torch.vstack([x[p].mean(dim=0, keepdim=True), x[q].mean(dim=0, keepdim=True)])
            sim = x@c.T
            sc = sim[p,0].sum() + sim[q,1].sum()
            new_sc, old_sc = sc/n, new_sc
        if idx is None: return p,q
        else: return (idx[p],idx[q])

    @staticmethod
    def proc(x:torch.Tensor, n_cluster:int, cluster:Optional[List[torch.Tensor]]=None):
        def _nearest_two_power(x): return 2**int(np.ceil(np.log2(x)))
        n_cluster = _nearest_two_power(n_cluster)
        cluster = (torch.arange(x.shape[0]),) if cluster is None else cluster
        nsz, osz = None, None
        while len(cluster) < n_cluster and (nsz != osz or osz is None):
            cluster = list(chain(*[BalancedClusters.binary_kmeans(x[o],o) for o in cluster]))
            nsz,osz = len(cluster),nsz
        x2cluster = torch.zeros(len(x), dtype=torch.int64)
        for i,c in enumerate(cluster): x2cluster[c] = i
        return cluster, x2cluster
        

### Example

In [None]:
x = torch.randn(10, 3)

In [None]:
BalancedClusters.proc(x, 4)

([tensor([8, 0, 7]), tensor([1, 6]), tensor([9, 5, 2]), tensor([3, 4])],
 tensor([0, 1, 2, 3, 3, 2, 1, 0, 0, 2]))

## `ClusterGroupedSampler`: CLUSTER BASED SAMPLING

In [None]:
#| export
class ClusterGroupedSampler(Sampler):

    def __init__(self, n:int, cluster:Optional[List]=None, generator:Optional[Any]=None):
        store_attr('n,cluster,generator')

    def __len__(self):
        return self.n

    def set_cluster(self, cluster): self.cluster = cluster

    def __iter__(self):
        if self.cluster is None: return iter(torch.randperm(self.n).tolist())
        csz = sum([len(o) for o in self.cluster])
        if len(self) != csz: raise ValueError(f'`n`({len(self)}) should be equal to total elements in `cluster`({csz})')
        cluster = [self.cluster[i] for i in torch.randperm(len(self.cluster))]
        indices = torch.hstack([o[torch.randperm(len(o))] for o in cluster]).tolist()
        return iter(indices)
        

### Example

In [None]:
from torch.utils.data import DataLoader

In [None]:
n = 16
x = torch.randn(n, 3)

def dlo(dl): 
    for b in dl: print(x2cluster[b])


In [None]:
sampler = ClusterGroupedSampler(n)
dl = DataLoader(torch.arange(len(x)), batch_size=2, sampler=sampler)

In [None]:
dlo(dl)

tensor([3, 4])
tensor([4, 7])
tensor([0, 7])
tensor([5, 2])
tensor([6, 2])
tensor([1, 3])
tensor([1, 5])
tensor([0, 6])


In [None]:
cluster, x2cluster = BalancedClusters.proc(x, 5)
dl.sampler.set_cluster(cluster)

In [None]:
dlo(dl)

tensor([1, 1])
tensor([3, 3])
tensor([5, 5])
tensor([2, 2])
tensor([6, 6])
tensor([4, 4])
tensor([7, 7])
tensor([0, 0])


In [None]:
cluster, x2cluster = BalancedClusters.proc(x, 5)
dl.sampler.set_cluster(cluster)

In [None]:
dlo(dl)

tensor([0, 0])
tensor([1, 1])
tensor([3, 3])
tensor([6, 6])
tensor([4, 4])
tensor([2, 2])
tensor([5, 5])
tensor([7, 7])
