In [1]:
#| default_exp 09-anshul-trie-implementation

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import os, pandas as pd, warnings, torch, pickle, numpy as np
from typing import Dict, Optional, List
from tqdm.auto import tqdm
from scipy import stats
import scipy.sparse as sp
import torch.nn.functional as F
from itertools import chain

from xcai.basics import *
from xcai.models.MMM00X import DBT007, DBT008, BT0002
from xcai.transform import AugmentMetaInputIdsTfm

In [4]:
#| export
os.environ['WANDB_MODE'] = 'disabled'

## `Trie`

In [86]:
class Trie:
    def __init__(self, max_height=32, sos_id=101, eos_id=102, pad_token=0, n_bm=10, len_penalty=1.0):
        store_attr('max_height,sos_id,eos_id,pad_token,n_bm,len_penalty')
        self.trie, self.hash = {}, None
    
    def build(self, X, y):
        assert(len(X) == len(y))
        self.hash = y
        trie_dict = {}
        for seq_id, seq in enumerate(tqdm(X)):
            next_dict = trie_dict
            for token in seq[:self.max_height]:
                next_dict[token] = next_dict.get(
                    token, {"next": {}, "occurs":0, "lbls": [],
                            "point_to": token, "is_leaf": False}
                )
                next_dict[token]["lbls"].append(seq_id)
                next_dict[token]["occurs"] += 1
                if token == self.eos_id:
                    next_dict[token]["is_leaf"] = True
                    break
                next_dict = next_dict[token]["next"]
        self.trie = trie_dict
    
    def decode_text(self, X):
        next_level = {'next':self.trie, 'lbls': [-1]}
        for token in X:
            items = next_level['next'].get(token, None)
            if items is None:
                return next_level["lbls"]
            if items["is_leaf"]:
                return items["lbls"]
            next_level = items
        return items["lbls"]
    
    def _padded_np(self, lol, fill_value, max_seq):
        tokens = np.full((len(lol), max_seq), fill_value, dtype=type(lol[0][0]))
        masks = np.zeros((len(lol), max_seq), dtype=np.int32)
        for i in np.arange(len(lol)):
            tokens[i, :len(lol[i])] = lol[i]
            masks[i, :len(lol[i])] = 1
        return tokens, masks
    
    def row_topk(self, score, k=10, return_scores=False, sort=False):
        index = np.argpartition(score, -k, axis=0)[-k:]
        if sort:
            _score = score[index]
            _index = np.argsort(_score, axis=0)
            index = index[_index]
        if return_scores:
            score = score[index]
            return index, score
        return index
    
    def batch_topk(self, score, k=10, return_scores=False, sort=False):
        index = np.argpartition(score, -k, axis=1)[:, -k:]
        if sort:
            _score = np.take_along_axis(score, index, axis=1)
            _index = np.argsort(_score, axis=1)
            index = np.take_along_axis(index, _index, axis=1)
        if return_scores:
            score = np.take_along_axis(score, index, axis=1)
            return index, score
        return index
    
    def _snl_one(self, tries, old_scores, curr_score, len_score, top_k_index):
        _tries, _score, _l_scr = [], [], []
        for col in top_k_index:
            items = list(tries[col]['next'].items())
            if len(items) == 0:
                key, value = [self.pad_token], [{"next":{}, "lbls": tries[col]["lbls"]}]
                l_items = [len_score[col]]
            else:
                key, value = list(zip(*items))
                l_items = np.full((len(key),), len_score[col] + 1)
            _tries.extend(value)
            _score.extend(curr_score[list(key)] + old_scores[col])
            _l_scr.extend(l_items)
        return _tries, _score, _l_scr
    
    def _agl_one(self, lol, sorted_index, sorted_scores):
        lol_lbs, lbl_scr = [], []
        for col, item in enumerate(sorted_index):
            _items = np.concatenate(
                        list(map(lambda x: self.hash[x], lol[item]["lbls"]))
                    )
            lol_lbs.extend(_items)
            lbl_scr.extend(np.full((_items.size,), sorted_scores[col], dtype=np.float32))
        return lol_lbs, lbl_scr

    def snl(self, tries, old_scores, curr_score, len_score, top_k_index):
        _tries, _score, _l_scr, _max_seq = [], [], [], -1
        for i in np.arange(top_k_index.shape[0]):
            __tries, __score, __l_scr = self._snl_one(
                tries[i], old_scores[i], curr_score[i], len_score[i], top_k_index[i])
            _max_seq = max(_max_seq, len(__tries))
            _tries.append(__tries)
            _score.append(__score)
            _l_scr.append(__l_scr)
        _score, _ = self._padded_np(_score, -np.inf, _max_seq)
        _l_scr, _ = self._padded_np(_l_scr, 0.001, _max_seq)
        return _tries, _score, _l_scr
    
    def agl(self, lol_trie, sorted_index, sorted_scores):
        lol_lbs, lbl_scr = [], []
        for rid in np.arange(sorted_index.shape[0]):
            _lol_lbs, _lbl_scr = self._agl_one(lol_trie[rid], sorted_index[rid], sorted_scores[rid])
            lol_lbs.append(_lol_lbs)
            lbl_scr.append(_lbl_scr)
        return lol_lbs, lbl_scr
    
    def decode_batch(self, preds, beam=10, l_penalty=0, start_seq=1):
        _token = np.full((len(preds), 1), self.sos_id)
        _tries = [[self.trie[self.sos_id]] for _ in range(len(preds))]
        _score = np.zeros((len(preds), 1), dtype=np.float32)
        _index = np.zeros((len(preds), 1), dtype=np.int32)
        _l_scr = np.ones((len(preds), 1), dtype=np.int32)
        for i in np.arange(start_seq, self.max_height):
            _tries, _score, _l_scr = self.snl(_tries, _score, preds[:, i], _l_scr, _index)
            _score = np.multiply(_score, np.power(_l_scr,-l_penalty))
            _index = self.batch_topk(_score, beam, i+1==self.max_height, i+1==self.max_height)
        return self.agl(_tries, _index[0], _index[1])

    def proc(self, model, inputs:Dict, n_bm:int=None, max_bm:Optional[int]=None, len_penalty:Optional[float]=None, 
             max_info:Optional[int]=None):
        store_attr('n_bm,len_penalty', is_none=False)
        logits = F.log_softmax(model(**inputs).logits, dim=2).cpu().detach().numpy()
        logits = np.concatenate([logits, np.zeros((logits.shape[0], max(0, self.max_height-logits.shape[1]), logits.shape[2]))], axis=1)
        idx, scores = self.decode_batch(logits, beam=self.n_bm, l_penalty=self.len_penalty)
        outputs = {
            'info2seq2data_idx': torch.tensor(list(chain(*idx))),
            'info2seq2data_score': torch.tensor(list(chain(*scores))),
            'info2seq2data_data2ptr': torch.tensor([len(o) for o in idx]),
        }
        return outputs
        
        
    def decode_one(self, pred, beam, l_penalty, start_seq=1):
        _token = [self.sos_id]
        _tries = [self.trie[self.sos_id]]
        _score = np.zeros((1,), dtype=np.float32)
        _index = np.zeros((1,), dtype=np.int32)
        _l_scr = np.ones((1, ), dtype=np.int32)
        for i in np.arange(start_seq, self.max_height):
            _tries, _score, _l_scr = self._snl_one(_tries, _score, pred[i], _l_scr, _index)
            _score = np.multiply(_score, np.power(_l_scr,-l_penalty))
            _index = self.row_topk(_score, beam, i+1==self.max_height, i+1==self.max_height)
        return self._agl_one(_tries, _index[0], _index[1])
    
    def decode_serial(self, preds, beam=10, l_penalty=0):
        labels = []
        scores = []
        for pred in preds:
            _labels, _scores = self.decode_one(pred, beam, l_penalty)
            labels.append(_labels)
            scores.append(_scores)
        return labels, scores
    

## Benchmarking

In [97]:
#| export
dump_dir = '/scratch/scai/phd/aiz218323/Projects/xc_nlg/outputs/09-anshul-trie-implementation/'

In [133]:
fname = '/home/aiscuser/scratch/Projects/xc_nlg/outputs/00-nar-trie-inference-benchmarking/data/block_distilbert-base-uncased.pkl'
with open(fname, 'rb') as file: block = pickle.load(file)

In [134]:
#| export
mname = f'/home/aiscuser/scratch/Projects/XC-NLG/models/distilbert-base-uncased_RB33-NAR-1+8-2_(mapped)LF-WikiSeeAlsoTitles-320K/checkpoint-168000'
model = DBT007.from_pretrained(mname, tn_targ=10_000, ig_tok=0)

Some weights of DBT007 were not initialized from the model checkpoint at /home/aiscuser/scratch/Projects/XC-NLG/models/distilbert-base-uncased_RB33-NAR-1+8-2_(mapped)LF-WikiSeeAlsoTitles-320K/checkpoint-168000 and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [126]:
#| export
args = XCLearningArguments(
    output_dir=f'{dump_dir}/distilbert-base-uncased_RB33-NAR-1+8-2_(mapped)LF-WikiSeeAlsoTitles-320K',
    generation_length_penalty=0.0,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [127]:
#| export
test_dset = block.test.dset.sample(n=1000, seed=50)

In [128]:
#| export
metric = PrecRecl(test_dset.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[2, 3, 10, 50, 100, 200])

In [129]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    data_collator=block.collator, 
    compute_metrics=metric,
)

In [130]:
trie = Trie(max_height=32, n_bm=10)
lbl_toks = test_dset.lbl_info['input_ids']
lbl_info = [[i] for i in range(len(lbl_toks))]
trie.build(lbl_toks, lbl_info)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [131]:
learn.tbs = trie

In [132]:
o = learn.predict(test_dset)
display_metric(o.metrics)



  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@2,R@3,R@10,R@50,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,9.9,8.2667,6.68,3.37,9.9,11.6549,12.774,12.5301,6.6157,8.9941,10.3577,9.4737,6.6157,8.3006,9.3411,9.2209,8.9202,11.5331,15.148,15.148,15.148,15.148,7.8529,99.5219,10.048,0.08


In [135]:
trie.n_bm, trie.len_penalty

(5, 0.0)

In [None]:
o = learn.predict(test_dset)
display_metric(o.metrics)

  _score = np.multiply(_score, np.power(_l_scr,-l_penalty))


  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@2,R@3,R@10,R@50,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,8.35,8.9,8.61,4.38,8.35,11.7853,14.8046,14.5668,6.3292,9.697,12.4711,11.4596,6.3292,8.9395,11.0119,10.9659,8.2428,12.6652,20.3485,20.3485,20.3485,20.3485,7.0555,209.895,9.529,0.3


## Zero shot

In [13]:
fname = '/home/aiscuser/scratch/Projects/xc_nlg/outputs/00-nar-trie-inference-benchmarking/data/block.pkl'
with open(fname, 'rb') as file: block = pickle.load(file)

In [89]:
args = XCLearningArguments(
    output_dir='/home/aiscuser/scratch/Projects/xc_nlg/outputs/default',
    generation_length_penalty=0.0,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

In [19]:
model = BT0002.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0002 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [73]:
test_dset = block.test.dset.sample(n=1000, seed=50)

In [74]:
metric = PrecRecl(test_dset.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [90]:
learn = XCLearner(
    model=model, 
    args=args,
    data_collator=block.collator, 
    compute_metrics=metric,
)

In [91]:
trie = Trie(max_height=32, sos_id=101, eos_id=102, pad_token=0, n_bm=5, len_penalty=0.0)
lbl_toks = test_dset.lbl_info['input_ids']
lbl_info = [[i] for i in range(len(lbl_toks))]
trie.build(lbl_toks, lbl_info)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [93]:
learn.tbs = trie

In [94]:
%%time
o = learn.predict(test_dset)
display_metric(o.metrics)



  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,4.0,2.7333,2.0,1.01,4.0,4.0084,4.1433,4.0635,4.4408,4.6646,4.7041,4.317,4.4408,4.765,4.9661,4.9313,4.6338,4.6338,4.6338,15.5659,70.2793,14.229,0.114


CPU times: user 1min 21s, sys: 1.37 s, total: 1min 23s
Wall time: 1min 10s


In [95]:
trie.n_bm, trie.len_penalty

(5, 0.0)

## Zero shot after integration

In [8]:
from xcai.generation.generate import XCTrieBeamSearch

In [5]:
fname = '/home/aiscuser/scratch/Projects/xc_nlg/outputs/00-nar-trie-inference-benchmarking/data/block.pkl'
with open(fname, 'rb') as file: block = pickle.load(file)

In [6]:
args = XCLearningArguments(
    output_dir='/home/aiscuser/scratch/Projects/xc_nlg/outputs/default',
    generation_length_penalty=0.0,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

In [7]:
model = BT0002.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0002 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(test_dset.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [13]:
tbs = XCTrieBeamSearch.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [21]:
learn = XCLearner(
    model=model, 
    args=args,
    trie_generator=tbs,
    train_dataset=block.train.dset,
    eval_dataset=test_dset,
    data_collator=block.collator, 
    compute_metrics=metric,
)

In [22]:
learn.evaluate()



  self._set_arrayXarray(i, j, x)


{'eval_loss': 15.70250415802002,
 'eval_P@1': 0.038,
 'eval_P@10': 0.009349999999999999,
 'eval_P@3': 0.024666666666666698,
 'eval_P@5': 0.0186,
 'eval_N@1': 0.03799999877810478,
 'eval_N@10': 0.038646530359983444,
 'eval_N@3': 0.037448786199092865,
 'eval_N@5': 0.039278190582990646,
 'eval_PSP@1': 0.04046994856832396,
 'eval_PSP@10': 0.03828862374382027,
 'eval_PSP@3': 0.04035186294636745,
 'eval_PSP@5': 0.04208168960973801,
 'eval_PSN@1': 0.0404699482023716,
 'eval_PSN@10': 0.04572494700551033,
 'eval_PSN@3': 0.04320825636386871,
 'eval_PSN@5': 0.0458955392241478,
 'eval_R@200': 0.04337371003179827,
 'eval_R@10': 0.04337371003179827,
 'eval_R@100': 0.04337371003179827,
 'eval_runtime': 112.7881,
 'eval_samples_per_second': 17.732,
 'eval_steps_per_second': 0.142}