# Clover fusion prediction 

In [1]:
#| default_exp 44-1-encoder-parallel-clover-for-wikiseealso

In [2]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
import os,torch, torch.multiprocessing as mp, pickle
from scipy import sparse
from xcai.basics import *
from xcai.models.PPP0XX import DBT013

In [4]:
import xclib.evaluation.xc_metrics as xc_metrics

In [5]:
#| export
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '12,13'
os.environ['WANDB_PROJECT']='xc-nlg_22-oak-training-pipeline-with-multitriplet-loss-and-clustering'

## Prediction

In [7]:
#| export
data_dir = '/home/aiscuser/scratch/datasets'
pkl_file = f'{data_dir}/processed/wikiseealso_data_distilbert-base-uncased_xcnlg_ngame.pkl'

with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [8]:
#| export
args = XCLearningArguments(
    output_dir='/home/aiscuser/outputs/44-encoder-parallel-clover-for-wikiseealso-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    predict_with_representation=True,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    output_concatenation_weight=1.0,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    fp16=True,
)

In [9]:
#| export
output_dir = f"/home/aiscuser/scratch/Projects/xc_nlg/outputs/{os.path.basename(args.output_dir)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

In [13]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = DBT013.from_pretrained(mname, ig_tok=0, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=5, 
                               apply_softmax=True, lw=0.01, tie_word_embeddings=False)

In [14]:
#| export
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [57]:
#| export
train_dset = block.train.dset.sample(n=50_000, seed=50)
metric = PrecRecl(block.n_lbl, train_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

learn = XCLearner(model=model, args=args, trie=trie, train_dataset=block.train.dset, eval_dataset=train_dset,
                  data_collator=block.collator, compute_metrics=metric)

if __name__ == '__main__':
    mp.freeze_support()
    train_pred = learn.predict(train_dset)
    
display_metric(train_pred.metrics)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [59]:
#| export
pred_dir = f'{mname}/predictions/'
os.makedirs(pred_dir, exist_ok=True)
with open(f'{pred_dir}/train_predictions.pkl', 'wb') as file: pickle.dump(train_pred, file)

In [20]:
#| export
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

learn = XCLearner(model=model, args=args, trie=trie, train_dataset=block.train.dset, eval_dataset=test_dset,
                  data_collator=block.collator, compute_metrics=metric)

if __name__ == '__main__':
    mp.freeze_support()
    test_pred = learn.predict(block.test.dset)
    
display_metric(test_pred.metrics)

In [None]:
#| export
pred_dir = f'{mname}/predictions/'
os.makedirs(pred_dir, exist_ok=True)
with open(f'{pred_dir}/test_predictions.pkl', 'wb') as file: pickle.dump(test_pred, file)

## Fusion

In [12]:
import numpy as np

In [10]:
pred_dir = f'{mname}/predictions/'

with open(f'{pred_dir}/train_predictions.pkl', 'rb') as f: 
    train_pred = pickle.load(f)
    
with open(f'{pred_dir}/test_predictions.pkl', 'rb') as f: 
    test_pred = pickle.load(f)

In [13]:
gen_pred,targ = get_output_sparse(**test_pred.gen_output, n_lbl=block.n_lbl)
repr_pred,_ = get_output_sparse(**test_pred.repr_output, n_lbl=block.n_lbl)
gen_pred.data = np.exp(gen_pred.data)

In [14]:
train_gen_pred,train_targ = get_output_sparse(**train_pred.gen_output, n_lbl=block.n_lbl)
train_repr_pred,_ = get_output_sparse(**train_pred.repr_output, n_lbl=block.n_lbl)
train_gen_pred.data = np.exp(train_gen_pred.data)

In [22]:
prop = xc_metrics.compute_inv_propesity(block.train.dset.data.data_lbl, A=0.55, B=1.5)
fuser = ScoreFusion(prop)

In [23]:
fuser.fit(train_gen_pred, train_repr_pred, train_targ)

In [30]:
pred = fuser.predict(gen_pred, repr_pred, beta=0.1)

In [31]:
output = {
    'targ_idx': test_pred.gen_output['targ_idx'],
    'targ_ptr': test_pred.gen_output['targ_ptr'],
    'pred_idx': torch.tensor(pred.indices),
    'pred_ptr': torch.tensor([q-p for p,q in zip(pred.indptr, pred.indptr[1:])]),
    'pred_score': torch.tensor(pred.data),
}

In [32]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
m = metric(**output)
display_metric(m, remove_prefix=False)

  self._set_arrayXarray(i, j, x)
  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200
0,32.3128,21.7133,16.4126,10.3124,32.3128,32.1152,33.1314,34.8998,25.7029,28.1331,30.4293,34.7173,25.7029,28.2914,30.0372,32.1681,39.5329,52.4273,55.1471


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200
0,32.3128,21.7133,16.4126,10.3124,32.3128,32.1152,33.1314,34.8998,25.7029,28.1331,30.4293,34.7173,25.7029,28.2914,30.0372,32.1681,39.5329,52.4273,55.1471


In [52]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
m = metric(**output)
display_metric(m, remove_prefix=False)

  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200
0,31.1399,20.7286,15.6277,9.8765,31.1399,30.8007,31.7445,33.5378,24.4896,26.6505,28.803,33.1025,24.4896,26.9292,28.598,30.7434,38.1122,52.0618,55.2864


In [55]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
m = metric(**output)
display_metric(m, remove_prefix=False)

  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200
0,30.4053,20.2255,15.3097,9.7354,30.4053,30.135,31.1323,32.987,24.9653,26.6902,28.7182,32.9481,24.9653,27.0559,28.6494,30.7709,37.6288,51.9256,55.2399
