In [None]:
#| default_exp transform

In [None]:
#| hide 
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, numpy as np, re, pickle
from tqdm.auto import tqdm
from scipy import sparse
from transformers import AutoTokenizer, BatchEncoding
from itertools import chain

from fastcore.utils import *
from fastcore.meta import *
from fastcore.dispatch import *
from fastprogress.fastprogress import master_bar, progress_bar

from xcai.core import *
from xcai.generation.trie import *
from xcai.data import XCDataBlock, BaseXCDataBlock

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [None]:
from xcai.block import *
block = XCBlock.from_cfg('/home/scai/phd/aiz218323/Projects/XC/data', 'train_meta', tokenizer='distilbert-base-uncased')



In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/processed/'
pkl_file = f'{pkl_dir}/wikiseealso_data-meta_distilbert-base-uncased_rm_ramen-cat.pkl'

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
PARAM = {
    
    # Collator arguements
    #-------------------------
    'transform_type': 'xc', 
    'smp_features': [('lbl2data',1,2), ('hlk2data',1,1), ('hlk2lbl2data',2,1)],
    'sampling_features': [('lbl2data',2), ('hlk2data',1), ('hlk2lbl2data',1)],
    
    # Arguements for Info class
    #-------------------------
    'info_column_names': ['identifier', 'input_text'], 
    'use_tokenizer': True, 
    'tokenizer': 'bert-base-cased',
    'tokenization_column': 'input_text',
    'max_sequence_length': 32,
    
    # PadFeatTfm arguements
    #-------------------------
    'pad_side': 'right', 'drop': True, 'ret_t': True, 'in_place': True, 'collapse': True, 'device': 'cpu',
    
    # AlignInputIdsTfm arguements
    #-------------------------
    'inp': 'data', 'targ': 'lbl2data', 'ptr': 'lbl2data_data2ptr',
    
    # Data arguements
    #-------------------------
    'n_data_meta_samples': None,
    'n_lbl_meta_samples': None,
    'n_lbl_samples': None,
    
}

tokz = AutoTokenizer.from_pretrained(PARAM['tokenizer'])
PARAM['sep_tok'] = tokz.sep_token_id
PARAM['pad_tok'] = tokz.pad_token_id

## Batch transforms

### `PadTfm`: PAD FEATURE

In [None]:
#| export
class PadTfm:

    def __init__(self, 
                 pad_tok:Optional[int]=None, 
                 pad_side:Optional[str]='right', 
                 ret_t:Optional[bool]=True,
                 in_place:Optional[bool]=True,
                 **kwargs):
        store_attr('pad_tok,pad_side,ret_t,in_place')

    def _sz_help(self, x:List, sz:List, lev:int):
        if len(x) and isinstance(x[0], list):
            l = max(len(o) for o in x)
            if len(sz) > lev: sz[lev] = max(sz[lev], l)
            else: sz.append(l)
            for o in x: self._sz_help(o, sz, lev+1)

    def get_sz(self, x:List):
        sz = [len(x)]
        self._sz_help(x, sz, len(sz))
        return sz

    def _pad_help(self, x:List, sz:List, pads:List, lev:int):
        if len(x) and isinstance(x[0], list):
            for i,o in enumerate(x): x[i] = self._pad_help(o, sz, pads, lev+1)
        rem = [pads[lev]]*(sz[lev] - len(x))
        return x+rem if self.pad_side == 'right' else rem+x

    def __call__(self, 
                 x:List, 
                 pad_tok:Optional[int]=None, 
                 pad_side:Optional[str]=None, 
                 ret_t:Optional[bool]=None, 
                 in_place:Optional[bool]=None):
        store_attr('pad_tok,pad_side,ret_t,in_place', is_none=False)
        if self.pad_tok is None: raise ValueError('`pad_tok` cannot be None.')
        
        sz = self.get_sz(x)
        pads = [self.pad_tok]
        for s in sz[:0:-1]: pads.insert(0, [pads[0]]*s)
        if not self.in_place: x = x.copy()
        x = self._pad_help(x, sz, pads, 0)
        try: return torch.tensor(x) if self.ret_t else x
        except: return x
        

#### Example

In [None]:
tfm = PadTfm(0, 'right')

arr = [[[1, 2, 3], [1, 2]], [[1]]]

In [None]:
o = tfm(arr, 0)
print(o)

tensor([[[1, 2, 3],
         [1, 2, 0]],

        [[1, 0, 0],
         [0, 0, 0]]])


In [None]:
from transformers import AutoTokenizer
tokz = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [None]:
tokz.pad({'input_ids':[[[1, 2, 3], [1, 2]], [[1]]]})

{'input_ids': [[[1, 2, 3], [1, 2]], [[1], 0]], 'attention_mask': [[1, 1], [1, 0]]}

In [None]:
tokz.pad({'input_ids':[[1, 2, 3], [1]]})

{'input_ids': [[1, 2, 3], [1, 0, 0]], 'attention_mask': [[1, 1, 1], [1, 0, 0]]}

### `CollapseTfm`: COLLAPSE FEATURE

In [None]:
#| export
class CollapseTfm:

    def __init__(self, lev:int=0, use_ptr:int=True, **kwargs):
        store_attr('lev,use_ptr')

    def collapse(self, x:List, ptr:Dict, lev:int):
        if not isinstance(x, list): raise ValueError(f'`x` should be a list, check the `lev`({self.lev}).')
        if self.lev == lev:
            if lev in ptr: ptr[lev].append(len(x))
            else: ptr[lev] = [len(x)]
            return x
        x = list(chain(*[self.collapse(o, ptr, lev+1) for o in x]))
        if lev in ptr: ptr[lev].append(len(x))
        else: ptr[lev] = [len(x)]
        return x

    def _get_ptr(self, ptr):
        for v in ptr.values():
            for p,q in enumerate(v[1:]): v[p+1] = v[p] + q
        
    def __call__(self, x:List, lev:int=None, use_ptr:Optional[int]=None):
        store_attr('lev,use_ptr', is_none=False)
        
        ptr = dict()
        x = self.collapse(x, ptr, 0)
        if self.use_ptr: self._get_ptr(ptr)
        return x, ptr


#### Example

In [None]:
tfm = CollapseTfm()
x = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]],
     [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]
tfm(x, lev=2, use_ptr=True)

([[1, 2, 3],
  [4, 5, 6],
  [7, 8, 9],
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9],
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9],
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9]],
 {2: [3, 6, 9, 12], 1: [6, 12], 0: [12]})

### `CollateFeatTfm`: COLLATE FEATURE

In [None]:
#| export
class CollateFeatTfm:

    def __init__(self, prefix:Optional[str]=None, drop:Optional[bool]=True, lev:Optional[int]=0, **kwargs):
        store_attr('prefix,drop,lev')
        self.colps_proc = CollapseTfm(lev, use_ptr=False)

    def proc(self, x:Union[Dict, List], prefix:Optional[str]=None, drop:Optional[bool]=True, lev:Optional[int]=0):
        if isinstance(x, list):
            name = [k for k in x[0] if prefix is None or re.match(f'^{prefix}',k)]
            feat = {k: [o.pop(k) if drop else o[k] for o in x] for k in name}
            if lev > 0:
                for k in name: 
                    feat[k], ptr = self.colps_proc(feat[k], lev)
                    for p,q in ptr.items(): 
                        if p != 0: feat[f'{k}_ptr-{p}'] = q
        elif isinstance(x, dict):
            name = [k for k in x if prefix is None or re.match(f'^{prefix}',k)]
            feat = {k: x.pop(k) if drop else x[k] for k in name}
        return feat

    def __call__(self, x:Union[Dict, List], prefix:Optional[str]=None, drop:Optional[bool]=None, lev:Optional[int]=None):
        store_attr('prefix,drop,lev', is_none=False)
        return self.proc(x, self.prefix, self.drop, self.lev)
        
        

#### Example

In [None]:
tfm = CollateFeatTfm(prefix='a', drop=False)
obj = [{'a':[[1, 2, 3], [1]], 'b':[1, 2, 3]}, 
       {'a':[[1, 2]], 'b':[1]},
       {'a':[[1]], 'b':[1, 2]},
       {'a':[[1], [1, 2, 3, 4, 5]], 'b':[1, 2, 3, 4]}]

o = tfm(obj, lev=2)
for k,v in o.items(): print(k, ':'); print(v)

a :
[1, 2, 3, 1, 1, 2, 1, 1, 1, 2, 3, 4, 5]
a_ptr-2 :
[3, 1, 2, 1, 1, 5]
a_ptr-1 :
[4, 2, 1, 6]


### `PadFeatTfm`: PAD BATCH

In [None]:
#| export
class PadFeatTfm:

    def __init__(self,
                 prefix:Optional[str]=None, 
                 drop:Optional[bool]=True, 
                 pad_tok:Optional[int]=0, 
                 pad_side:Optional[str]='right', 
                 ret_t:Optional[bool]=True,
                 in_place:Optional[bool]=True,
                 lev:Optional[int]=0,
                 **kwargs):
        store_attr('prefix,drop,pad_tok,pad_side,ret_t,in_place,lev')
        self.pad_proc, self.coll_proc = PadTfm(), CollateFeatTfm(prefix=prefix, drop=drop, lev=lev)

    def get_feat(self, 
                 x:Union[Dict, List], 
                 prefix:Optional[str]=None, 
                 drop:Optional[bool]=True, 
                 lev:Optional[int]=0):
        if isinstance(x, list):
            name = [k for k in x[0] if prefix is None or re.match(f'^{prefix}',k)]
            feat = {k: [o.pop(k) if drop else o[k] for o in x] for k in name}
            if lev > 0:
                for k in name: 
                    feat[k], ptr = self.coll_proc(feat[k], lev)
                    for p,q in ptr.items(): 
                        if p != 0: feat[f'{k}_ptr-{p}'] = q
        elif isinstance(x, dict):
            name = [k for k in x if prefix is None or re.match(f'^{prefix}',k)]
            feat = {k: x.pop(k) if drop else x[k] for k in name}
        return feat

    def proc(self, x):
        return BatchEncoding({
            k: (self.pad_proc(v, 0, self.pad_side, self.ret_t, self.in_place) 
                if re.match('(.*_attention_mask|.*_token_type_ids)', k) else 
                self.pad_proc(v, self.pad_tok, self.pad_side, self.ret_t, self.in_place)) 
            for k,v in x.items()
        })
        
    def __call__(self, x:Union[Dict, List], 
                 prefix:Optional[str]=None, 
                 drop:Optional[bool]=None, 
                 pad_tok:Optional[int]=None, 
                 pad_side:Optional[str]=None, 
                 ret_t:Optional[bool]=None, 
                 in_place:Optional[bool]=None,
                 lev:Optional[int]=0):
        store_attr('prefix,drop,pad_tok,pad_side,ret_t,in_place,lev', is_none=False)
        feat = self.coll_proc(x, self.prefix, self.drop, self.lev)
        return self.proc(feat)
        

#### Example 1

In [None]:
tfm = PadFeatTfm(pad_tok=0)

obj = [{'a':[[1, 2, 3], [1]], 'b':[1, 2, 3]}, 
       {'a':[[1, 2]], 'b':[1]},
       {'a':[[1]], 'b':[1, 2]},
       {'a':[[1], [1, 2, 3, 4, 5]], 'b':[1, 2, 3, 4]}]

o = tfm(obj, prefix='a')
for k,v in o.items(): print(k, ':'); print(v)

a :
tensor([[[1, 2, 3, 0, 0],
         [1, 0, 0, 0, 0]],

        [[1, 2, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[1, 0, 0, 0, 0],
         [1, 2, 3, 4, 5]]])


#### Example 2

In [None]:
batch = block.train.dset.one_batch()
batch[0].keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
o = (PadFeatTfm(**PARAM))(batch, prefix='lbl2data', lev=1, drop=False)

In [None]:
for k,v in o.items(): 
    if hasattr(v, 'shape'): print(k, ': ', v.shape)
    else: print(k, ': ', len(v))

lbl2data_idx :  torch.Size([16])
lbl2data_identifier :  16
lbl2data_input_text :  16
lbl2data_input_ids :  torch.Size([16, 16])
lbl2data_token_type_ids :  torch.Size([16, 16])
lbl2data_attention_mask :  torch.Size([16, 16])
lbl2data_idx_ptr-1 :  torch.Size([10])
lbl2data_identifier_ptr-1 :  torch.Size([10])
lbl2data_input_text_ptr-1 :  torch.Size([10])
lbl2data_input_ids_ptr-1 :  torch.Size([10])
lbl2data_token_type_ids_ptr-1 :  torch.Size([10])
lbl2data_attention_mask_ptr-1 :  torch.Size([10])


### `AlignInputIdsTfm`: ALIGN TOKEN SEQUENCE

In [None]:
#| export
class AlignInputIdsTfm:

    def __init__(self,
                 inp:Optional[str]='data',
                 targ:Optional[str]='lbl2data',
                 ptr:Optional[str]='lbl2data_data2ptr',
                 sep_tok:Optional[int]=0, 
                 pad_tok:Optional[int]=0,
                 device:Union[str,torch.device]='cpu', 
                 **kwargs):
        store_attr('inp,targ,ptr,sep_tok,pad_tok,device')

    @typedispatch
    def proc(self, inp_ids:List, targ_ids:List, sep_tok:int, targ_mask:Optional[List]=None, targ_tok:Optional[List]=None, **kwargs):
        for i,ids in enumerate(inp_ids):
            inp_len = len(ids)
            for j,t in enumerate(targ_ids[i]):
                if len(t) > inp_len: 
                    targ_ids[i][j] = t[:inp_len-1]+[self.sep_tok]
                    if targ_mask is not None: targ_mask[i][j] = targ_mask[i][j][:inp_len]
                    if targ_tok is not None: targ_tok[i][j] = targ_tok[i][j][:inp_len] 
        return targ_ids, targ_mask, targ_tok

    @typedispatch
    def proc(self, inp_ids:torch.Tensor, targ_ids:torch.Tensor, ptr:torch.Tensor, sep_tok:int, pad_tok:int,
             targ_mask:Optional[torch.Tensor]=None, targ_tok:Optional[torch.Tensor]=None):
        inp_len = (inp_ids == sep_tok).cumsum(1).argmax(1) + 1
        inp_len = torch.repeat_interleave(inp_len, ptr)
        targ_len = (targ_ids == sep_tok).cumsum(1).argmax(1) + 1
        seq_len = torch.where(inp_len < targ_len, inp_len, targ_len)
        
        for i,(p,q) in enumerate(zip(seq_len, targ_len)):
            targ_ids[i,p-1] = sep_tok
            targ_ids[i,p:q] = pad_tok 
            if targ_mask is not None: targ_mask[i,p:q] = 0
            if targ_tok is not None: targ_tok[i,p:q] = 0
        return targ_ids, targ_mask, targ_tok
        
    def __call__(self, x:Dict, 
                 inp:Optional[str]=None, 
                 targ:Optional[str]=None,
                 ptr:Optional[str]=None, 
                 sep_tok:Optional[int]=None, 
                 pad_tok:Optional[int]=None):
        store_attr('inp,targ,ptr,sep_tok,pad_tok', is_none=False)

        def get_attr(x, keys, required=False):
            attr = []
            for k in keys.split(','):
                if k not in x: 
                    if required: raise ValueError(f'"{k}" not in `x`')
                    else: attr.append(None)
                else: attr.append(x[k])
            return attr
            
        inp_ids, targ_ids = get_attr(x, f'{self.inp}_input_ids,{self.targ}_input_ids')
        if inp_ids is None or targ_ids is None: return x
        targ_mask, targ_tok = get_attr(x, f'{self.targ}_attention_mask,{self.targ}_token_type_ids') 
        ptr = None if self.ptr is None else x[self.ptr]
        
        targ_ids, targ_mask, targ_tok = self.proc(inp_ids, targ_ids, ptr=ptr, targ_mask=targ_mask, targ_tok=targ_tok, 
                                                  sep_tok=self.sep_tok, pad_tok=self.pad_tok)
        def set_attr(x, keys, vals):
            for i,(k,v) in enumerate(zip(keys.split(','),vals)):
                if v is not None: x[k] = v
                    
        set_attr(x, f'{self.targ}_input_ids,{self.targ}_attention_mask,{self.targ}_token_type_ids', [targ_ids,targ_mask,targ_tok])
        
        return x
        

In [None]:
@typedispatch
def verify_align(p:torch.Tensor, q:torch.Tensor, r:torch.Tensor, tok:int):
    p_len = torch.where(p == tok)[1]+1
    p_len = p_len.repeat_interleave(r)
    q_len = torch.where(q == tok)[1]+1
    for p,q in zip(p_len, q_len):
        p,q = p.item(),q.item()
        if p < q: print(p,' < ',q)
        else: print(p,' > ',q)

#### Example 1

In [None]:
pad_tfm = PadFeatTfm(pad_tok=0)
obj = {
    'data_input_ids':[[1, 2, 11], [11], [1, 2, 3, 4, 5, 11]], 
    'lbl2data_input_ids':[[1, 2, 11], [5, 11], [5, 11], [5, 11], [5, 6, 11], [13, 4, 11]],
    'lbl2data_data2ptr':[1, 2, 3],
    'lbl2data_attention_mask':[[1, 1, 1], [1, 1], [1, 1], [1, 1], [1, 1, 1], [1, 1, 1]],
    'lbl2data_token_type_ids':[[0, 0, 0], [0, 0], [0, 0], [0, 0], [0, 0, 0], [0, 0, 0]],
}

o = pad_tfm(obj)
for k,v in o.items(): print(k,':'); print(v)

data_input_ids :
tensor([[ 1,  2, 11,  0,  0,  0],
        [11,  0,  0,  0,  0,  0],
        [ 1,  2,  3,  4,  5, 11]])
lbl2data_input_ids :
tensor([[ 1,  2, 11],
        [ 5, 11,  0],
        [ 5, 11,  0],
        [ 5, 11,  0],
        [ 5,  6, 11],
        [13,  4, 11]])
lbl2data_data2ptr :
tensor([1, 2, 3])
lbl2data_attention_mask :
tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 1, 0],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 1]])
lbl2data_token_type_ids :
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])


In [None]:
align_tfm = AlignInputIdsTfm(bsz=4, pad_tok=0, sep_tok=11, inp='data', 
                             targ='lbl2data', ptr='lbl2data_data2ptr')
o = align_tfm(o)
for k,v in o.items(): print(k,':'); print(v)

data_input_ids :
tensor([[ 1,  2, 11,  0,  0,  0],
        [11,  0,  0,  0,  0,  0],
        [ 1,  2,  3,  4,  5, 11]])
lbl2data_input_ids :
tensor([[ 1,  2, 11],
        [11,  0,  0],
        [11,  0,  0],
        [ 5, 11,  0],
        [ 5,  6, 11],
        [13,  4, 11]])
lbl2data_data2ptr :
tensor([1, 2, 3])
lbl2data_attention_mask :
tensor([[1, 1, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 1]])
lbl2data_token_type_ids :
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])


#### Example 2

In [None]:
batch = block.train.dset.one_batch()
pad_tfm, alg_tfm = PadFeatTfm(**PARAM), AlignInputIdsTfm(**PARAM)

In [None]:
batch[0].keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
o = pad_tfm(batch, prefix='data', lev=0, in_place=True, drop=True)
o.update(pad_tfm(batch, prefix='lbl2data', lev=1, in_place=True, drop=True))
o['lbl2data_data2ptr'] = o['lbl2data_idx_ptr-1']

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'], o['lbl2data_data2ptr'], PARAM['sep_tok'])

7  >  7
4  <  5
4  <  9
4  <  5
4  <  8
9  >  7
7  <  11
5  >  4
5  >  4
5  <  8
6  <  9
6  >  4
6  <  8
6  >  5
7  >  7
7  >  7
6  >  3
11  >  3
11  >  6


In [None]:
o = alg_tfm(o)

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'], o['lbl2data_data2ptr'], PARAM['sep_tok'])

7  >  7
4  >  4
4  >  4
4  >  4
4  >  4
9  >  7
7  >  7
5  >  4
5  >  4
5  >  5
6  >  6
6  >  4
6  >  6
6  >  5
7  >  7
7  >  7
6  >  3
11  >  3
11  >  6


### `XCPadFeatTfm`: PAD XC BATCH

In [None]:
#| export
class XCPadFeatTfm:

    @delegates(PadFeatTfm.__init__)
    def __init__(self, **kwargs):
        self.tfm = PadFeatTfm(**kwargs)

    def extract_ptr(self, x:Dict, suffix:str):
        ptr_name = [k for k in x if re.match(f'.*{suffix}$',k)]
        ptr = [x.pop(k) for k in ptr_name]
        return ptr[0] if len(ptr) else None

    def __call__(self, x):
        meta_name = set([k.split('_',maxsplit=1)[0].split('2')[0] for k in x[0]]).difference(['lbl', 'data'])
        out = self.tfm(x, prefix='lbl2data', lev=1, in_place=True, drop=True)
        lbl2data_data2ptr = self.extract_ptr(out, 'ptr-1')
        if lbl2data_data2ptr is not None: out['lbl2data_data2ptr'] = lbl2data_data2ptr
        out.update(self.tfm(x, prefix='data', lev=0, in_place=True, drop=True))
        for k in meta_name:
            o = self.tfm(x, prefix=f'{k}2lbl2data', lev=2, in_place=True, drop=True)
            o[f'{k}2lbl2data_data2ptr'] = self.extract_ptr(o, 'ptr-1')
            o[f'{k}2lbl2data_lbl2ptr'] = self.extract_ptr(o, 'ptr-2')
            out.update(o)
            o = self.tfm(x, prefix=f'{k}2data', lev=1, in_place=True, drop=True)
            o[f'{k}2data_data2ptr'] = self.extract_ptr(o, 'ptr-1')
            out.update(o)
        return out
        

#### Example

In [None]:
batch = block.train.dset.one_batch()
pad_tfm, align_tfm = XCPadFeatTfm(**PARAM), AlignInputIdsTfm(**PARAM)

In [None]:
o = pad_tfm(batch)
o = align_tfm(o)

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'], o['lbl2data_data2ptr'], PARAM['sep_tok'])

6  >  6
7  >  7
5  >  5
7  >  7
7  >  7
10  >  7
6  >  5
6  >  6
16  >  16
16  >  6
16  >  7
16  >  6
16  >  5
16  >  4
16  >  5
5  >  4
5  >  5
5  >  5
5  >  5
5  >  5
5  >  5
7  >  5


### `XCPadOutputTfm`: PAD XC OUTPUT

In [None]:
#| export
class XCPadOutputTfm:

    @delegates(PadFeatTfm.__init__)
    def __init__(self, **kwargs):
        self.tfm = PadFeatTfm(**kwargs)

    def extract_ptr(self, x:Dict, suffix:str):
        ptr_name = [k for k in x if re.match(f'.*{suffix}$',k)]
        return [x.pop(k) for k in ptr_name][0]

    def __call__(self, x):
        out = self.tfm(x, prefix='info2seq', lev=0, in_place=True, drop=True)
        out.update(self.tfm(x, prefix='seq', lev=0, in_place=True, drop=True))
        return out
        

#### Example

In [None]:
pad_tfm = XCPadOutputTfm(pad_tok=0)
obj = {
    'seq_output_ids':[[1, 2, 11], [11], [1, 2, 3, 4, 5, 11]],
    'seq_score':[1.2, 3.3, 4.5],
    'info2seq_idx':[1, 2, 11, 5, 11, 13],
    'info2seq_seq2ptr':[1, 2, 3],
}

o = pad_tfm(obj)
for k,v in o.items(): print(k,':'); print(v)

info2seq_idx :
tensor([ 1,  2, 11,  5, 11, 13])
info2seq_seq2ptr :
tensor([1, 2, 3])
seq_output_ids :
tensor([[ 1,  2, 11,  0,  0,  0],
        [11,  0,  0,  0,  0,  0],
        [ 1,  2,  3,  4,  5, 11]])
seq_score :
tensor([1.2000, 3.3000, 4.5000])


### `SampleFeatTfm`: SAMPLE LABELS

In [None]:
#| export
class SampleFeatTfm:

    def __init__(self, feat_type:Optional[str]=None, smp_prefix:Optional[str]='', **kwargs):
        store_attr('feat_type,smp_prefix')

    def _get_feat(self, x:Dict):
        if self.feat_type is None: raise ValueError('`feat_type` is None.')
        return [o for o in x if re.match(f'^({self.feat_type})_.*', o)]

    def proc(self, x:List, lev:int, n_samples:Optional[int]=1):
        feat = self._get_feat(x[0])
        out, coll_proc = {}, CollapseTfm()
        
        if len(feat) > 0:
            smp_prefix = self.smp_prefix if self.smp_prefix == '' else f'{self.smp_prefix}2'
            def _size(o, l): return len(coll_proc(o[feat[0]],l)[0])
            rnd_idx = [np.random.permutation(_size(o,lev))[:n_samples] if _size(o,lev) else [] for o in x]
            
            out = []
            for idx,o in zip(rnd_idx, x):
                d = {}
                for k in feat: c = coll_proc(o[k],lev)[0]; d.update({smp_prefix+k:[c[i] for i in idx] if len(idx) >= 0 else []})
                out.append(d) 
        return out

    def __call__(self, x:[List,Dict,BatchEncoding], lev:int, n_samples:Optional[int]=1, 
                 feat_type:Optional[str]=None, smp_prefix:Optional[str]=None, **kwargs):
        store_attr('feat_type,smp_prefix', is_none=False)
        return self.proc(x, lev, n_samples)
        

#### Example

In [None]:
batch = block.train.dset.one_batch()

In [None]:
batch[0].keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'hlk2data_idx', 'hlk2data_identifier', 'hlk2data_input_text', 'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2lbl2data_idx', 'hlk2lbl2data_identifier', 'hlk2lbl2data_input_text', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'])

In [None]:
feat_type = '(hlk2lbl2data|lbl2data)'
smp_tfm = SampleFeatTfm(feat_type=feat_type, ptr_type='data', smp_prefix='smp')
out = smp_tfm(batch, lev=0, n_samples=3)

In [None]:
for i,(o, b) in enumerate(zip(out, batch)): 
    k='data_input_text'; print(f'{i+1}.{k} ',': ', b[k])
    k=f'smp2{feat_type}_input_text'; print(f'  {k}', ': ', o[k])
    k=f'{feat_type}_input_text'; print(f'  {k}', ': ', b[k])

### `XCSamplePadFeatTfm`: SAMPLE AND THEN PAD BATCH

In [None]:
#| export
class XCSamplePadFeatTfm:

    def __init__(self, smp_features:Optional[List]=None, **kwargs):
        store_attr('smp_features')
        self.smp_proc, self.pad_proc = SampleFeatTfm(**kwargs), PadFeatTfm(**kwargs)
        
    def extract_ptr(self, x:Dict, suffix:str):
        ptr_name = [k for k in x if re.match(f'.*{suffix}$',k)]
        ptr = [x.pop(k) for k in ptr_name]
        return ptr[0] if len(ptr) else None
        
    def sample_feat(self, x:List, feat:str, lev:int, n_samples:Optional[int]=1):
        out = self.pad_proc(x, prefix=f'{feat}_idx', lev=lev, in_place=False, drop=False)
        
        if f'{feat}_idx' in out:
            
            out[f'p{feat}_idx'] = out.pop(f'{feat}_idx')
            out[f'p{feat}_data2ptr'] = out.pop(f'{feat}_idx_ptr-1')
            if f'{feat}_idx_ptr-2' in out: out.pop(f'{feat}_idx_ptr-2')
                
            o = self.pad_proc(self.smp_proc(x, lev=lev-1, n_samples=n_samples, feat_type=feat), 
                              prefix=feat, lev=1, in_place=True, drop=True)
            o[f'{feat}_data2ptr'] = self.extract_ptr(o, 'ptr-1')
            self.extract_ptr(o, 'ptr-2')
            
            out.update(o)
        return out

    def __call__(self, x:List, smp_features:Optional[List]=None):
        store_attr('smp_features', is_none=False)
        
        out, smp_features = BatchEncoding({}), () 
        if self.smp_features is not None:
            for feat,lev,n in self.smp_features: out.update(self.sample_feat(x, feat, lev, n))
            smp_features = list(zip(*self.smp_features))[0]
            
        out.update(self.pad_proc(x, prefix='data', lev=0, in_place=True, drop=True))
        
        meta_names = set([o.split('_',maxsplit=1)[0] for o in x[0]]).difference(smp_features+('data',))
        if 'lbl2data' in meta_names:
            out.update(self.pad_proc(x, prefix='lbl2data', lev=1, in_place=True, drop=True))
            out['lbl2data_data2ptr'] = self.extract_ptr(out, 'ptr-1')
        
        for k in meta_names.difference(['lbl2data']):
            if k.endswith('2lbl2data'): 
                o = self.pad_proc(x, prefix=k, lev=2, in_place=True, drop=True)
                o[f'{k}_data2ptr'] = self.extract_ptr(o, 'ptr-1')
                if 'lbl2data' in meta_names: o[f'{k}_lbl2data2ptr'] = self.extract_ptr(o, 'ptr-2')
                else: o[f'{k}_plbl2data2ptr'] = self.extract_ptr(o, 'ptr-2')
            elif k.endswith('2data'): 
                o = self.pad_proc(x, prefix=k, lev=1, in_place=True, drop=True)
                o[f'{k}_data2ptr'] = self.extract_ptr(o, 'ptr-1')
            else: raise ValueError(f'Invalid metadata name ({k})')
            out.update(o)
        return out
        

#### Example 1

In [None]:
batch = block.train.dset.one_batch(3)

In [None]:
batch[0].keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'hlk2data_idx', 'hlk2data_identifier', 'hlk2data_input_text', 'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2lbl2data_idx', 'hlk2lbl2data_identifier', 'hlk2lbl2data_input_text', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'])

In [None]:
tfm = XCSamplePadFeatTfm(**PARAM)
o = tfm(batch, smp_features=[('lbl2data',1, 2), ('hlk2data',1, 1), ('hlk2lbl2data',2, 1)])

In [None]:
for k,v in o.items():
    if k.startswith(r'lbl2data_'):
        if isinstance(v, torch.Tensor): print(f'{k}: ', v.shape)
        else: print(f'{k}: ', len(v))

lbl2data_idx:  torch.Size([5])
lbl2data_identifier:  5
lbl2data_input_text:  5
lbl2data_input_ids:  torch.Size([5, 12])
lbl2data_attention_mask:  torch.Size([5, 12])
lbl2data_data2ptr:  torch.Size([3])


In [None]:
for k,v in o.items():
    if k.startswith(r'hlk2data_'):
        if isinstance(v, torch.Tensor): print(f'{k}: ', v.shape)
        else: print(k, len(v))

hlk2data_idx:  torch.Size([3])
hlk2data_identifier 3
hlk2data_input_text 3
hlk2data_input_ids:  torch.Size([3, 8])
hlk2data_attention_mask:  torch.Size([3, 8])
hlk2data_data2ptr:  torch.Size([3])


In [None]:
for k,v in o.items():
    if k.startswith(r'hlk2lbl2data_'):
        if isinstance(v, torch.Tensor): print(f'{k}: ', v.shape)
        else: print(k, len(v))

hlk2lbl2data_idx:  torch.Size([3])
hlk2lbl2data_identifier 3
hlk2lbl2data_input_text 3
hlk2lbl2data_input_ids:  torch.Size([3, 8])
hlk2lbl2data_attention_mask:  torch.Size([3, 8])
hlk2lbl2data_data2ptr:  torch.Size([3])


In [None]:
o.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'phlk2data_idx', 'phlk2data_data2ptr', 'hlk2data_idx', 'hlk2data_identifier', 'hlk2data_input_text', 'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr', 'phlk2lbl2data_idx', 'phlk2lbl2data_data2ptr', 'hlk2lbl2data_idx', 'hlk2lbl2data_identifier', 'hlk2lbl2data_input_text', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask', 'hlk2lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask'])

### `RamenPadFeatTfm`: SAMPLE AND THEN PAD BATCH

In [None]:
#| export
class RamenPadFeatTfm:

    def __init__(self, smp_features:Optional[List]=None, **kwargs):
        store_attr('smp_features')
        self.smp_proc, self.pad_proc = SampleFeatTfm(**kwargs), PadFeatTfm(**kwargs)
        
    def extract_ptr(self, x:Dict, suffix:str):
        ptr_name = [k for k in x if re.match(f'.*{suffix}$',k)]
        ptr = [x.pop(k) for k in ptr_name]
        return ptr[0] if len(ptr) else None
    
    def get_feat(self, x, feat_type): return [o for o in x[0] if o.startswith(f'{feat_type}_')]

    def smp_feat(self, x:List, feat_type:str, n_samples:Optional[int]=1):
        feat = self.get_feat(x, feat_type)
        rnd_idx = [[np.random.permutation(len(v))[:n_samples] if len(v) else [-1] for v in o[feat[0]]] for o in x]

        out = []
        for o,idx in zip(x, rnd_idx):
            out.append({k:[[v[i] for i in ii if i >= 0] for v,ii in zip(o[k], idx)] for k in feat})
        return out

        
    def sample_feat(self, x:List, feat:str, lev:int, n_samples:Optional[Union[int,List]]=1):
        f = feat.split("|")
        
        if isinstance(n_samples, int):
            n_samples = (n_samples,)*len(f)
        else:
            if len(n_samples) != len(f): 
                raise ValueError(f"Size of `n_samples`({len(n_samples)}) should be equal to number of features.")
        
        f,of = f[0], f[1:]
        n_sample, n_samples = n_samples[0], n_samples[1:]
        
        out = self.pad_proc(x, prefix=f'{f}_idx', lev=1, in_place=False, drop=False)
        if f'{f}_idx' in out:
            out[f'p{f}_idx'],out[f'p{f}_{f.split("2")[-1]}2ptr'] = out.pop(f'{f}_idx'), self.extract_ptr(out, 'ptr-1')
            self.extract_ptr(out, 'ptr-2')

            smp_out = self.smp_proc(x, lev=0, n_samples=n_sample, feat_type=feat)

            o = self.pad_proc(smp_out, prefix=f, lev=1, in_place=True, drop=True)
            o[f'{f}_{f.split("2")[-1]}2ptr'] = self.extract_ptr(o, 'ptr-1')
            self.extract_ptr(o, 'ptr-2')
            out.update(o)

            for f,n_sample in zip(of,n_samples):
                o = self.pad_proc(smp_out, prefix=f'{f}_idx', lev=2, in_place=False, drop=False)
                if f'{f}_idx' in o:
                    o[f'p{"2".join(f.split("2")[:-1])}_idx'] = o.pop(f'{f}_idx')
                    o[f'p{"2".join(f.split("2")[:-1])}_{"2".join(f.split("2")[-2:])}2ptr'] = self.extract_ptr(o, 'ptr-2')
                    o[f'p{"2".join(f.split("2")[:-1])}_{f.split("2")[-1]}2ptr'] = self.extract_ptr(o, 'ptr-1')
                    out.update(o)

                    o = self.pad_proc(self.smp_feat(smp_out, f, n_sample), prefix=f, lev=2, in_place=True, drop=True)
                    feat = list(o.keys())
                    for k in feat:
                        p,q = k.split('_', maxsplit=1)
                        o["_".join(["2".join(p.split("2")[:-1]),q])] = o.pop(k)
                    o[f'{"2".join(p.split("2")[:-1])}_{"2".join(f.split("2")[-2:])}2ptr'] = self.extract_ptr(o, 'ptr-2')
                    o[f'{"2".join(p.split("2")[:-1])}_{f.split("2")[-1]}2ptr'] = self.extract_ptr(o, 'ptr-1')
                    out.update(o)
        return out
    

    def __call__(self, x:List, smp_features:Optional[List]=None):
        store_attr('smp_features', is_none=False)
        
        out, smp_features = BatchEncoding({}), () 
        if self.smp_features is not None:
            for feat,lev,n in self.smp_features: out.update(self.sample_feat(x, feat, lev, n))
            smp_features = list(chain(*[o[0].split('|') for o in self.smp_features]))
            
        out.update(self.pad_proc(x, prefix='data', lev=0, in_place=True, drop=True))
        
        meta_names = set([o.split('_',maxsplit=1)[0] for o in x[0]]).difference(smp_features+['data'])
        if 'lbl2data' in meta_names:
            out.update(self.pad_proc(x, prefix='lbl2data', lev=1, in_place=True, drop=True))
            out['lbl2data_data2ptr'] = self.extract_ptr(out, 'ptr-1')
        
        for k in meta_names.difference(['lbl2data']):
            if k.endswith('2lbl2data'): 
                o = self.pad_proc(x, prefix=k, lev=2, in_place=True, drop=True)
                o[f'{k}_data2ptr'] = self.extract_ptr(o, 'ptr-1')
                if 'lbl2data' in meta_names: o[f'{k}_lbl2data2ptr'] = self.extract_ptr(o, 'ptr-2')
                else: o[f'{k}_plbl2data2ptr'] = self.extract_ptr(o, 'ptr-2')
            elif k.endswith('2data'): 
                o = self.pad_proc(x, prefix=k, lev=1, in_place=True, drop=True)
                o[f'{k}_data2ptr'] = self.extract_ptr(o, 'ptr-1')
            else: raise ValueError(f'Invalid metadata name ({k})')
            out.update(o)
        return out
        

#### Example

In [None]:
batch = block.train.dset.one_batch(4)

In [None]:
tfm = RamenPadFeatTfm(**PARAM)
o = tfm(batch, smp_features=[('lbl2data|cat2lbl2data',1, (1,1)), ('cat2data',1, 2)])

In [None]:
o['plbl2data_idx']

tensor([ 11412,  31970,  31971,  89688, 100218, 193510, 223526, 230797, 115246,
        219431, 219547, 229016, 237079])

In [None]:
o['pcat2lbl_lbl2data2ptr']

tensor([2, 4, 2, 2])

In [None]:
o['lbl2data_idx']

tensor([ 31970, 100218, 223526, 115246])

In [None]:
o['cat2lbl_idx']

tensor([ 98820, 149008,  91424, 157547])

In [None]:
for k,v in o.items():
    if isinstance(v, torch.Tensor): print(k, ': ', v.shape)
    else: print(k, ': ', len(v))

plbl2data_idx :  torch.Size([13])
plbl2data_data2ptr :  torch.Size([4])
lbl2data_idx :  torch.Size([4])
lbl2data_identifier :  4
lbl2data_input_text :  4
lbl2data_input_ids :  torch.Size([4, 11])
lbl2data_attention_mask :  torch.Size([4, 11])
lbl2data_data2ptr :  torch.Size([4])
pcat2lbl_idx :  torch.Size([10])
pcat2lbl_lbl2data2ptr :  torch.Size([4])
pcat2lbl_data2ptr :  torch.Size([4])
cat2lbl_idx :  torch.Size([4])
cat2lbl_identifier :  4
cat2lbl_input_text :  4
cat2lbl_input_ids :  torch.Size([4, 9])
cat2lbl_attention_mask :  torch.Size([4, 9])
cat2lbl_lbl2data2ptr :  torch.Size([4])
cat2lbl_data2ptr :  torch.Size([4])
pcat2data_idx :  torch.Size([11])
pcat2data_data2ptr :  torch.Size([4])
cat2data_idx :  torch.Size([7])
cat2data_identifier :  7
cat2data_input_text :  7
cat2data_input_ids :  torch.Size([7, 10])
cat2data_attention_mask :  torch.Size([7, 10])
cat2data_data2ptr :  torch.Size([4])
data_identifier :  4
data_input_text :  4
data_input_ids :  torch.Size([4, 7])
data_atten

### `RemoveColumnsTfm`: REMOVE COLUMNS

In [None]:
#| export
class RemoveColumnTfm:
    
    def __init__(self, column:List, **kwargs):
        self.column = column
    
    def __call__(self, x:Dict):
        for k in self.column: 
            if k in x: x.pop(k)
        return x

### `NGPadFeatTfm`: PAD NGAME BATCH

In [None]:
#| export
class NGPadFeatTfm:

    def __init__(self, **kwargs):
        self.smp_proc, self.pad_proc = SampleFeatTfm(**kwargs), PadFeatTfm(**kwargs)

    def __call__(self, x:Dict):
        out = self.pad_proc(x, prefix='lbl2data_idx', lev=1, in_place=False, drop=False)
        if 'lbl2data_idx' in out:
            out['plbl2data_idx'] = out['lbl2data_idx']
            out['plbl2data_data2ptr'] = out.pop('lbl2data_idx_ptr-1')
            out.update(self.pad_proc(self.smp_proc(x, lev=0, feat_type='lbl2data', ptr_type='data'), 
                                     prefix='lbl2data', lev=1, in_place=True, drop=True))
        out.update(self.pad_proc(x, prefix='data', lev=0, in_place=True, drop=True))
        return out
        

#### Example 1

In [None]:
batch = block.train.dset.one_batch()
tfm = NGPadFeatTfm(**PARAM)

o = tfm(batch)

In [None]:
o.keys()

dict_keys(['lbl2data_idx', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_idx_ptr-1', 'lbl2data_identifier_ptr-1', 'lbl2data_input_text_ptr-1', 'lbl2data_input_ids_ptr-1', 'lbl2data_attention_mask_ptr-1', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask'])

In [None]:
prefix = 'lbl2data'
for k,v in o.items():
    if re.match(f'^{prefix}_.*',k): print(k, ': ', v.shape if isinstance(v, torch.Tensor) else len(v))
        

lbl2data_idx :  torch.Size([10])
lbl2data_identifier :  10
lbl2data_input_text :  10
lbl2data_input_ids :  torch.Size([10, 13])
lbl2data_attention_mask :  torch.Size([10, 13])
lbl2data_idx_ptr-1 :  torch.Size([10])
lbl2data_identifier_ptr-1 :  torch.Size([10])
lbl2data_input_text_ptr-1 :  torch.Size([10])
lbl2data_input_ids_ptr-1 :  torch.Size([10])
lbl2data_attention_mask_ptr-1 :  torch.Size([10])


#### Example 2

In [None]:
batch = [{k:v for k,v in o.items() if re.match(r'^data_*',k)} for o in block.train.dset.one_batch()]
tfm = NGPadFeatTfm(**PARAM)

o = tfm(batch)

In [None]:
o.keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask'])

### `TfmPipeline`: APPLY TRANSFORMS

In [None]:
#| export
class TfmPipeline:

    def __init__(self, tfms:List):
        self.tfms = tfms

    def __call__(self, x):
        for tfm in self.tfms: x = tfm(x)
        return x
        

#### Example

In [None]:
tfms = [XCPadFeatTfm(**PARAM), AlignInputIdsTfm(**PARAM)]

tfm = TfmPipeline(tfms)
batch = block.train.dset.one_batch()

o = tfm(batch)

In [None]:
o.keys()

dict_keys(['lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

## Mid transforms

### `AlignInputIdsTfm`: ALIGN TOKEN SEQUENCE

In [None]:
@typedispatch
def verify_align(src_ids:List, targ_ids:List):
    for p,qs in zip(src_ids, targ_ids):
        for q in qs: 
            if len(p)<len(q): print(len(p),' < ', len(q))
            else: print(len(p),' > ', len(q))
    

In [None]:
batch = block.train.dset.one_batch()
batch = {k:[o[k] for o in batch] for k,v in batch[0].items()}

In [None]:
batch.keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
PARAM['ptr'] = None

In [None]:
tfm = AlignInputIdsTfm(**PARAM)
o = tfm(batch)

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'])

5  >  5
5  >  4
11  >  7
11  >  5
11  >  5
11  >  6
11  >  5
6  >  6
7  >  7
16  >  9
4  >  4
6  >  6
7  >  7
4  >  4


### `TriePruneInputIdsTfm`: PRUNE TOKEN SEQUENCE

In [None]:
class TriePruneInputIdsTfm:

    def __init__(self, prefix:str='lbl2data'):
        self.prefix = prefix

    @staticmethod
    def _flatten(x:List, o:List):
        if not isinstance(x[0], list): o.append(x)
        else: 
            for i in x: TriePruneInputIdsTfm._flatten(i, o)

    @staticmethod
    def flatten(x:List):
        flat_x = []
        TriePruneInputIdsTfm._flatten(x, flat_x)
        return flat_x
        
    @staticmethod
    def _prune_feature(x:List, trie:Trie):
        if not isinstance(x[0], list): return trie.prefix(x)
        return [TriePruneInputIdsTfm._prune_feature(o, trie) for o in x]

    def prune_feature(self, x:Dict, fld:str):
        if fld not in x: raise ValueError(f'`{fld}` not in `x`')
        v = self.flatten(x[fld])
        trie = Trie.from_list(v)
        trie.prune()
        x[fld] = self._prune_feature(x[fld], trie)

    @staticmethod
    def _align_feature(inp:List, targ:List):
        if not isinstance(inp[0], list): return targ[:len(inp)]
        for i,(p,q) in enumerate(zip(inp, targ)): targ[i] = TriePruneInputIdsTfm._align_feature(p,q)
        return targ

    def align_feature(self, x:Dict, inp:str, targ:str):
        if targ not in x: return
        self._align_feature(x[inp], x[targ])
        
    def __call__(self, x:Dict, 
                 prefix:Optional[str]=None):
        self.prefix = self.prefix if prefix is None else prefix
        
        self.prune_feature(x, f'{self.prefix}_input_ids')
        self.align_feature(x, f'{self.prefix}_input_ids', f'{self.prefix}_attention_mask')
        self.align_feature(x, f'{self.prefix}_input_ids', f'{self.prefix}_token_type_ids')
        return x


#### Example 1

In [None]:
lbl = [[[101, 100, 200, 300, 102],
        [101, 200, 100, 100, 109, 102]],
       [[101, 200, 100, 100, 301, 102],
        [101, 300, 301, 200, 400, 500, 102],
        [101, 300, 301, 102]],
       [[101, 200, 100, 222, 301, 401, 501, 444, 102]]]

mask = [[[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]],
        [[1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],
        [[1, 1, 1, 1, 1, 1, 1, 1, 1]]]

x = {'lbl2data_input_ids': lbl, 'lbl2data_attention_mask': mask}
for k,v in x.items(): print(k, ':'); print(v)

lbl2data_input_ids :
[[[101, 100, 200, 300, 102], [101, 200, 100, 100, 109, 102]], [[101, 200, 100, 100, 301, 102], [101, 300, 301, 200, 400, 500, 102], [101, 300, 301, 102]], [[101, 200, 100, 222, 301, 401, 501, 444, 102]]]
lbl2data_attention_mask :
[[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1, 1, 1, 1]]]


In [None]:
tfm = TriePruneInputIdsTfm(prefix='lbl2data')
o = tfm(x)
for k,v in o.items(): 
    print(k, ':'); print(v)

  0%|          | 0/6 [00:00<?, ?it/s]

lbl2data_input_ids :
[[[101, 100, 102], [101, 200, 100, 100, 109, 102]], [[101, 200, 100, 100, 301, 102], [101, 300, 301, 200, 102], [101, 300, 301, 102]], [[101, 200, 100, 222, 102]]]
lbl2data_attention_mask :
[[[1, 1, 1], [1, 1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 1, 1]]]


#### Example 2

In [None]:
batch = block.train.dset.one_batch()
batch = {k:[o[k] for o in batch] for k,v in batch[0].items()}

In [None]:
lbl2data_input_ids = batch['lbl2data_input_ids'].copy()

In [None]:
batch.keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
tfm = TriePruneInputIdsTfm(prefix='lbl2data')
o = tfm(batch)

  0%|          | 0/14 [00:00<?, ?it/s]

In [None]:
def verify_prune(x:List, y:List):
    x = TriePruneInputIdsTfm.flatten(x)
    y = TriePruneInputIdsTfm.flatten(y)
    for p,q in zip(x,y):
        if len(p) < len(q): print(len(p),' < ',len(q))
        else: print(len(p),' > ',len(q))
            

In [None]:
def verify_alignment(x:List, y:List):
    x = TriePruneInputIdsTfm.flatten(x)
    y = TriePruneInputIdsTfm.flatten(y)
    return np.all([len(p)==len(q) for p,q in zip(x,y)])
    

In [None]:
verify_alignment(batch['lbl2data_input_ids'], batch['lbl2data_attention_mask'])

True

In [None]:
verify_prune(lbl2data_input_ids, batch['lbl2data_input_ids'])

5  >  3
10  >  5
13  >  5
8  >  3
8  >  3
4  >  3
7  >  3
6  >  3
3  >  3
11  >  5
7  >  5
5  >  3
7  >  7
7  >  7


## Item transforms

### `AugmentMetaInputIdsTfm`: AUGMENT INPUT TOKENS

In [None]:
#| export
class AugmentMetaInputIdsTfm:

    def __init__(self, meta:str, max_len:Optional[int]=None, exclude_sep:Optional[bool]=False):
        self.meta, self.max_len, self.exclude_sep = meta, max_len, exclude_sep
    
    def augment(self, data_ids:List, data_meta:sparse.csr_matrix, meta_ids:List):
        meta2data_ids = []
        for d_ids, d_meta in progress_bar(zip(data_ids, data_meta), total=len(data_ids)):
            m2d_ids, sep_tok = d_ids[:-1].copy() if self.exclude_sep else d_ids.copy(), d_ids[-1:]
            for o in d_meta.indices[np.random.permutation(len(d_meta.indices))]:
                if self.exclude_sep: m2d_ids.extend(meta_ids[o][1:-1])
                else: m2d_ids.extend(meta_ids[o][1:])
                if self.max_len is not None and len(m2d_ids)>=self.max_len: m2d_ids = m2d_ids[:self.max_len-1]; break
            meta2data_ids.append(m2d_ids+sep_tok)
        return meta2data_ids

    def proc(self, block:XCDataBlock, split:str, fld:str, side:Optional[str]='data'):
        if side not in ['data', 'lbl']: 
            raise ValueError("Invalid `side`, it should be in ['data','lbl']")
            
        if fld in get_attr(block, f'{split}.dset.data.{side}_info'):
            data_ids = get_attr(block, f'{split}.dset.data.{side}_info')[fld]
            meta_ids = get_attr(block, f'{split}.dset.meta.{self.meta}.meta_info')[fld]
            data_meta = get_attr(block, f'{split}.dset.meta.{self.meta}.{side}_meta')
            get_attr(block, f'{split}.dset.data.{side}_info')[f'{fld}_aug_{self.meta.split("_")[0]}'] = self.augment(data_ids, data_meta, meta_ids)

    def __call__(self, block:XCDataBlock, meta:str, side:Optional[str]='data', max_len:Optional[int]=None, 
                 exclude_sep:Optional[bool]=None):
        store_attr('meta,max_len,exclude_sep', is_none=False)
        for split in master_bar(['train', 'valid', 'test']):
            if hasattr(block, split) and get_attr(block, split) is not None: 
                for fld in ['input_ids', 'attention_mask', 'token_type_ids']: self.proc(block, split, fld, side)
        return block
        
    @classmethod
    def apply(cls, block:XCDataBlock, meta:str, side:Optional[str]='data', max_len:Optional[int]=None, exclude_sep:Optional[bool]=False):
        self = cls(meta, max_len, exclude_sep)
        return self(block, meta, side, max_len, exclude_sep)
        

#### Example

##### Example 1

In [None]:
o = AugmentMetaInputIdsTfm.apply(block, 'hlk_meta', 15, True)

In [None]:
block.train.dset.data.data_info.keys()

dict_keys(['identifier', 'input_text', 'input_ids', 'attention_mask', 'input_ids_aug_hlk', 'attention_mask_aug_hlk'])

In [None]:
p = block.train.dset.data.data_info['attention_mask']
q = block.train.dset.data.data_info['attention_mask_aug_hlk']

In [None]:
for i in np.random.permutation(len(p))[:10]: print(len(p[i]), len(q[i]))

7 15
13 15
9 15
10 15
4 15
9 15
8 15
6 15
9 15
4 15


In [None]:
p = block.train.dset.data.data_info['input_ids_aug_hlk']
q = block.train.dset.data.data_info['attention_mask_aug_hlk']

In [None]:
for i in np.random.permutation(len(p))[:10]: print(len(p[i]), len(q[i]))

15 15
15 15
15 15
15 15
15 15
15 15
15 15
13 13
15 15
15 15


In [None]:
from transformers import AutoTokenizer

In [None]:
tokz = AutoTokenizer.from_pretrained('distilbert-base-uncased')

n = 1000
print(p[n])
print("Length: ", len(p[n]))
print(tokz.decode(p[n]))

[101, 2522, 17040, 2239, 16569, 1997, 2605, 1996, 17830, 2063, 16944, 2099, 2605, 7640, 102]
Length:  15
[CLS] couzon communes of france thermae allier france departments [SEP]


##### Example 2

In [None]:
tfm = AugmentMetaInputIdsTfm('hlk_meta', 512, True)

In [None]:
o = AugmentMetaInputIdsTfm.apply(block, 'hlk_meta', 'data', 15, True)

In [None]:
o = AugmentMetaInputIdsTfm.apply(block, 'hlk_meta', 'lbl', 32, True)

### `TriePruneInputIdsTfm`: PRUNE TOKEN SEQUENCE

In [None]:
#| export
class TriePruneInputIdsTfm:

    def prune(self, block:XCDataBlock, loc:str, fld:str):
        x = get_attr(block, loc)
        if fld in x:
            trie = Trie.from_list(x[fld], None)
            trie.prune()
            x[f'{fld}_prn_tre'] = [trie.prefix(o) for o in x[fld]]

    def align(self, block:XCDataBlock, loc:str, inp:str, targ:str):
        x = get_attr(block, loc)
        if inp in x and targ in x:
            x[f'{targ}_prn_tre'] = [q[:len(p)] for i,(p,q) in enumerate(zip(x[inp],x[targ]))]
        
    def proc(self, block:XCDataBlock, loc:str):
        self.prune(block, loc, 'input_ids')
        self.align(block, loc, 'input_ids_prn_tre', 'attention_mask')
        self.align(block, loc, 'input_ids_prn_tre', 'token_type_ids')
        return block

    def __call__(self, block:XCDataBlock, loc:str):
        return self.proc(block, loc)

    @classmethod
    def apply(cls, block:XCDataBlock, loc:str):
        self = cls()
        return self(block, loc)
        

#### Example

In [None]:
block = TriePruneInputIdsTfm.apply(block, 'train.dset.data.lbl_info')

  0%|          | 0/312330 [00:00<?, ?it/s]

In [None]:
block.train.dset.data.lbl_info.keys()

dict_keys(['identifier', 'input_text', 'input_ids', 'attention_mask', 'input_ids_prn_tre', 'attention_mask_prn_tre'])

In [None]:
x = block.train.dset.data.lbl_info
rnd_idx = np.random.permutation(len(x['input_ids_prn_tre']))[:10]
p,q = x['input_ids_prn_tre'],x['input_ids']
for idx in rnd_idx: print(f'{p[idx]}: {len(p[idx])} ; {q[idx]}: {len(q[idx])}')

[101, 2293, 2818, 102]: 4 ; [101, 2293, 2818, 102]: 4
[101, 16359, 102]: 3 ; [101, 16359, 102]: 3
[101, 4794, 7120, 102]: 4 ; [101, 4794, 7120, 102]: 4
[101, 2697, 1059, 102]: 4 ; [101, 2697, 1059, 16584, 9286, 7952, 102]: 7
[101, 10067, 19948, 102]: 4 ; [101, 10067, 19948, 102]: 4
[101, 25353, 4571, 13592, 22498, 2015, 102]: 7 ; [101, 25353, 4571, 13592, 22498, 2015, 102]: 7
[101, 16215, 10735, 102]: 4 ; [101, 16215, 10735, 21007, 3686, 102]: 6
[101, 1060, 22540, 102]: 4 ; [101, 1060, 22540, 9099, 3401, 16402, 102]: 7
[101, 2862, 1997, 4291, 4290, 3152, 1997, 2901, 102]: 9 ; [101, 2862, 1997, 4291, 4290, 3152, 1997, 2901, 102]: 9
[101, 13873, 8316, 102]: 4 ; [101, 13873, 8316, 1006, 4623, 1007, 102]: 7
