In [1]:
#| default_exp generation.generate

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [3]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import torch, math
import torch.multiprocessing as mp
from multiprocessing import Pool
import torch.nn.functional as F
from itertools import chain
from tqdm.auto import tqdm
from typing import Optional, Sequence, Any, Dict, List
from dataclasses import dataclass

from fastcore.utils import *
from fastcore.meta import *
from fastcore.parallel import *

from xcai.core import *
from xcai.transform import *
from xcai.generation.trie import *

In [None]:
import os, torch, numpy as np
from typing import Dict, Optional, List, Any
from tqdm.auto import tqdm
from scipy import stats
import scipy.sparse as sp
import torch.nn.functional as F
from itertools import chain

from fastcore.meta import *

from xcai.core import *
from xcai.data import XCDataBlock

## Setup

In [5]:
import numpy as np
from xcai.block import *
from xcai.models.MMM00X import DBT007
from xcai.metrics import *

In [205]:
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets', 'data', tokenizer='distilbert-base-uncased')
b, n_lbl = block.train.one_batch(), block.n_lbl

  self._set_arrayXarray(i, j, x)


In [202]:
mname = f'/home/aiscuser/scratch/Projects/XC-NLG/models/distilbert-base-uncased_RB33-NAR-1+8-2_(mapped)LF-WikiSeeAlsoTitles-320K/checkpoint-168000'
m = DBT007.from_pretrained(mname, tn_targ=10_000, ig_tok=0)

Some weights of DBT007 were not initialized from the model checkpoint at /home/aiscuser/scratch/Projects/XC-NLG/models/distilbert-base-uncased_RB33-NAR-1+8-2_(mapped)LF-WikiSeeAlsoTitles-320K/checkpoint-168000 and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
o = m(**b)

In [9]:
o.logits.shape, len(block.lbl_info['input_ids'])

(torch.Size([10, 11, 30522]), 312330)

In [10]:
toks = block.lbl_info['input_ids']
info = [[i] for i in range(len(toks))]
t = Trie.from_list(toks, info, max_info=20)

  0%|          | 0/312330 [00:00<?, ?it/s]

## Trie Pointer

In [11]:
#| export
class TriePtr:

    def __init__(self, trie, max_info:Optional[int]=None):
        store_attr('trie,max_info')
        self.ptr, self.hyp = trie.root, [trie.root.tok]

    @property
    def tokens(self):
        return list(self.ptr.nxt_toks.keys())

    def next(self, val:int):
        if val not in self.tokens: raise ValueError(f'`{val}` not a valid next token.')
        self.ptr = self.ptr.nxt_toks[val]
        self.hyp.append(val)
        return self

    def suffixes(self):
        o = []
        Trie._search(self.ptr, self.hyp, o, self.max_info)
        return sorted(o, key=lambda x: x.cnt, reverse=True)

    @property
    def is_end(self):
        return self.ptr.is_end

    @property
    def value(self):
        info = list(self.ptr.info) if self.max_info is None else list(self.ptr.info)[:self.max_info]
        return TrieOutput(self.hyp, self.ptr.cnt, info)

    def copy(self):
        t = TriePtr(self.trie, self.max_info)
        t.ptr,t.hyp = self.ptr,self.hyp.copy()
        return t
        

### Example 1

In [None]:
tp = TriePtr(t)

In [None]:
tp.tokens

[]

In [None]:
tp.is_end

True

In [None]:
tp.next(102)

In [None]:
tp.suffixes()

[TrieOutput(s=[101, 200, 100, 222, 102], cnt=1, info=None)]

In [None]:
tp.value

TrieOutput(s=[101, 200, 100, 222, 102], cnt=1, info=None)

### Example 2

In [None]:
tp = TriePtr(t)
l = [tp.copy(), tp.copy()]

In [None]:
l[0].next(100)

In [None]:
l[0].tokens, l[1].tokens

([102], [100, 200, 300])

In [None]:
l[0].hyp, l[1].hyp

([101, 100], [101])

## Hypothesis

In [12]:
#| export
class Hypothesis:

    def __init__(self, n_bm:int, len_penalty:Optional[float]=1.0):
        store_attr('n_bm,len_penalty')
        self.worst_sc, self.beams = 1e9, []

    def __len__(self):
        return len(self.beams)

    def add(self, hyp, sum_logits:float, gen_len:Optional[int]=None):
        if gen_len is not None: sc = sum_logits/gen_len**self.len_penalty
        else: sc = sum_logits/len(hyp.s)**self.len_penalty

        if len(self) < self.n_bm or sc > self.worst_sc:
            self.beams.append((sc, hyp))
            if len(self) > self.n_bm:
                nxt_sc = sorted([(s,i) for i,(s,_) in enumerate(self.beams)])
                del self.beams[nxt_sc[0][1]]
                self.worst_sc = nxt_sc[1][0]
            else: self.worst_sc = min(sc, self.worst_sc)

    def is_done(self, best_sc:float, cur_len:int):
        if len(self) < self.n_bm: return False
        high_sc = best_sc/cur_len**self.len_penalty
        return self.worst_sc >= high_sc
        

### Example

In [None]:
hyp = Hypothesis(5, 0.5)

In [None]:
len(hyp)

2

In [None]:
hyp.add(TrieOutput([1, 3, 6, 11, 12, 14], 2, [2, 5]), sum_logits=-1.2)

In [None]:
hyp.beams

[(-0.75, TrieOutput(s=[1, 2, 3, 4], cnt=2, info=[0, 1, 2])),
 (-0.6, TrieOutput(s=[1, 3, 6, 11], cnt=2, info=[2, 5])),
 (-0.6, TrieOutput(s=[1, 3, 6, 11], cnt=2, info=[2, 5])),
 (-0.5366563145999494, TrieOutput(s=[1, 3, 6, 11, 12], cnt=2, info=[2, 5])),
 (-0.48989794855663565,
  TrieOutput(s=[1, 3, 6, 11, 12, 14], cnt=2, info=[2, 5]))]

## Trie Beam

In [183]:
#| export
def pad_tensor(tensor, fill_value):
    max_len = max(len(t) for t in tensor)
    padded_tensor = torch.full((len(tensor), max_len), fill_value, dtype=tensor[0].dtype)
    mask = torch.zeros((len(tensor), max_len), dtype=torch.bool)
    for i, t in enumerate(tensor): padded_tensor[i, :len(t)], mask[i, :len(t)] = t, 1
    return padded_tensor, mask


In [228]:
#| export
class TrieBeam:

    def __init__(self, trie:Trie, eos_tok:int, n_bm:Optional[int]=5, len_penalty:Optional[float]=1.0, 
                 max_info:Optional[int]=None, **kwargs):
        store_attr('trie,eos_tok,n_bm,len_penalty,max_info')
        self.tfm = XCPadOutputTfm(**kwargs)

    def valid(self, pointers:List, scores:torch.FloatTensor):
        all_tok, all_sc, all_idx = [], [], []
        for ptr,sc in zip(pointers, scores):
            batch_tok = [torch.tensor([], dtype=torch.long)]
            batch_sc = [torch.tensor([], dtype=scores.dtype)]
            batch_idx = [torch.tensor([], dtype=torch.long)] 
            for i,(p,s) in enumerate(zip(ptr,sc)):
                toks = torch.tensor(p.tokens)
                batch_tok.append(toks)
                batch_sc.append(s[toks])
                batch_idx.append(torch.full((len(toks),), i))
            all_tok.append(torch.concat(batch_tok))
            all_sc.append(torch.concat(batch_sc))
            all_idx.append(torch.concat(batch_idx))
        all_tok, mask = pad_tensor(all_tok, -100)
        all_sc, _ = pad_tensor(all_sc, -float('Inf'))
        all_idx,_ = pad_tensor(all_idx, -100)
        return all_tok, all_sc, all_idx, mask

    def topk(self, tok:torch.Tensor, sc:torch.Tensor, idx:torch.Tensor, mask:torch.Tensor):
        top_sc, top_i = (
            torch.topk(sc, 2*self.n_bm, dim=1) 
            if sc.shape[1] > 2*self.n_bm else torch.sort(sc, dim=1, descending=True)
        )
        top_idx, top_tok, mask = idx.gather(1,top_i), tok.gather(1,top_i), mask.gather(1, top_i)
        return top_tok, top_sc, top_idx, mask

    def next(self, pointers:List, tokens:torch.Tensor, scores:torch.Tensor, indices:torch.Tensor, masks:torch.Tensor):
        all_ptr, all_sc = [], []
        for hyp,ptr,tok,sc,idx,mask in zip(self.hyp, pointers, tokens, scores, indices, masks):
            batch_tok, batch_sc, batch_idx = [], [], []
            for t,s,i,m in zip(tok,sc,idx,mask):
                t,s,i,m = t.item(),s.item(),i.item(),m.item()
                if t == self.eos_tok and m: hyp.add(ptr[i].copy().next(t).value, s)
                elif m: batch_tok.append(t); batch_sc.append(s); batch_idx.append(i)
            all_sc.append(torch.tensor(batch_sc)[:self.n_bm])
            batch_ptr = [ptr[i].copy().next(t) for t,i in zip(batch_tok[:self.n_bm], batch_idx[:self.n_bm])]
            all_ptr.append(batch_ptr)
        all_sc, _ = pad_tensor(all_sc, -float('Inf'))
        return all_ptr, all_sc

    def finalize(self, pointers:List, scores:torch.Tensor):
        outputs = []
        for i,(hyp,ptr,sc) in enumerate(zip(self.hyp,pointers,scores)):
            if len(hyp) < self.n_bm:
                for p,s in zip(ptr, sc):
                    for o in p.suffixes(): hyp.add(o, s)
            if len(hyp) < self.n_bm: raise ValueError(f'`len(hyp)`({len(hyp)}) < `n_bm`({self.n_bm})')
            seq_sc, seq_ids, info, n_info = list(map(list, zip(*[(s,h.s,h.info,len(h.info)) for s,h in hyp.beams])))
            outputs.append({
                'seq2data_data2ptr':[self.n_bm],
                'seq2data_score':seq_sc, 
                'seq2data_output_ids':seq_ids, 
                'info2seq2data_idx':list(chain(*info)),
                'info2seq2data_seq2ptr':n_info,
                'info2seq2data_data2ptr':[sum(n_info)],
            })
        return outputs
    
    def proc(self, logits:torch.FloatTensor, n_bm:Optional[int]=None, len_penalty:Optional[float]=None, 
             max_info:Optional[int]=None):
        store_attr('n_bm,len_penalty,max_info', is_none=False)
        bsz,seq_len,cur_len = logits.shape[0],logits.shape[1],1
        
        self.hyp = [Hypothesis(self.n_bm, self.len_penalty) for _ in range(bsz)]
        sc, ptr = torch.zeros((bsz,1,1)), [[TriePtr(self.trie,self.max_info)] for _ in range(bsz)]
        
        while True:
            sc = logits[:, cur_len:cur_len+1] + sc
            v_tok, v_sc, v_idx, mask = self.valid(ptr, sc)
            top_tok, top_sc, top_idx, mask = self.topk(v_tok, v_sc, v_idx, mask)
            ptr, sc = self.next(ptr, top_tok, top_sc, top_idx, mask)
            sc = sc.unsqueeze(2)
            cur_len += 1
            
            if (cur_len == seq_len 
                or torch.all(torch.tensor([len(p) for p in ptr]) == 0) 
                or torch.all(torch.tensor([hyp.is_done(s.max().item(), cur_len) for hyp,s in zip(self.hyp,sc)]))):
                break
                
        outputs = self.finalize(ptr, sc.squeeze(2))
        outputs = self.tfm({k:list(chain(*[o[k] for o in outputs])) for k in outputs[0]})
        return outputs
    

### Example

In [90]:
PARAM = {
    'pad_tok': 0,
    'pad_side': 'right',
    'drop': True,
    'ret_t': True,
    'in_place': True,
    'collapse': True,
    'device': 'cpu',
    'n_bm': 10,
    'len_penalty': 1.2,
}

In [180]:
logits = F.log_softmax(o.logits, dim=2).cpu().detach()

In [177]:
attention_mask = b['data_attention_mask'].bool().cpu().detach()
mask = torch.logical_not(attention_mask.unsqueeze(2).expand(logits.size()))
logits[mask] = 0

In [157]:
tb = TrieBeam(t, 102, n_bm=5, len_penalty=1.0)

In [181]:
%%time
r = tb.proc(logits)

CPU times: user 1min 17s, sys: 276 ms, total: 1min 17s
Wall time: 904 ms


In [None]:
%debug

In [71]:
r

{'info2seq2data_idx': tensor([117331, 130922, 130925, 130923, 130924,    372,  15979, 112620, 233843,
        213298, 229942, 253767,  85690, 256334, 268492, 255377, 210582, 268034,
        291174, 281488, 129448,  80862, 268164, 114470, 289712,   3646, 112036,
         10865, 105174, 100471, 229942,  93578,  58547,  58548,  61310,  71006,
        290881, 179869,  70274,   1651,  69818,  90164, 260660, 195205, 265863,
        268665, 218792,   8984, 287489, 275267]), 'info2seq2data_seq2ptr': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1]), 'info2seq2data_data2ptr': tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5]), 'seq2data_data2ptr': tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5]), 'seq2data_score': tensor([    inf,     inf,     inf,     inf,     inf, -1.8678, -2.0272, -2.0925,
        -2.0996,     inf, -1.7307, -2.9457, -1.3470,     inf,     inf, -0.5059,
        -1.3359, -0.48

## TrieBeamSearch

In [229]:
#| export
class TrieBeamSearch:

    @delegates(XCPadOutputTfm.__init__)
    def __init__(self, trie:Trie, eos_tok:int, n_bm:Optional[int]=5, len_penalty:Optional[float]=1.0, 
                 max_info:Optional[int]=None, **kwargs):
        store_attr('trie,eos_tok,n_bm,len_penalty,max_info')
        self.tb = TrieBeam(trie, eos_tok, n_bm=n_bm, len_penalty=len_penalty, max_info=max_info)
        
    def proc(self, model, inputs:Dict, n_bm:int=None, len_penalty:Optional[float]=None, 
             max_info:Optional[int]=None):
        store_attr('n_bm,len_penalty,max_info', is_none=False)
        
        logits = F.log_softmax(model(**inputs).logits, dim=2).cpu().detach()
        attention_mask = inputs['data_attention_mask'].bool().cpu().detach()
        mask = torch.logical_not(attention_mask.unsqueeze(2).expand(logits.size()))
        logits[mask] = 0
        
        outputs = self.tb.proc(logits, n_bm=self.n_bm, len_penalty=self.len_penalty, max_info=self.max_info)
        outputs['info2seq2data_score'] = torch.repeat_interleave(outputs['seq2data_score'], outputs['info2seq2data_seq2ptr'], dim=0)
        return outputs
        

### Example 1

In [185]:
PARAM = {
    'pad_tok': 0,
    'pad_side': 'right',
    'drop': True,
    'ret_t': True,
    'in_place': True,
    'collapse': True,
    'device': 'cpu',
    'n_bm': 10,
    'len_penalty': 1.2,
}

In [230]:
#| hide
tbs = TrieBeamSearch(t, 102, n_bm=20, max_bm=20, len_penalty=1, max_info=1)

In [189]:
#| hide
bsz = 64
b = block.train.one_batch(bsz)
b = prepare_batch(m, b, m_args=['lbl2data_idx'])

In [190]:
b.keys()

dict_keys(['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_input_ids', 'data_attention_mask'])

In [191]:
m, b = m.to('cuda'), b.to('cuda')

In [193]:
%%time
r = tbs.proc(m, b)

CPU times: user 5min, sys: 2.74 s, total: 5min 3s
Wall time: 4.79 s


In [None]:
%prun tbs.proc(m, b)

In [194]:
r

{'info2seq2data_idx': tensor([ 13080,  56531,  53600,  ...,  42710, 181002, 186359]), 'info2seq2data_seq2ptr': tensor([1, 1, 1,  ..., 1, 1, 1]), 'info2seq2data_data2ptr': tensor([20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20]), 'seq2data_data2ptr': tensor([20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20]), 'seq2data_score': tensor([-0.9979, -0.9979, -0.9979,  ..., -0.5846, -0.7420, -2.3247]), 'seq2data_output_ids': tensor([[ 101, 2862, 1997,  ...,    0,    0,    0],
        [ 101, 2862, 1997,  ...,    0,    0,    0],
        [ 101

In [195]:
#| hide
for k,v in r.items(): print(k, v.shape)

info2seq2data_idx torch.Size([1280])
info2seq2data_seq2ptr torch.Size([1280])
info2seq2data_data2ptr torch.Size([64])
seq2data_data2ptr torch.Size([64])
seq2data_score torch.Size([1280])
seq2data_output_ids torch.Size([1280, 18])
info2seq2data_score torch.Size([1280])


In [196]:
#| hide
output = {}
output['targ_idx'] = b['lbl2data_idx'].cpu()
output['targ_ptr'] = b['lbl2data_data2ptr'].cpu()

output['pred_idx'] = r['info2seq2data_idx'].cpu()
output['pred_score'] = r['info2seq2data_score'].cpu()
output['pred_ptr'] = r['info2seq2data_data2ptr'].cpu()


In [197]:
#| hide
metric = PrecRecl(n_lbl, prop=block.train.dset.data.data_lbl, pk=5, rk=5, rep_pk=[1, 3, 5], rep_rk=[5])

In [198]:
#| hide
metric(**output)

{'P@1': 0.4375,
 'P@3': 0.2552083333333333,
 'P@5': 0.18750000000000003,
 'N@1': 0.4375,
 'N@3': 0.42743546,
 'N@5': 0.4478654,
 'PSP@1': 0.40621228431290085,
 'PSP@3': 0.42684336299334974,
 'PSP@5': 0.4485064426200951,
 'PSN@1': 0.4062123,
 'PSN@3': 0.412833,
 'PSN@5': 0.43182746,
 'R@5': 0.4727149714052288}

### Example 2

In [242]:
#| hide
tbs = TrieBeamSearch(t, 102, n_bm=5, len_penalty=1.5)

test_dset = block.test.sample(n=200, seed=50)
metric = PrecRecl(block.n_lbl, test_dset.data_lbl_filterer, prop=block.train.dset.data.data_lbl, pk=5, rk=5, 
                  rep_pk=[1, 3, 5], rep_rk=[5])


In [232]:
#| hide
def get_xo(inp, targ):
    return {
        'targ_idx':inp['lbl2data_idx'],
        'targ_ptr':inp['lbl2data_data2ptr'],
        'pred_idx':targ['info2seq2data_idx'],
        'pred_score':targ['info2seq2data_score'],
        'pred_ptr':targ['info2seq2data_data2ptr'],
    }
    

In [243]:
m = m.to('cuda')

In [244]:
metric.reset()

for b in tqdm(test_dset.dl, total=len(test_dset.dl)):
    b = prepare_batch(m, b, m_args=['lbl2data_idx']).to('cuda')
    r = tbs.proc(m, b)
    o = get_xo(b.to('cpu'), r)
    metric.accumulate(**o)
    

  0%|          | 0/20 [00:00<?, ?it/s]

In [246]:
#| hide
metric.value

{'P@1': 0.165,
 'P@3': 0.10666666666666665,
 'P@5': 0.07899999999999996,
 'N@1': 0.165,
 'N@3': 0.16029894,
 'N@5': 0.16409001,
 'PSP@1': 0.10177405823432056,
 'PSP@3': 0.10478122301584751,
 'PSP@5': 0.11126934422627434,
 'PSN@1': 0.10177407,
 'PSN@3': 0.10880884,
 'PSN@5': 0.11646156,
 'R@5': 0.16957666519304446}

## TrieBeamSearch

In [135]:
class TrieBeamSearch:
    def __init__(self, max_height=32, sos_id=101, eos_id=102, pad_token=0, n_bm=10, len_penalty=0.0):
        store_attr('max_height,sos_id,eos_id,pad_token,n_bm,len_penalty')
        self.trie, self.hash = {}, None
    
    def build(self, X, y):
        assert(len(X) == len(y))
        self.hash = y
        trie_dict = {}
        for seq_id, seq in enumerate(tqdm(X)):
            next_dict = trie_dict
            for token in seq[:self.max_height]:
                next_dict[token] = next_dict.get(
                    token, {"next": {}, "occurs":0, "lbls": [],
                            "point_to": token, "is_leaf": False}
                )
                next_dict[token]["lbls"].append(seq_id)
                next_dict[token]["occurs"] += 1
                if token == self.eos_id:
                    next_dict[token]["is_leaf"] = True
                    break
                next_dict = next_dict[token]["next"]
        self.trie = trie_dict
    
    def decode_text(self, X):
        next_level = {'next':self.trie, 'lbls': [-1]}
        for token in X:
            items = next_level['next'].get(token, None)
            if items is None:
                return next_level["lbls"]
            if items["is_leaf"]:
                return items["lbls"]
            next_level = items
        return items["lbls"]
    
    def _padded_np(self, lol, fill_value, max_seq):
        tokens = np.full((len(lol), max_seq), fill_value, dtype=type(lol[0][0]))
        masks = np.zeros((len(lol), max_seq), dtype=np.int32)
        for i in np.arange(len(lol)):
            tokens[i, :len(lol[i])] = lol[i]
            masks[i, :len(lol[i])] = 1
        return tokens, masks
    
    def row_topk(self, score, k=10, return_scores=False, sort=False):
        index = np.argpartition(score, -k, axis=0)[-k:]
        if sort:
            _score = score[index]
            _index = np.argsort(_score, axis=0)
            index = index[_index]
        if return_scores:
            score = score[index]
            return index, score
        return index
    
    def batch_topk(self, score, k=10, return_scores=False, sort=False):
        index = np.argpartition(score, -k, axis=1)[:, -k:]
        if sort:
            _score = np.take_along_axis(score, index, axis=1)
            _index = np.argsort(_score, axis=1)
            index = np.take_along_axis(index, _index, axis=1)
        if return_scores:
            score = np.take_along_axis(score, index, axis=1)
            return index, score
        return index
    
    def _snl_one(self, tries, old_scores, curr_score, len_score, top_k_index):
        _tries, _score, _l_scr = [], [], []
        for col in top_k_index:
            items = list(tries[col]['next'].items())
            if len(items) == 0:
                key, value = [self.pad_token], [{"next":{}, "lbls": tries[col]["lbls"]}]
                l_items = [len_score[col]]
            else:
                key, value = list(zip(*items))
                l_items = np.full((len(key),), len_score[col] + 1)
            _tries.extend(value)
            _score.extend(curr_score[list(key)] + old_scores[col])
            _l_scr.extend(l_items)
        return _tries, _score, _l_scr
    
    def _agl_one(self, lol, sorted_index, sorted_scores):
        lol_lbs, lbl_scr = [], []
        for col, item in enumerate(sorted_index):
            _items = np.concatenate(
                        list(map(lambda x: self.hash[x], lol[item]["lbls"]))
                    )
            lol_lbs.extend(_items)
            lbl_scr.extend(np.full((_items.size,), sorted_scores[col], dtype=np.float32))
        return lol_lbs, lbl_scr

    def snl(self, tries, old_scores, curr_score, len_score, top_k_index):
        _tries, _score, _l_scr, _max_seq = [], [], [], -1
        for i in np.arange(top_k_index.shape[0]):
            __tries, __score, __l_scr = self._snl_one(
                tries[i], old_scores[i], curr_score[i], len_score[i], top_k_index[i])
            _max_seq = max(_max_seq, len(__tries))
            _tries.append(__tries)
            _score.append(__score)
            _l_scr.append(__l_scr)
        _score, _ = self._padded_np(_score, -np.inf, _max_seq)
        _l_scr, _ = self._padded_np(_l_scr, 0.001, _max_seq)
        return _tries, _score, _l_scr
    
    def agl(self, lol_trie, sorted_index, sorted_scores):
        lol_lbs, lbl_scr = [], []
        for rid in np.arange(sorted_index.shape[0]):
            _lol_lbs, _lbl_scr = self._agl_one(lol_trie[rid], sorted_index[rid], sorted_scores[rid])
            lol_lbs.append(_lol_lbs)
            lbl_scr.append(_lbl_scr)
        return lol_lbs, lbl_scr
    
    def decode_batch(self, preds, beam=10, l_penalty=0, start_seq=1):
        _token = np.full((len(preds), 1), self.sos_id)
        _tries = [[self.trie[self.sos_id]] for _ in range(len(preds))]
        _score = np.zeros((len(preds), 1), dtype=np.float32)
        _index = np.zeros((len(preds), 1), dtype=np.int32)
        _l_scr = np.ones((len(preds), 1), dtype=np.int32)
        for i in np.arange(start_seq, self.max_height):
            _tries, _score, _l_scr = self.snl(_tries, _score, preds[:, i], _l_scr, _index)
            _score = np.multiply(_score, np.power(_l_scr,-l_penalty))
            _index = self.batch_topk(_score, beam, i+1==self.max_height, i+1==self.max_height)
        return self.agl(_tries, _index[0], _index[1])

    def proc(self, model, inputs:Dict, n_bm:int=None, max_bm:Optional[int]=None, len_penalty:Optional[float]=None, 
             max_info:Optional[int]=None):
        store_attr('n_bm,len_penalty', is_none=False)
        logits = F.log_softmax(model(**inputs).logits, dim=2).cpu().detach().numpy()
        logits = np.concatenate([logits, np.zeros((logits.shape[0], max(0, self.max_height-logits.shape[1]), logits.shape[2]))], axis=1)
        idx, scores = self.decode_batch(logits, beam=self.n_bm, l_penalty=self.len_penalty)
        outputs = {
            'info2seq2data_idx': torch.tensor(list(chain(*idx))),
            'info2seq2data_score': torch.tensor(list(chain(*scores))),
            'info2seq2data_data2ptr': torch.tensor([len(o) for o in idx]),
        }
        return outputs
        
        
    def decode_one(self, pred, beam, l_penalty, start_seq=1):
        _token = [self.sos_id]
        _tries = [self.trie[self.sos_id]]
        _score = np.zeros((1,), dtype=np.float32)
        _index = np.zeros((1,), dtype=np.int32)
        _l_scr = np.ones((1, ), dtype=np.int32)
        for i in np.arange(start_seq, self.max_height):
            _tries, _score, _l_scr = self._snl_one(_tries, _score, pred[i], _l_scr, _index)
            _score = np.multiply(_score, np.power(_l_scr,-l_penalty))
            _index = self.row_topk(_score, beam, i+1==self.max_height, i+1==self.max_height)
        return self._agl_one(_tries, _index[0], _index[1])
    
    def decode_serial(self, preds, beam=10, l_penalty=0):
        labels = []
        scores = []
        for pred in preds:
            _labels, _scores = self.decode_one(pred, beam, l_penalty)
            labels.append(_labels)
            scores.append(_scores)
        return labels, scores
    
    @classmethod
    @delegates(__init__)
    def from_list(cls, toks:List, info:List, **kwargs):
        self = cls(**kwargs)
        self.build(toks, info)
        return self
    

### Example

In [130]:
tbs = TrieBeamSearch(max_height=32, sos_id=101, eos_id=102, pad_token=0, n_bm=20, len_penalty=0.0)

In [131]:
toks = block.lbl_info['input_ids']
info = [[i] for i in range(len(toks))]
tbs.build(toks, info)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [134]:
%%time
r = tbs.proc(m, b)

> [0;32m/tmp/ipykernel_1270042/750450751.py[0m(116)[0;36mdecode_batch[0;34m()[0m
[0;32m    114 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    115 [0;31m        [0;31m#debug[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 116 [0;31m        [0m_token[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mfull[0m[0;34m([0m[0;34m([0m[0mlen[0m[0;34m([0m[0mpreds[0m[0;34m)[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0msos_id[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    117 [0;31m        [0m_tries[0m [0;34m=[0m [0;34m[[0m[0;34m[[0m[0mself[0m[0;34m.[0m[0mtrie[0m[0;34m[[0m[0mself[0m[0;34m.[0m[0msos_id[0m[0;34m][0m[0;34m][0m [0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mlen[0m[0;34m([0m[0mpreds[0m[0;34m)[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118

ipdb> 
> [0;32m/tmp/ipykernel_1270042/750450751.py[0m(122)[0;36mdecode_batch[0;34m()[0m
[0;32m    120 [0;31m        [0m_l_scr[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mones[0m[0;34m([0m[0;34m([0m[0mlen[0m[0;34m([0m[0mpreds[0m[0;34m)[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m,[0m [0mdtype[0m[0;34m=[0m[0mnp[0m[0;34m.[0m[0mint32[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    121 [0;31m        [0;32mfor[0m [0mi[0m [0;32min[0m [0mnp[0m[0;34m.[0m[0marange[0m[0;34m([0m[0mstart_seq[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mmax_height[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 122 [0;31m            [0m_tries[0m[0;34m,[0m [0m_score[0m[0;34m,[0m [0m_l_scr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0msnl[0m[0;34m([0m[0m_tries[0m[0;34m,[0m [0m_score[0m[0;34m,[0m [0mpreds[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mi[0m[0;34m][0m[0;34m,[0m [0m_l_scr[0m[0;34m,[0m [0m_index[0m[0;

ipdb> index
array([[   68,  3721,  1496, ...,     7, 10100,  1249],
       [  935,  9782,   910, ...,  6760,  7161,  7872],
       [ 1603,  2005,  1138, ...,  5188,     7,  1688],
       ...,
       [ 2005,  1750,  6483, ..., 13943,     7,  1741],
       [ 9655,  6609,   552, ...,  6838,  1534,  2455],
       [  294,  7320,  7344, ...,  7578,     7,  1741]])
ipdb> n
> [0;32m/tmp/ipykernel_1270042/750450751.py[0m(61)[0;36mbatch_topk[0;34m()[0m
[0;32m     59 [0;31m            [0m_index[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0margsort[0m[0;34m([0m[0m_score[0m[0;34m,[0m [0maxis[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m            [0mindex[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mtake_along_axis[0m[0;34m([0m[0mindex[0m[0;34m,[0m [0m_index[0m[0;34m,[0m [0maxis[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 61 [0;31m        [0;32mif[0m [0mreturn_scores[0m[0;34m:[0m[0;

BdbQuit: 

## XCTrieBeamSearch

In [17]:
class XCTrieBeamSearch:
    
    @classmethod
    @delegates(TrieBeamSearch.from_list)
    def from_block(cls, block:XCDataBlock, meta:Optional[List]=None, **kwargs):
        toks = block.lbl_info['input_ids']
        info = [[i] for i in range(len(toks))]
        
        if meta is not None:
            meta_dset = block.train.dset.meta
            for o in meta:
                if f'{o}_meta' not in meta_dset: raise ValueError(f'`{o}_meta` does not exist.')
                meta_toks = meta_dset[f'{o}_meta'].meta_info['input_ids']
                lbl_meta = meta_dset[f'{o}_meta'].lbl_meta.T.tocsr()
                meta_info = [o.indices.tolist() for o in lbl_meta]
                if len(meta_toks) != len(meta_info): raise ValueError(f'`meta_toks` and `meta_info` should have equal length.')
                toks.extend(meta_toks); info.extend(meta_info)
                
        return TrieBeamSearch.from_list(toks, info, **kwargs)
                        

### Example

In [34]:
tbs = XCTrieBeamSearch.from_block(block, max_height=32, sos_id=101, eos_id=102, pad_token=0, n_bm=5, len_penalty=0.0)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [41]:
o = tbs.proc(model, b, n_bm=10)