In [None]:
#| default_exp transform

In [None]:
#| hide 
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from tqdm.auto import tqdm
from scipy import sparse
import torch, numpy as np
from fastcore.utils import *
from fastcore.meta import *
from fastcore.dispatch import *
from transformers import AutoTokenizer, BatchEncoding
from itertools import chain

from xcai.core import *
from xcai.generation.trie import *
from xcai.data import XCDataBlock, BaseXCDataBlock

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [None]:
from xcai.test_utils import *
block = Test.from_cfg('data_meta')

  self._set_arrayXarray(i, j, x)


In [None]:
PARAM = {
    'cols': ['identifier', 'input_text'],
    'use_tokz': True,
    'tokz': 'bert-base-uncased',
    'max_len': 32,
    'prefix': 'lbl',
    'pad_side': 'right',
    'inp': 'data',
    'targ': 'lbl2data',
    'ptr': 'lbl2data_data2ptr',
    'drop': True,
    'ret_t': True,
    'in_place': True,
    'collapse': True,
    'device': 'cpu',
}
tokz = AutoTokenizer.from_pretrained(PARAM['tokz'])
PARAM['sep_tok'] = tokz.sep_token_id
PARAM['pad_tok'] = tokz.pad_token_id

## Batch transforms

### `PadTfm`: PAD FEATURE

In [None]:
#| export
class PadTfm:

    def __init__(self, 
                 pad_tok:Optional[int]=None, 
                 pad_side:Optional[str]='right', 
                 ret_t:Optional[bool]=True,
                 in_place:Optional[bool]=True,
                 **kwargs):
        store_attr('pad_tok,pad_side,ret_t,in_place')

    def _sz_help(self, x:List, sz:List, lev:int):
        if isinstance(x[0], list):
            l = max(len(o) for o in x)
            if len(sz) > lev: sz[lev] = max(sz[lev], l)
            else: sz.append(l)
            for o in x: self._sz_help(o, sz, lev+1)

    def get_sz(self, x:List):
        sz = [len(x)]
        self._sz_help(x, sz, len(sz))
        return sz

    def _pad_help(self, x:List, sz:List, pads:List, lev:int):
        if len(x) and isinstance(x[0], list):
            for i,o in enumerate(x): x[i] = self._pad_help(o, sz, pads, lev+1)
        rem = [pads[lev]]*(sz[lev] - len(x))
        return x+rem if self.pad_side == 'right' else rem+x

    def __call__(self, 
                 x:List, 
                 pad_tok:Optional[int]=None, 
                 pad_side:Optional[str]=None, 
                 ret_t:Optional[bool]=None, 
                 in_place:Optional[bool]=None):
        store_attr('pad_tok,pad_side,ret_t,in_place', is_none=False)
        if self.pad_tok is None: raise ValueError('`pad_tok` cannot be None.')
        
        sz = self.get_sz(x)
        pads = [self.pad_tok]
        for s in sz[:0:-1]: pads.insert(0, [pads[0]]*s)
        if not self.in_place: x = x.copy()
        x = self._pad_help(x, sz, pads, 0)
        try: return torch.tensor(x) if self.ret_t else x
        except: return x
        

In [None]:
tfm = PadTfm(0, 'right')

arr = [[[1, 2, 3], [1, 2]], [[1]]]
o = tfm(arr, 100)
print(o)

tensor([[[  1,   2,   3],
         [  1,   2, 100]],

        [[  1, 100, 100],
         [100, 100, 100]]])


### `CollapseTfm`: COLLAPSE FEATURE

In [None]:
#| export
class CollapseTfm:

    def __init__(self, lev:int=0, use_ptr:int=True):
        store_attr('lev,use_ptr')

    def collapse(self, x:List, ptr:Dict, lev:int):
        if not isinstance(x, list): raise ValueError(f'`x` should be a list, check the `lev`({self.lev}).')
        if self.lev == lev:
            if lev in ptr: ptr[lev].append(len(x))
            else: ptr[lev] = [len(x)]
            return x
        x = list(chain(*[self.collapse(o, ptr, lev+1) for o in x]))
        if lev in ptr: ptr[lev].append(len(x))
        else: ptr[lev] = [len(x)]
        return x

    def _get_ptr(self, ptr):
        for v in ptr.values():
            for p,q in enumerate(v[1:]): v[p+1] = v[p] + q
        
    def __call__(self, x:List, lev:int=None, use_ptr:Optional[int]=None):
        store_attr('lev,use_ptr', is_none=False)
        
        ptr = dict()
        x = self.collapse(x, ptr, 0)
        if self.use_ptr: self._get_ptr(ptr)
        return x, ptr


In [None]:
tfm = CollapseTfm()
x = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]],
     [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]
tfm(x, lev=2, use_ptr=True)

([[1, 2, 3],
  [4, 5, 6],
  [7, 8, 9],
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9],
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9],
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9]],
 {2: [3, 6, 9, 12], 1: [6, 12], 0: [12]})

### `PadFeatTfm`: PAD BATCH

In [None]:
#| export
class PadFeatTfm:

    def __init__(self,
                 prefix:Optional[str]=None, 
                 drop:Optional[bool]=True, 
                 pad_tok:Optional[int]=0, 
                 pad_side:Optional[str]='right', 
                 ret_t:Optional[bool]=True,
                 in_place:Optional[bool]=True,
                 lev:Optional[int]=0,
                 **kwargs):
        store_attr('prefix,drop,pad_tok,pad_side,ret_t,in_place,lev')
        self.pad_proc, self.colps_proc = PadTfm(), CollapseTfm(lev, use_ptr=False)

    def get_feat(self, 
                 x:Union[Dict, List], 
                 prefix:Optional[str]=None, 
                 drop:Optional[bool]=True, 
                 lev:Optional[int]=0):
        if isinstance(x, list):
            name = [k for k in x[0] if prefix is None or re.match(f'^{prefix}',k)]
            feat = {k: [o.pop(k) if drop else o[k] for o in x] for k in name}
            if lev > 0:
                for k in name: 
                    feat[k], ptr = self.colps_proc(feat[k], lev)
                    for p,q in ptr.items(): 
                        if p != 0: feat[f'{k}_ptr-{p}'] = q
        elif isinstance(x, dict):
            name = [k for k in x if prefix is None or re.match(f'^{prefix}',k)]
            feat = {k: x.pop(k) if drop else x[k] for k in name}
        return feat
        
    def __call__(self, x:Union[Dict, List], 
                 prefix:Optional[str]=None, 
                 drop:Optional[bool]=None, 
                 pad_tok:Optional[int]=None, 
                 pad_side:Optional[str]=None, 
                 ret_t:Optional[bool]=None, 
                 in_place:Optional[bool]=None,
                 lev:Optional[int]=0):
        store_attr('prefix,drop,pad_tok,pad_side,ret_t,in_place,lev', is_none=False)
        
        feat = self.get_feat(x, self.prefix, self.drop, self.lev)
        return BatchEncoding({k:self.pad_proc(v, self.pad_tok, self.pad_side, self.ret_t, self.in_place) for k,v in feat.items()})
        

#### Example 1

In [None]:
tfm = PadFeatTfm(pad_tok=0)

obj = [{'a':[[1, 2, 3], [1]], 'b':[1, 2, 3]}, 
       {'a':[[1, 2]], 'b':[1]},
       {'a':[[1]], 'b':[1, 2]},
       {'a':[[1], [1, 2, 3, 4, 5]], 'b':[1, 2, 3, 4]}]

o = tfm(obj)
for k,v in o.items(): print(k, ':'); print(v)

a :
tensor([[[1, 2, 3, 0, 0],
         [1, 0, 0, 0, 0]],

        [[1, 2, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[1, 0, 0, 0, 0],
         [1, 2, 3, 4, 5]]])
b :
tensor([[1, 2, 3, 0],
        [1, 0, 0, 0],
        [1, 2, 0, 0],
        [1, 2, 3, 4]])


#### Example 2

In [None]:
batch = block.train.dset.one_batch()
batch[0].keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
o = (PadFeatTfm(**PARAM))(batch, prefix='lbl2data', lev=1, drop=False)

In [None]:
for k,v in o.items(): 
    if hasattr(v, 'shape'): print(k, ': ', v.shape)
    else: print(k, ': ', len(v))

lbl2data_idx :  torch.Size([33])
lbl2data_identifier :  33
lbl2data_input_text :  33
lbl2data_input_ids :  torch.Size([33, 12])
lbl2data_token_type_ids :  torch.Size([33, 12])
lbl2data_attention_mask :  torch.Size([33, 12])
lbl2data_idx_ptr-1 :  torch.Size([10])
lbl2data_identifier_ptr-1 :  torch.Size([10])
lbl2data_input_text_ptr-1 :  torch.Size([10])
lbl2data_input_ids_ptr-1 :  torch.Size([10])
lbl2data_token_type_ids_ptr-1 :  torch.Size([10])
lbl2data_attention_mask_ptr-1 :  torch.Size([10])


### `AlignInputIdsTfm`: ALIGN TOKEN SEQUENCE

In [None]:
#| export
class AlignInputIdsTfm:

    def __init__(self,
                 inp:Optional[str]='data',
                 targ:Optional[str]='lbl2data',
                 ptr:Optional[str]='lbl2data_data2ptr',
                 sep_tok:Optional[int]=0, 
                 pad_tok:Optional[int]=0,
                 device:Union[str,torch.device]='cpu', 
                 **kwargs):
        store_attr('inp,targ,ptr,sep_tok,pad_tok,device')

    @typedispatch
    def encode(self, inp_ids:List, targ_ids:List, sep_tok:int, targ_mask:Optional[List]=None, targ_tok:Optional[List]=None, **kwargs):
        for i,ids in enumerate(inp_ids):
            inp_len = len(ids)
            for j,t in enumerate(targ_ids[i]):
                if len(t) > inp_len: 
                    targ_ids[i][j] = t[:inp_len-1]+[self.sep_tok]
                    if targ_mask is not None: targ_mask[i][j] = targ_mask[i][j][:inp_len]
                    if targ_tok is not None: targ_tok[i][j] = targ_tok[i][j][:inp_len] 
        return targ_ids, targ_mask, targ_tok

    @typedispatch
    def encode(self, inp_ids:torch.Tensor, targ_ids:torch.Tensor, ptr:torch.Tensor, sep_tok:int, pad_tok:int,
               targ_mask:Optional[torch.Tensor]=None, targ_tok:Optional[torch.Tensor]=None):
        inp_len = torch.where(inp_ids == sep_tok)[1] + 1
        inp_len = torch.repeat_interleave(inp_len, ptr)
        targ_len = torch.where(targ_ids == sep_tok)[1] + 1
        seq_len = torch.where(inp_len < targ_len, inp_len, targ_len)
        
        for i,(p,q) in enumerate(zip(seq_len, targ_len)):
            targ_ids[i,p-1] = sep_tok
            targ_ids[i,p:q] = pad_tok 
            if targ_mask is not None: targ_mask[i,p:q] = 0
            if targ_tok is not None: targ_tok[i,p:q] = 0
        return targ_ids, targ_mask, targ_tok
        
    def __call__(self, x:Dict, 
                 inp:Optional[str]=None, 
                 targ:Optional[str]=None,
                 ptr:Optional[str]=None, 
                 sep_tok:Optional[int]=None, 
                 pad_tok:Optional[int]=None):
        store_attr('inp,targ,ptr,sep_tok,pad_tok', is_none=False)

        def get_attr(x, keys, required=False):
            attr = []
            for k in keys.split(','):
                if k not in x: 
                    if required: raise ValueError(f'"{k}" not in `x`')
                    else: attr.append(None)
                else: attr.append(x[k])
            return attr
            
        inp_ids, targ_ids = get_attr(x, f'{self.inp}_input_ids,{self.targ}_input_ids', required=True) 
        targ_mask, targ_tok = get_attr(x, f'{self.targ}_attention_mask,{self.targ}_token_type_ids') 
        ptr = None if self.ptr is None else x[self.ptr]
        
        targ_ids, targ_mask, targ_tok = self.encode(inp_ids, targ_ids, ptr=ptr, targ_mask=targ_mask, targ_tok=targ_tok, 
                                                    sep_tok=self.sep_tok, pad_tok=self.pad_tok)
        def set_attr(x, keys, vals):
            for i,(k,v) in enumerate(zip(keys.split(','),vals)):
                if v is not None: x[k] = v
                    
        set_attr(x, f'{self.targ}_input_ids,{self.targ}_attention_mask,{self.targ}_token_type_ids', [targ_ids,targ_mask,targ_tok])
        
        return x
        

In [None]:
@typedispatch
def verify_align(p:torch.Tensor, q:torch.Tensor, r:torch.Tensor, tok:int):
    p_len = torch.where(p == tok)[1]+1
    p_len = p_len.repeat_interleave(r)
    q_len = torch.where(q == tok)[1]+1
    for p,q in zip(p_len, q_len):
        p,q = p.item(),q.item()
        if p < q: print(p,' < ',q)
        else: print(p,' > ',q)

#### Example 1

In [None]:
pad_tfm = PadFeatTfm(pad_tok=0)
obj = {
    'data_input_ids':[[1, 2, 11], [11], [1, 2, 3, 4, 5, 11]], 
    'lbl2data_input_ids':[[1, 2, 11], [5, 11], [5, 11], [5, 11], [5, 6, 11], [13, 4, 11]],
    'lbl2data_data2ptr':[1, 2, 3],
    'lbl2data_attention_mask':[[1, 1, 1], [1, 1], [1, 1], [1, 1], [1, 1, 1], [1, 1, 1]],
    'lbl2data_token_type_ids':[[0, 0, 0], [0, 0], [0, 0], [0, 0], [0, 0, 0], [0, 0, 0]],
}

o = pad_tfm(obj)
for k,v in o.items(): print(k,':'); print(v)

data_input_ids :
tensor([[ 1,  2, 11,  0,  0,  0],
        [11,  0,  0,  0,  0,  0],
        [ 1,  2,  3,  4,  5, 11]])
lbl2data_input_ids :
tensor([[ 1,  2, 11],
        [ 5, 11,  0],
        [ 5, 11,  0],
        [ 5, 11,  0],
        [ 5,  6, 11],
        [13,  4, 11]])
lbl2data_data2ptr :
tensor([1, 2, 3])
lbl2data_attention_mask :
tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 1, 0],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 1]])
lbl2data_token_type_ids :
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])


In [None]:
align_tfm = AlignInputIdsTfm(bsz=4, pad_tok=0, sep_tok=11, inp='data', 
                             targ='lbl2data', ptr='lbl2data_data2ptr')
o = align_tfm(o)
for k,v in o.items(): print(k,':'); print(v)

data_input_ids :
tensor([[ 1,  2, 11,  0,  0,  0],
        [11,  0,  0,  0,  0,  0],
        [ 1,  2,  3,  4,  5, 11]])
lbl2data_input_ids :
tensor([[ 1,  2, 11],
        [11,  0,  0],
        [11,  0,  0],
        [ 5, 11,  0],
        [ 5,  6, 11],
        [13,  4, 11]])
lbl2data_data2ptr :
tensor([1, 2, 3])
lbl2data_attention_mask :
tensor([[1, 1, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 1]])
lbl2data_token_type_ids :
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])


#### Example 2

In [None]:
batch = block.train.dset.one_batch()
pad_tfm, alg_tfm = PadFeatTfm(**PARAM), AlignInputIdsTfm(**PARAM)

In [None]:
batch[0].keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
o = pad_tfm(batch, prefix='data', lev=0, in_place=True, drop=True)
o.update(pad_tfm(batch, prefix='lbl2data', lev=1, in_place=True, drop=True))
o['lbl2data_data2ptr'] = o['lbl2data_idx_ptr-1']

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'], o['lbl2data_data2ptr'], PARAM['sep_tok'])

7  >  5
7  >  7
7  <  9
7  >  4
7  <  8
10  >  3
10  >  3
10  >  4
10  >  6
4  <  5
4  >  4
4  <  5
7  >  7
6  >  4
6  >  4
6  >  4
13  >  8
13  <  14
13  >  8
5  <  9
5  <  14
5  >  5
5  <  11
5  <  6


In [None]:
o = alg_tfm(o)

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'], o['lbl2data_data2ptr'], PARAM['sep_tok'])

7  >  5
7  >  7
7  >  7
7  >  4
7  >  7
10  >  3
10  >  3
10  >  4
10  >  6
4  >  4
4  >  4
4  >  4
7  >  7
6  >  4
6  >  4
6  >  4
13  >  8
13  >  13
13  >  8
5  >  5
5  >  5
5  >  5
5  >  5
5  >  5


### `XCPadFeatTfm`: PAD XC BATCH

In [None]:
#| export
class XCPadFeatTfm:

    @delegates(PadFeatTfm.__init__)
    def __init__(self, **kwargs):
        self.tfm = PadFeatTfm(**kwargs)

    def extract_ptr(self, x:Dict, suffix:str):
        ptr_name = [k for k in x if re.match(f'.*{suffix}$',k)]
        return [x.pop(k) for k in ptr_name][0]

    def __call__(self, x):
        meta_name = set([k.split('_',maxsplit=1)[0].split('2')[0] for k in x[0]]).difference(['lbl', 'data'])
        out = self.tfm(x, prefix='lbl2data', lev=1, in_place=True, drop=True)
        out['lbl2data_data2ptr'] = self.extract_ptr(out, 'ptr-1')
        out.update(self.tfm(x, prefix='data', lev=0, in_place=True, drop=True))
        for k in meta_name:
            o = self.tfm(x, prefix=f'{k}2lbl2data', lev=2, in_place=True, drop=True)
            o[f'{k}2lbl2data_data2ptr'] = self.extract_ptr(o, 'ptr-1')
            o[f'{k}2lbl2data_lbl2ptr'] = self.extract_ptr(o, 'ptr-2')
            out.update(o)
            o = self.tfm(x, prefix=f'{k}2data', lev=1, in_place=True, drop=True)
            o[f'{k}2data_data2ptr'] = self.extract_ptr(o, 'ptr-1')
            out.update(o)
            
        return out
        

#### Example

In [None]:
batch = block.train.dset.one_batch()
pad_tfm, align_tfm = XCPadFeatTfm(**PARAM), AlignInputIdsTfm(**PARAM)

In [None]:
o = pad_tfm(batch)
o = align_tfm(o)

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'], o['lbl2data_data2ptr'], PARAM['sep_tok'])

8  >  4
16  >  14
8  >  5
8  >  3
8  >  8
8  >  8
8  >  8
8  >  8
5  >  5
10  >  10
10  >  9
10  >  10
11  >  10
7  >  7
7  >  7
7  >  7
7  >  7
6  >  6
6  >  6
4  >  4
4  >  4
4  >  4
4  >  4


### `XCPadOutputTfm`: PAD XC OUTPUT

In [None]:
#| export
class XCPadOutputTfm:

    @delegates(PadFeatTfm.__init__)
    def __init__(self, **kwargs):
        self.tfm = PadFeatTfm(**kwargs)

    def extract_ptr(self, x:Dict, suffix:str):
        ptr_name = [k for k in x if re.match(f'.*{suffix}$',k)]
        return [x.pop(k) for k in ptr_name][0]

    def __call__(self, x):
        out = self.tfm(x, prefix='info2seq', lev=0, in_place=True, drop=True)
        out.update(self.tfm(x, prefix='seq', lev=0, in_place=True, drop=True))
        return out
        

#### Example

In [None]:
pad_tfm = XCPadOutputTfm(pad_tok=0)
obj = {
    'seq_output_ids':[[1, 2, 11], [11], [1, 2, 3, 4, 5, 11]],
    'seq_score':[1.2, 3.3, 4.5],
    'info2seq_idx':[1, 2, 11, 5, 11, 13],
    'info2seq_seq2ptr':[1, 2, 3],
}

o = pad_tfm(obj)
for k,v in o.items(): print(k,':'); print(v)

info2seq_idx :
tensor([ 1,  2, 11,  5, 11, 13])
info2seq_seq2ptr :
tensor([1, 2, 3])
seq_output_ids :
tensor([[ 1,  2, 11,  0,  0,  0],
        [11,  0,  0,  0,  0,  0],
        [ 1,  2,  3,  4,  5, 11]])
seq_score :
tensor([1.2000, 3.3000, 4.5000])


### `TfmPipeline`: APPLY TRANSFORMS

In [None]:
#| export
class TfmPipeline:

    def __init__(self, tfms:List):
        self.tfms = tfms

    def __call__(self, x):
        for tfm in self.tfms: x = tfm(x)
        return x
        

In [None]:
tfms = [XCPadFeatTfm(**PARAM), AlignInputIdsTfm(**PARAM)]

tfm = TfmPipeline(tfms)
batch = block.train.dset.one_batch()

o = tfm(batch)

In [None]:
o.keys()

dict_keys(['lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

## Item transforms

### `AlignInputIdsTfm`: ALIGN TOKEN SEQUENCE

In [None]:
@typedispatch
def verify_align(src_ids:List, targ_ids:List):
    for p,qs in zip(src_ids, targ_ids):
        for q in qs: 
            if len(p)<len(q): print(len(p),' < ', len(q))
            else: print(len(p),' > ', len(q))
    

In [None]:
batch = block.train.dset.one_batch()
batch = {k:[o[k] for o in batch] for k,v in batch[0].items()}

In [None]:
batch.keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
PARAM['ptr'] = None

In [None]:
tfm = AlignInputIdsTfm(**PARAM)
o = tfm(batch)

In [None]:
verify_align(o['data_input_ids'], o['lbl2data_input_ids'])

5  >  5
5  >  4
11  >  7
11  >  5
11  >  5
11  >  6
11  >  5
6  >  6
7  >  7
16  >  9
4  >  4
6  >  6
7  >  7
4  >  4


### `TriePruneInputIdsTfm`: PRUNE TOKEN SEQUENCE

In [None]:
#| export
class TriePruneInputIdsTfm:

    def __init__(self, prefix:str='lbl2data'):
        self.prefix = prefix

    @staticmethod
    def _flatten(x:List, o:List):
        if not isinstance(x[0], list): o.append(x)
        else: 
            for i in x: TriePruneInputIdsTfm._flatten(i, o)

    @staticmethod
    def flatten(x:List):
        flat_x = []
        TriePruneInputIdsTfm._flatten(x, flat_x)
        return flat_x
        
    @staticmethod
    def _prune_feature(x:List, trie:Trie):
        if not isinstance(x[0], list): return trie.prefix(x)
        return [TriePruneInputIdsTfm._prune_feature(o, trie) for o in x]

    def prune_feature(self, x:Dict, fld:str):
        if fld not in x: raise ValueError(f'`{fld}` not in `x`')
        v = self.flatten(x[fld])
        trie = Trie.from_list(v)
        trie.prune()
        x[fld] = self._prune_feature(x[fld], trie)

    @staticmethod
    def _align_feature(inp:List, targ:List):
        if not isinstance(inp[0], list): return targ[:len(inp)]
        for i,(p,q) in enumerate(zip(inp, targ)): targ[i] = TriePruneInputIdsTfm._align_feature(p,q)
        return targ

    def align_feature(self, x:Dict, inp:str, targ:str):
        if targ not in x: return
        self._align_feature(x[inp], x[targ])
        
    def __call__(self, x:Dict, 
                 prefix:Optional[str]=None):
        self.prefix = self.prefix if prefix is None else prefix
        
        self.prune_feature(x, f'{self.prefix}_input_ids')
        self.align_feature(x, f'{self.prefix}_input_ids', f'{self.prefix}_attention_mask')
        self.align_feature(x, f'{self.prefix}_input_ids', f'{self.prefix}_token_type_ids')
        return x


#### Example 1

In [None]:
lbl = [[[101, 100, 200, 300, 102],
        [101, 200, 100, 100, 109, 102]],
       [[101, 200, 100, 100, 301, 102],
        [101, 300, 301, 200, 400, 500, 102],
        [101, 300, 301, 102]],
       [[101, 200, 100, 222, 301, 401, 501, 444, 102]]]

mask = [[[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]],
        [[1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],
        [[1, 1, 1, 1, 1, 1, 1, 1, 1]]]

x = {'lbl2data_input_ids': lbl, 'lbl2data_attention_mask': mask}
for k,v in x.items(): print(k, ':'); print(v)

lbl2data_input_ids :
[[[101, 100, 200, 300, 102], [101, 200, 100, 100, 109, 102]], [[101, 200, 100, 100, 301, 102], [101, 300, 301, 200, 400, 500, 102], [101, 300, 301, 102]], [[101, 200, 100, 222, 301, 401, 501, 444, 102]]]
lbl2data_attention_mask :
[[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1, 1, 1, 1]]]


In [None]:
tfm = TriePruneInputIdsTfm(prefix='lbl2data')
o = tfm(x)
for k,v in o.items(): 
    print(k, ':'); print(v)

  0%|          | 0/6 [00:00<?, ?it/s]

lbl2data_input_ids :
[[[101, 100, 102], [101, 200, 100, 100, 109, 102]], [[101, 200, 100, 100, 301, 102], [101, 300, 301, 200, 102], [101, 300, 301, 102]], [[101, 200, 100, 222, 102]]]
lbl2data_attention_mask :
[[[1, 1, 1], [1, 1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 1, 1]]]


#### Example 2

In [None]:
batch = block.train.dset.one_batch()
batch = {k:[o[k] for o in batch] for k,v in batch[0].items()}

In [None]:
lbl2data_input_ids = batch['lbl2data_input_ids'].copy()

In [None]:
batch.keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask'])

In [None]:
tfm = TriePruneInputIdsTfm(prefix='lbl2data')
o = tfm(batch)

  0%|          | 0/14 [00:00<?, ?it/s]

In [None]:
def verify_prune(x:List, y:List):
    x = TriePruneInputIdsTfm.flatten(x)
    y = TriePruneInputIdsTfm.flatten(y)
    for p,q in zip(x,y):
        if len(p) < len(q): print(len(p),' < ',len(q))
        else: print(len(p),' > ',len(q))
            

In [None]:
def verify_alignment(x:List, y:List):
    x = TriePruneInputIdsTfm.flatten(x)
    y = TriePruneInputIdsTfm.flatten(y)
    return np.all([len(p)==len(q) for p,q in zip(x,y)])
    

In [None]:
verify_alignment(batch['lbl2data_input_ids'], batch['lbl2data_attention_mask'])

True

In [None]:
verify_prune(lbl2data_input_ids, batch['lbl2data_input_ids'])

5  >  3
10  >  5
13  >  5
8  >  3
8  >  3
4  >  3
7  >  3
6  >  3
3  >  3
11  >  5
7  >  5
5  >  3
7  >  7
7  >  7


### `AugmentMetaInputIdsTfm`: AUGMENT INPUT TOKENS

In [None]:
#| export
class AugmentMetaInputIdsTfm:

    def __init__(self, meta:str, max_len:Optional[int]=None):
        self.meta, self.max_len = meta, max_len
    
    def proc(self, data_ids:List, data_meta:sparse.csr_matrix, meta_ids:List):
        meta2data_ids = []
        for d_ids, d_meta in tqdm(zip(data_ids, data_meta), total=len(data_ids)):
            m2d_ids = d_ids.copy()
            for o in d_meta.indices:
                m2d_ids.extend(meta_ids[o][1:])
                if len(m2d_ids)>self.max_len and self.max_len is not None: m2d_ids = m2d_ids[:self.max_len-1] + m2d_ids[-1:]; break
            meta2data_ids.append(m2d_ids)
        return meta2data_ids

    def feature(self, block:BaseXCDataBlock, fld:str):
        if fld in block.dset.data.data_info:
            data_ids = block.dset.data.data_info[fld]
            meta_ids = block.dset.meta[self.meta].meta_info[fld]
            data_meta = block.dset.meta[self.meta].data_meta
            block.dset.data.data_info[fld] = self.proc(data_ids, data_meta, meta_ids)

    def split(self, block:XCDataBlock, split:str):
        split = getattr(block, split)
        if split is None: return
        if self.meta is None or self.meta not in split.dset.meta: raise ValueError(f'`{self.meta}` not in `block`')
        for fld in ['input_ids', 'attention_mask', 'token_type_ids']: self.feature(split, fld)

    @classmethod
    def apply(cls, block:XCDataBlock, meta:str, max_len:Optional[int]=None):
        self = cls(f'{meta}_meta', max_len)
        for split in ['train', 'valid', 'test']: self.split(block, split)
        return block
        

#### Example

In [None]:
import copy
aug_block = copy.deepcopy(block)

##### Example 1

In [None]:
data_ids = aug_block.train.dset.data.data_info['input_ids']
meta_ids = aug_block.train.dset.meta['hlk_meta'].meta_info['input_ids']
data_meta = aug_block.train.dset.meta['hlk_meta'].data_meta

In [None]:
tfm = AugmentMetaInputIdsTfm('hlk', 15)

In [None]:
meta2data_ids = tfm.proc(data_ids, data_meta, meta_ids)

  0%|          | 0/554466 [00:00<?, ?it/s]

##### Example 2

In [None]:
o = AugmentMetaInputIdsTfm.apply(aug_block, 'hlk', 15)

  0%|          | 0/554466 [00:00<?, ?it/s]

  0%|          | 0/554466 [00:00<?, ?it/s]

  0%|          | 0/554466 [00:00<?, ?it/s]

  0%|          | 0/137554 [00:00<?, ?it/s]

  0%|          | 0/137554 [00:00<?, ?it/s]

  0%|          | 0/137554 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

In [None]:
p = block.train.dset.data.data_info['token_type_ids']
q = aug_block.train.dset.data.data_info['token_type_ids']

In [None]:
for i in np.random.permutation(len(p))[:10]: print(len(p[i]), len(q[i]))

8 15
4 15
7 15
6 15
8 15
8 15
4 4
9 15
5 15
6 15


In [None]:
p = aug_block.train.dset.data.data_info['input_ids']
q = aug_block.train.dset.data.data_info['attention_mask']

In [None]:
for i in np.random.permutation(len(p))[:10]: print(len(p[i]), len(q[i]))

15 15
15 15
15 15
16 16
15 15
15 15
15 15
15 15
15 15
9 9


In [None]:
print(p[500])

[101, 23544, 12809, 11428, 1204, 12143, 102, 1498, 102, 1203, 1365, 1392, 102, 3559, 102]
