# RAMEN-NLG prediction fusion 

In [1]:
#| default_exp 23-1-ramen-fusion-prediction-1-0

In [2]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [24]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from scipy import sparse
from xcai.basics import *
from xcai.models.MMM0XX import DBT014

In [4]:
import xclib.evaluation.xc_metrics as xc_metrics

In [6]:
#| export
os.environ['WANDB_MODE'] = 'disabled'

In [7]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
os.environ['WANDB_PROJECT']='23-1-ramen-fusion-prediction-1-0'

## Prediction

In [8]:
#| export
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets', 'data', valid_pct=0.001, tfm='xcnlg', 
                         tokenizer='distilbert-base-uncased', smp_features=[('lbl2data',1,2)])

  self._set_arrayXarray(i, j, x)


In [13]:
#| export
args = XCLearningArguments(
    output_dir='/home/aiscuser/scratch/Projects/xc_nlg/outputs/23-ramen-style-oak-training-pipeline-with-multitriplet-loss-with-clustering-2-6/',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    predict_with_representation=True,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [14]:
mname = f'{args.output_dir}/{os.path.basename(get_best_model(args.output_dir))}'

In [11]:
#| export
mname = f'{args.output_dir}/{os.path.basename(get_best_model(args.output_dir))}'
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = DBT014.from_pretrained(mname, ig_tok=0, bsz=bsz, tn_targ=1000, margin=0.3, tau=0.1,
                               n_negatives=5, apply_softmax=True, lw=0.01, m_lw=0.1, meta_prefix='hlk')

Some weights of the model checkpoint at /home/aiscuser/scratch/Projects/xc_nlg/outputs/23-ramen-style-oak-training-pipeline-with-multitriplet-loss-with-clustering-2-6//checkpoint-124000 were not used when initializing DBT014: ['gen_lfn.o', 'rep_lfn.u', 'rep_lfn.v']
- This IS expected if you are initializing DBT014 from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DBT014 from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
#| export
trie = XCTrie.from_block(block)

In [24]:
#| export
train_dset = block.train.dset.sample(n=50_000, seed=50)
metric = PrecRecl(block.n_lbl, train_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

learn = XCLearner(model=model, args=args, trie=trie, train_dataset=block.train.dset, eval_dataset=train_dset,
                  data_collator=block.collator, compute_metrics=metric)

#| export
if __name__ == '__main__':
    mp.freeze_support()
    train_pred = learn.predict(train_dset)
    
display_metric(train_pred.metrics)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [59]:
#| export
pred_dir = f'{mname}/predictions/'
os.makedirs(pred_dir, exist_ok=True)
with open(f'{pred_dir}/train_predictions.pkl', 'wb') as file: pickle.dump(train_pred, file)

In [43]:
#| export
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

learn = XCLearner(model=model, args=args, trie=trie, train_dataset=block.train.dset, eval_dataset=test_dset,
                  data_collator=block.collator, compute_metrics=metric)

if __name__ == '__main__':
    mp.freeze_support()
    test_pred = learn.predict(block.test.dset)
    
display_metric(test_pred.metrics)

In [None]:
#| export
pred_dir = f'{mname}/predictions/'
os.makedirs(pred_dir, exist_ok=True)
with open(f'{pred_dir}/test_predictions.pkl', 'wb') as file: pickle.dump(test_pred, file)

## Fusion

In [19]:
def get_output_sparse(pred_idx, pred_ptr, pred_score, targ_idx, targ_ptr, n_lbl):
    n_data = pred_ptr.shape[0]
    
    pred_ptr = torch.cat([torch.zeros((1,), dtype=torch.long), pred_ptr.cumsum(dim=0)])
    
    targ_ptr = torch.cat([torch.zeros((1,), dtype=torch.long), targ_ptr.cumsum(dim=0)])
    targ_score = torch.ones((targ_idx.shape[0],), dtype=torch.long)
    
    pred = sparse.csr_matrix((pred_score,pred_idx,pred_ptr), shape=(n_data, n_lbl))
    targ = sparse.csr_matrix((targ_score,targ_idx,targ_ptr), shape=(n_data, n_lbl))
    return pred, targ


In [20]:
pred_dir = f'{mname}/predictions/'

with open(f'{pred_dir}/train_predictions.pkl', 'rb') as file: 
    train_pred = pickle.load(file)
    
with open(f'{pred_dir}/test_predictions.pkl', 'rb') as file: 
    test_pred = pickle.load(file)

In [25]:
gen_pred,targ = get_output_sparse(**test_pred.gen_output, n_lbl=block.n_lbl)
repr_pred,_ = get_output_sparse(**test_pred.repr_output, n_lbl=block.n_lbl)
gen_pred.data = np.exp(gen_pred.data)

In [26]:
train_gen_pred,train_targ = get_output_sparse(**train_pred.gen_output, n_lbl=block.n_lbl)
train_repr_pred,_ = get_output_sparse(**train_pred.repr_output, n_lbl=block.n_lbl)
train_gen_pred.data = np.exp(train_gen_pred.data)

In [27]:
prop = xc_metrics.compute_inv_propesity(block.train.dset.data.data_lbl, A=0.5, B=0.4)
fuser = ScoreFusion(prop)

In [28]:
fuser.fit(train_gen_pred, train_repr_pred, train_targ)

In [39]:
pred = fuser.predict(gen_pred, repr_pred, beta=0.1)

In [40]:
output = {
    'targ_idx': test_pred.gen_output['targ_idx'],
    'targ_ptr': test_pred.gen_output['targ_ptr'],
    'pred_idx': torch.tensor(pred.indices),
    'pred_ptr': torch.tensor([q-p for p,q in zip(pred.indptr, pred.indptr[1:])]),
    'pred_score': torch.tensor(pred.data),
}

In [41]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
m = metric(**output)
display_metric(m, remove_prefix=False)

  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200
0,32.1562,21.5244,16.2393,10.2679,32.1562,31.938,32.9399,34.8155,25.5299,27.8744,30.1075,34.573,25.5299,28.1301,29.8701,32.096,39.6081,54.3698,57.9056


In [35]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
m = metric(**output)
display_metric(m, remove_prefix=False)

  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200
0,31.5427,21.1346,16.0016,10.1471,31.5427,31.4279,32.4816,34.3752,26.0245,27.9952,30.1248,34.4638,26.0245,28.35,30.0287,32.2021,39.1777,54.2491,57.8725
