In [None]:
#| default_exp 09_combine-predictions

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from dataclasses import dataclass
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.analysis import *

from fastcore.utils import *

from xclib.utils.sparse import retain_topk
import xclib.evaluation.xc_metrics as xc_metrics

## Helper

In [None]:
#| export
@dataclass
class PredictionArguements:
    use_centroid_label_representation: bool = False
    use_centroid_data_metadata: bool = True
    centroid_data_attribute_representation: str = 'data_repr'
    centroid_data_batch_size: int = 2048
    use_teacher_lbl_representation: bool = False
    use_teacher_data_representation: bool = False
    

In [None]:
#| export
def get_predictions(pred_dir, args):
    train_o, test_o = None, None
    
    if args.use_centroid_label_representation:
        if args.use_teacher_data_representation: 
            test_pred_file = f'{pred_dir}/test_predictions_teacher_centroid.pkl'
        elif args.centroid_data_attribute_representation == 'data_repr': 
            test_pred_file = f'{pred_dir}/test_predictions_student-repr_centroid.pkl'
        else:
            test_pred_file = f'{pred_dir}/test_predictions_student-fused-repr_centroid.pkl'
    else:
        if args.use_teacher_lbl_representation: 
            test_pred_file = f'{pred_dir}/test_predictions_teacher.pkl'
            train_pred_file = f'{pred_dir}/train_predictions_teacher.pkl'
        else:
            test_pred_file = f'{pred_dir}/test_predictions.pkl'
            train_pred_file = f'{pred_dir}/train_predictions.pkl'
        
        if os.path.exists(train_pred_file):
            with open(train_pred_file, 'rb') as file: train_o = pickle.load(file)
    
    if os.path.exists(test_pred_file):
        with open(test_pred_file, 'rb') as file: test_o = pickle.load(file)
            
    return test_o, train_o
    

In [None]:
#| export
def get_sparse_predictions(dirname, run_name, use_centroid_label_representation, use_centroid_data_metadata, 
                           centroid_data_attribute_representation, centroid_data_batch_size, use_teacher_lbl_representation, 
                           use_teacher_data_representation):
    
    output_dir = f"{dirname}/{run_name}"
    pred_dir = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}/predictions'
    
    args = PredictionArguements(
        use_centroid_label_representation=use_centroid_label_representation,
        use_centroid_data_metadata=use_centroid_data_metadata,
        centroid_data_attribute_representation=centroid_data_attribute_representation,
        centroid_data_batch_size=centroid_data_batch_size,
        use_teacher_lbl_representation=use_teacher_lbl_representation,
        use_teacher_data_representation=use_teacher_data_representation,
    )
    test_o, train_o = get_predictions(pred_dir, args)
    
    test_lbl = get_pred_sparse(test_o, block.n_lbl)
    train_lbl = get_pred_sparse(train_o, block.n_lbl)

    return test_lbl, train_lbl
    

## Load data

In [None]:
build_block = False

pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

if build_block:
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    block = XCBlock.from_cfg(data_dir, 'data_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                             sampling_features=[('lbl2data',4), ('lnk2data',3)], oversample=True)
    with open(pkl_file, 'wb') as file: pickle.dump(block, file)
else:
    with open(pkl_file, 'rb') as file: block = pickle.load(file)

""" Prune metadata """
data_meta = retain_topk(block.train.dset.meta.lnk_meta.data_meta, k=5)
lbl_meta = block.train.dset.meta.lnk_meta.lbl_meta
block.train.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)

data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
lbl_meta = block.test.dset.meta.lnk_meta.lbl_meta
block.test.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)

block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4),('lnk2data',3)]
block.collator.tfms.tfms[0].oversample = True

block.train.dset.meta.lnk_meta.meta_info = None
block.test.dset.meta.lnk_meta.meta_info = None


## Load predictions

In [None]:
output_dir = '/home/scai/phd/aiz218323/scratch/outputs/'
run_name = '64-ngame-ep-for-wikiseealso-with-entropy-loss-1-0'

use_centroid_label_representation=False
use_centroid_data_metadata=True
centroid_data_attribute_representation='data_repr'
centroid_data_batch_size=2048
use_teacher_lbl_representation=False
use_teacher_data_representation=False

test_a, train_a = get_sparse_predictions(output_dir, run_name, use_centroid_label_representation, use_centroid_data_metadata, 
                                         centroid_data_attribute_representation, centroid_data_batch_size, use_teacher_lbl_representation, 
                                         use_teacher_data_representation)

In [None]:
run_name = '64-ngame-ep-for-wikiseealso-with-entropy-loss-1-0'

use_centroid_label_representation=False
use_centroid_data_metadata=True
centroid_data_attribute_representation='data_repr'
centroid_data_batch_size=2048
use_teacher_lbl_representation=False
use_teacher_data_representation=False

test_b, train_b = get_sparse_predictions(output_dir, run_name, use_centroid_label_representation, use_centroid_data_metadata, 
                                         centroid_data_attribute_representation, centroid_data_batch_size, use_teacher_lbl_representation, 
                                         use_teacher_data_representation)

## Fusion

In [None]:
prop = xc_metrics.compute_inv_propesity(block.train.dset.data.data_lbl, A=0.55, B=1.5)
fuser = ScoreFusion(prop)

In [None]:
fuser.fit(train_a, train_b, block.train.dset.data.data_lbl, n_samples=1000)

In [None]:
pred = fuser.predict(test_a, test_b, beta=1.0)

output = {
    'targ_idx': torch.tensor(block.test.dset.data.data_lbl.indices),
    'targ_ptr': torch.tensor([q-p for p,q in zip(block.test.dset.data.data_lbl.indptr, block.test.dset.data.data_lbl.indptr[1:])]),
    'pred_idx': torch.tensor(pred.indices),
    'pred_ptr': torch.tensor([q-p for p,q in zip(pred.indptr, pred.indptr[1:])]),
    'pred_score': torch.tensor(pred.data),
}

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
m = metric(**output)
display_metric(m, remove_prefix=False)

  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200
0,23.9732,16.7873,13.121,8.7125,23.9732,24.5717,25.9121,28.1608,16.5622,19.7185,22.531,27.8253,16.5622,19.5033,21.4851,24.0724,34.4291,53.4049,58.4375


## Driver

In [None]:
#| export
if __name__ == '__main__':
    build_block = False
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_distilbert-base-uncased_xcs.pkl'
    if build_block:
        block = XCBlock.from_cfg(data_dir, 'data', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                                 sampling_features=[('lbl2data',1)], oversample=False)
        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)
    
    block.collator.tfms.tfms[0].sampling_features = [('lbl2data',1)]
    block.collator.tfms.tfms[0].oversample = False

    """ Load predictions """
    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/'
    run_name = '64-ngame-ep-for-wikiseealso-with-entropy-loss-1-0'
    
    use_centroid_label_representation=False
    use_centroid_data_metadata=True
    centroid_data_attribute_representation='data_repr'
    centroid_data_batch_size=2048
    use_teacher_lbl_representation=False
    use_teacher_data_representation=False
    
    test_a, train_a = get_sparse_predictions(output_dir, run_name, use_centroid_label_representation, use_centroid_data_metadata, 
                                             centroid_data_attribute_representation, centroid_data_batch_size, use_teacher_lbl_representation, 
                                             use_teacher_data_representation)

    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/'
    run_name = '64-ngame-ep-for-wikiseealso-with-entropy-loss-1-0'

    use_centroid_label_representation=False
    use_centroid_data_metadata=True
    centroid_data_attribute_representation='data_repr'
    centroid_data_batch_size=2048
    use_teacher_lbl_representation=False
    use_teacher_data_representation=False
    
    test_b, train_b = get_sparse_predictions(output_dir, run_name, use_centroid_label_representation, use_centroid_data_metadata, 
                                             centroid_data_attribute_representation, centroid_data_batch_size, use_teacher_lbl_representation, 
                                             use_teacher_data_representation)

    """ Fusion """
    prop = xc_metrics.compute_inv_propesity(block.train.dset.data.data_lbl, A=0.55, B=1.5)
    fuser = ScoreFusion(prop)
    fuser.fit(train_a, train_b, block.train.dset.data.data_lbl, n_samples=1000)

    pred = fuser.predict(test_a, test_b, beta=1.0)
    output = {
        'targ_idx': torch.tensor(block.test.dset.data.data_lbl.indices),
        'targ_ptr': torch.tensor([q-p for p,q in zip(block.test.dset.data.data_lbl.indptr, block.test.dset.data.data_lbl.indptr[1:])]),
        'pred_idx': torch.tensor(pred.indices),
        'pred_ptr': torch.tensor([q-p for p,q in zip(pred.indptr, pred.indptr[1:])]),
        'pred_score': torch.tensor(pred.data),
    }
    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
    m = metric(**output)
    print(m.metrics)
    