In [None]:
#| default_exp generation.generate

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import torch, math
from torch.multiprocessing import Pool
import torch.nn.functional as F
from itertools import chain
from tqdm.auto import tqdm
from typing import Optional, Sequence, Any, Dict, List
from dataclasses import dataclass

from fastcore.utils import *
from fastcore.dispatch import *
from fastcore.meta import *
from fastcore.parallel import *

from xcai.core import *
from xcai.transform import *
from xcai.generation.trie import *

## Setup

In [None]:
#| hide
import numpy as np
from xcai.block import *
from xcai.models.MMM00X import BT0002
from xcai.metrics import *

In [None]:
#| hide
block = XCBlock.from_cfg('train', tokz='bert-base-cased')
b, n_lbl = block.train.one_batch(), block.n_lbl

  self._set_arrayXarray(i, j, x)


In [None]:
#| hide
fname = '/home/scai/phd/aiz218323/Projects/XC_NLG/code/models/bert-base-cased_RB33-NAR-3+8-2_(mapped)LF-WikiSeeAlsoTitles-320K/checkpoint-242000'
m = BT0002.from_pretrained(fname)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [None]:
#| hide
o = m(**b)

In [None]:
#| hide
o.logits.shape, len(block.lbl_info['input_ids'])

(torch.Size([10, 13, 28996]), 312330)

In [None]:
#| hide
toks = block.lbl_info['input_ids']
info = [[i] for i in range(len(toks))]
t = Trie.from_list(toks, info, max_info=20)

  0%|          | 0/312330 [00:00<?, ?it/s]

## Trie Pointer

In [None]:
#| export
class TriePtr:

    def __init__(self, trie, max_info:Optional[int]=None):
        store_attr('trie,max_info')
        self.ptr, self.hyp = trie.root, [trie.root.tok]

    @property
    def tokens(self):
        return list(self.ptr.nxt_toks.keys())

    def next(self, val:int):
        if val not in self.tokens: raise ValueError(f'`{val}` not a valid next token.')
        self.ptr = self.ptr.nxt_toks[val]
        self.hyp.append(val)

    def suffixes(self):
        o = []
        Trie._search(self.ptr, self.hyp, o, self.max_info)
        return sorted(o, key=lambda x: x.cnt, reverse=True)

    @property
    def is_end(self):
        return self.ptr.is_end

    @property
    def value(self):
        info = list(self.ptr.info) if self.max_info is None else list(self.ptr.info)[:self.max_info]
        return TrieOutput(self.hyp, self.ptr.cnt, info)

    def copy(self):
        t = TriePtr(self.trie, self.max_info)
        t.ptr,t.hyp = self.ptr,self.hyp.copy()
        return t
        

### Example 1

In [None]:
tp = TriePtr(t)

In [None]:
tp.tokens

[]

In [None]:
tp.is_end

True

In [None]:
tp.next(102)

In [None]:
tp.suffixes()

[TrieOutput(s=[101, 200, 100, 222, 102], cnt=1, info=None)]

In [None]:
tp.value

TrieOutput(s=[101, 200, 100, 222, 102], cnt=1, info=None)

### Example 2

In [None]:
tp = TriePtr(t)
l = [tp.copy(), tp.copy()]

In [None]:
l[0].next(100)

In [None]:
l[0].tokens, l[1].tokens

([102], [100, 200, 300])

In [None]:
l[0].hyp, l[1].hyp

([101, 100], [101])

## Hypothesis

In [None]:
#| export
class Hypothesis:

    def __init__(self, n_bm:int, len_penalty:Optional[float]=1.0):
        store_attr('n_bm,len_penalty')
        self.worst_sc, self.beams = 1e9, []

    def __len__(self):
        return len(self.beams)

    def add(self, hyp, sum_logits:float, gen_len:Optional[int]=None):
        if gen_len is not None: sc = sum_logits/gen_len**self.len_penalty
        else: sc = sum_logits/len(hyp.s)**self.len_penalty

        if len(self) < self.n_bm or sc > self.worst_sc:
            self.beams.append((sc, hyp))
            if len(self) > self.n_bm:
                nxt_sc = sorted([(s,i) for i,(s,_) in enumerate(self.beams)])
                del self.beams[nxt_sc[0][1]]
                self.worst_sc = nxt_sc[1][0]
            else: self.worst_sc = min(sc, self.worst_sc)

    def is_done(self, best_sc:float, cur_len:int):
        if len(self) < self.n_bm: return False
        high_sc = best_sc/cur_len**self.len_penalty
        return self.worst_sc >= high_sc
        

### Example

In [None]:
hyp = Hypothesis(5, 0.5)

In [None]:
len(hyp)

2

In [None]:
hyp.add(TrieOutput([1, 3, 6, 11, 12, 14], 2, [2, 5]), sum_logits=-1.2)

In [None]:
hyp.beams

[(-0.75, TrieOutput(s=[1, 2, 3, 4], cnt=2, info=[0, 1, 2])),
 (-0.6, TrieOutput(s=[1, 3, 6, 11], cnt=2, info=[2, 5])),
 (-0.6, TrieOutput(s=[1, 3, 6, 11], cnt=2, info=[2, 5])),
 (-0.5366563145999494, TrieOutput(s=[1, 3, 6, 11, 12], cnt=2, info=[2, 5])),
 (-0.48989794855663565,
  TrieOutput(s=[1, 3, 6, 11, 12, 14], cnt=2, info=[2, 5]))]

## Trie Beam

In [None]:
#| export
class TrieBeam:

    def __init__(self, trie:Trie, n_bm:Optional[int]=5, max_bm:Optional[int]=None, len_penalty:Optional[float]=1.0, 
                 max_info:Optional[int]=None):
        store_attr('trie,n_bm,len_penalty,max_info')
        self.max_bm, self.hyp = max_bm if max_bm is None else max(max_bm, 2*n_bm), None

    def valid(self, ptr:List, sc:torch.FloatTensor):
        v_tok, v_sc, v_idx = [], [], []
        for i,(p,s) in enumerate(zip(ptr,sc)):
            toks = p.tokens
            v_tok.extend(toks)
            v_sc.extend(s[toks].tolist())
            v_idx.extend([i for _ in range(len(toks))])
        return v_tok, v_sc, v_idx

    def topk(self, ptr:List, tok:List, sc:List, idx:List):
        top_sc, top_i = (
            torch.topk(torch.tensor(sc), 2*self.n_bm, dim=0) 
            if len(sc) > 2*self.n_bm else torch.sort(torch.tensor(sc), dim=0, descending=True)
        )
        top_sc = top_sc.tolist()
        top_idx, top_tok = list(zip(*[(idx[i],tok[i]) for i in top_i]))
        top_ptr = [ptr[i].copy() for i in top_idx]
        for p,t in zip(top_ptr, top_tok): p.next(t)
        return top_ptr, top_sc

    def next(self, ptr:List, sc:List):
        nxt_ptr, nxt_sc = [], []
        for i,(p,s) in enumerate(zip(ptr, sc)):
            if p.is_end: self.hyp.add(p.value, s)
            else: nxt_ptr.append(p);nxt_sc.append(s)
        nxt_ptr,nxt_sc = nxt_ptr[:self.n_bm],torch.tensor(nxt_sc[:self.n_bm]).unsqueeze(1)
        return nxt_ptr, nxt_sc

    def finalize(self, ptr:List, sc:List):
        if len(self.hyp) < self.n_bm:
            nh = int(math.ceil((self.max_bm-len(self.hyp))/len(ptr))) if self.max_bm is not None and len(ptr) else None
            for p,s in zip(ptr, sc):
                hyps = p.suffixes() if nh is None else p.suffixes()[:nh]
                for o in hyps: self.hyp.add(o, s)
        if len(self.hyp) < self.n_bm: raise ValueError(f'`len(self.hyp)`({len(self.hyp)}) < `n_bm`({self.n_bm})')
        seq_sc, seq_ids, info, n_info = list(map(list, zip(*[(sc,hyp.s,hyp.info,len(hyp.info)) for sc,hyp in self.hyp.beams])))
        return {
            'seq2data_data2ptr':[self.n_bm],
            'seq2data_score':seq_sc, 
            'seq2data_output_ids':seq_ids, 
            'info2seq2data_idx':list(chain(*info)),
            'info2seq2data_seq2ptr':n_info,
            'info2seq2data_data2ptr':[sum(n_info)],
        }
        
    def proc(self, logits:torch.FloatTensor, n_bm:Optional[int]=None, max_bm:Optional[int]=None, len_penalty:Optional[float]=None, 
             max_info:Optional[int]=None):
        store_attr('n_bm,len_penalty,max_info', is_none=False)
        if max_bm is not None: self.max_bm = max(max_bm, 2*self.n_bm)
        
        self.hyp = Hypothesis(self.n_bm, self.len_penalty)
        sc = torch.full((self.n_bm,1), -1e9); sc[0,0] = 0
        ptr = [TriePtr(self.trie,self.max_info) for _ in range(2*self.n_bm)]
        
        cur_len,seq_len = 1,logits.shape[0]
        while True:
            sc = logits[cur_len:cur_len+1].expand(sc.shape[0],-1) + sc
            v_tok, v_sc, v_idx = self.valid(ptr, sc)
            top_ptr, top_sc = self.topk(ptr, v_tok, v_sc, v_idx)
            ptr, sc = self.next(top_ptr, top_sc)
            cur_len += 1
            
            if cur_len >= seq_len or len(ptr) == 0 or self.hyp.is_done(sc.max().item(), cur_len):
                break
        return self.finalize(ptr, sc.squeeze(1).tolist())
        

### Example

In [None]:
#| hide
tb = TrieBeam(t, n_bm=5, len_penalty=1.0)

In [None]:
#| hide
i = F.log_softmax(o.logits, dim=-1)
r = tb.proc(i[0], len_penalty=5)

In [None]:
#| hide
r

{'seq2data_data2ptr': [5],
 'seq2data_score': [-0.0007121877170385026,
  -0.000994869279160678,
  -0.0009977314937062937,
  -0.0010254948953590068,
  -0.0006374765653163195],
 'seq2data_output_ids': [[101, 5619, 1104, 11765, 1107, 1726, 102],
  [101, 5619, 1104, 11765, 1107, 12247, 102],
  [101, 5619, 1104, 11765, 1107, 7217, 102],
  [101, 5619, 1104, 11765, 1107, 3900, 102],
  [101, 5619, 1104, 11765, 1107, 4471, 6722, 102]],
 'info2seq2data_idx': [150667, 297701, 208470, 276666, 278542],
 'info2seq2data_seq2ptr': [1, 1, 1, 1, 1],
 'info2seq2data_data2ptr': [5]}

## TrieBeamSearch

In [None]:
#| hide
PARAM = {
    'pad_tok': 0,
    'pad_side': 'right',
    'drop': True,
    'ret_t': True,
    'in_place': True,
    'collapse': True,
    'device': 'cpu',
    'n_bm': 5,
    'len_penalty': 1.2,
}

In [None]:
#| export
def tbs_proc(x): return x[0].proc(x[1])

In [None]:
#| export
class TrieBeamSearch:

    @delegates(XCPadOutputTfm.__init__)
    def __init__(self, trie:Trie, n_bm:int=5, max_bm:Optional[int]=None, len_penalty:Optional[float]=1.0, max_info:Optional[int]=None,
                 n_threads=3, **kwargs):
        store_attr('trie,n_bm,max_bm,len_penalty,max_info,n_threads')
        self.tfm = XCPadOutputTfm(**kwargs)
        
    def proc(self, model, inputs:Dict, n_bm:int=None, max_bm:Optional[int]=None, len_penalty:Optional[float]=None, 
             max_info:Optional[int]=None):
        store_attr('n_bm,max_bm,len_penalty,max_info', is_none=False)
        logits, attention_mask = F.log_softmax(model(**inputs).logits, dim=-1).cpu(), inputs['data_attention_mask'].bool().cpu()
        hyps = [TrieBeam(self.trie, self.n_bm, self.max_bm, self.len_penalty, self.max_info) for _ in range(logits.shape[0])]
        outputs = [h.proc(l[a]) for h,l,a in zip(hyps, logits, attention_mask)]
        outputs = self.tfm({k:list(chain(*[o[k] for o in outputs])) for k in outputs[0]})
        outputs['info2seq2data_score'] = torch.repeat_interleave(outputs['seq2data_score'], outputs['info2seq2data_seq2ptr'], dim=0)
        return outputs

    def proc_parallel(self, model, inputs:Dict, n_bm:int=None, max_bm:Optional[int]=None, len_penalty:Optional[float]=None, 
                      max_info:Optional[int]=None, n_threads=None):
        store_attr('n_bm,max_bm,len_penalty,max_info,n_threads', is_none=False)
        logits = F.log_softmax(model(**inputs).logits, dim=-1).cpu().share_memory_()
        attention_mask = inputs['data_attention_mask'].bool().cpu().share_memory_()
        hyps = [TrieBeam(self.trie, self.n_bm, self.max_bm, self.len_penalty, self.max_info) for _ in range(logits.shape[0])]
        
        with torch.no_grad(), Pool(processes=n_threads) as pool: outputs = list(pool.map(tbs_proc, list(zip(hyps, logits, attention_mask))))
        
        outputs = self.tfm({k:list(chain(*[o[k] for o in outputs])) for k in outputs[0]})
        outputs['info2seq2data_score'] = torch.repeat_interleave(outputs['seq2data_score'], outputs['info2seq2data_seq2ptr'], dim=0)
        return outputs
        

### Example 1

In [None]:
#| hide
tbs = TrieBeamSearch(t, n_bm=20, max_bm=20, len_penalty=1, max_info=1)

In [None]:
#| hide
bsz = 64
b = block.train.one_batch(bsz)
b = prepare_batch(m, b, m_args=['lbl2data_idx'])

In [None]:
#| hide
b.keys()

dict_keys(['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
#| hide
m, b = m.to('cuda'), b.to('cuda')

In [None]:
#| hide
r = tbs.proc(m, b)

In [None]:
r.keys()

dict_keys(['info2seq2data_idx', 'info2seq2data_seq2ptr', 'info2seq2data_data2ptr', 'seq2data_data2ptr', 'seq2data_score', 'seq2data_output_ids', 'info2seq2data_score'])

In [None]:
r.info2seq2data_data2ptr

tensor([20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20])

In [None]:
#| hide
r = tbs.proc_parallel(m, b)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

KeyboardInterrupt: 

In [None]:
#| hide
for k,v in r.items(): print(k, v.shape)

info2seq2data_idx torch.Size([2040])
info2seq2data_seq2ptr torch.Size([320])
info2seq2data_data2ptr torch.Size([64])
seq2data_data2ptr torch.Size([64])
seq2data_score torch.Size([320])
seq2data_output_ids torch.Size([320, 17])
info2seq2data_score torch.Size([2040])


In [None]:
#| hide
output = {}
output['targ_idx'] = b['lbl2data_idx'].cpu()
output['targ_ptr'] = b['lbl2data_data2ptr'].cpu()

output['pred_idx'] = r['info2seq2data_idx'].cpu()
output['pred_score'] = r['info2seq2data_score'].cpu()
output['pred_ptr'] = r['info2seq2data_data2ptr'].cpu()


In [None]:
#| hide
metric = PrecRecl(n_lbl, prop=block.train.dset.data.data_lbl, pk=5, rk=5, rep_pk=[1, 3, 5], rep_rk=[5])

In [None]:
#| hide
metric(**output)

{'P@1': 0.5,
 'P@3': 0.26562500000000006,
 'P@5': 0.18437499999999998,
 'N@1': 0.5,
 'N@3': 0.51767075,
 'N@5': 0.5278427,
 'PSP@1': 0.42532231190544867,
 'PSP@3': 0.44209107351860044,
 'PSP@5': 0.46167468811559276,
 'PSN@1': 0.42532232,
 'PSN@3': 0.47837257,
 'PSN@5': 0.5009967,
 'R@5': 0.5492559523809524}

### Example 2

In [None]:
#| hide
filterer = np.loadtxt('/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_test.txt')

In [None]:
#| hide
tbs = TrieBeamSearch(t, n_bm=5, len_penalty=1.5)
metric = PrecRecl(n_lbl, filterer, prop=block.train.dset.data.data_lbl, pk=5, rk=5, rep_pk=[1, 3, 5], rep_rk=[5])

block.test.bsz = 100

In [None]:
#| hide
def get_xo(inp, targ):
    return {
        'targ_idx':inp['lbl2data_idx'],
        'targ_ptr':inp['lbl2data_data2ptr'],
        'pred_idx':targ['info2seq2data_idx'],
        'pred_score':targ['info2seq2data_score'],
        'pred_ptr':targ['info2seq2data_data2ptr'],
    }
    

In [None]:
#| hide
m = m.to('cuda')

In [None]:
#| hide
metric.reset()

for b in tqdm(block.test.dl, total=len(block.test.dl)):
    b = prepare_batch(m, b).to('cuda')
    r = tbs.proc(m, b)
    o = get_xo(b, r)
    metric.accumulate(**o)
    

  0%|          | 0/1776 [00:00<?, ?it/s]

In [None]:
#| hide
metric.value

{'P@1': 0.15198715601498464,
 'P@3': 0.09240721441383382,
 'P@5': 0.06510323071290847,
 'N@1': 0.15198715,
 'N@3': 0.14293797,
 'N@5': 0.14415713,
 'PSP@1': 0.09805626244659127,
 'PSP@3': 0.0986113111983527,
 'PSP@5': 0.10019778518095093,
 'PSN@1': 0.098056264,
 'PSN@3': 0.10458853,
 'PSN@5': 0.109316155,
 'R@5': 0.14616284431636717}