In [None]:
#| default_exp 00_ngame-for-msmarco-inference

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from xcai.basics import *
from xcai.models.PPP0XX import DBT009,DBT011

In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco'

## Setup

In [None]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/00_ngame-for-msmarco'
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt.json'
config_key = 'data_entity-gpt'

mname = 'sentence-transformers/msmarco-distilbert-dot-v5'

In [None]:
do_train_inference = True
do_test_inference = True

save_train_inference = False
save_test_inference = False

save_representation = False

use_sxc_sampler, only_test = True, True

In [None]:
pkl_file = f'{pkl_dir}/mogicX/msmarco_data_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
if only_test: pkl_file = f'{pkl_file}_only-test'
pkl_file = f'{pkl_file}.joblib'

In [None]:
do_inference = do_train_inference or do_test_inference or save_train_inference or save_test_inference or save_representation

In [None]:
pkl_file

'/scratch/scai/phd/aiz218323/datasets/processed//mogicX/msmarco_data_distilbert-base-uncased_sxc_only-test.joblib'

In [None]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, do_build=True, only_test=True)

CPU times: user 53min 33s, sys: 10min 11s, total: 1h 3min 44s
Wall time: 19min


In [None]:
args = XCLearningArguments(
    output_dir=output_dir,
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=10,
    save_steps=10,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,                                                                                                                                          warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-5,

    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',

    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
)

comet_ml version 3.39.1 is installed, but version 3.43.2 or higher is required. Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=3.43.2'.


In [None]:
metric = PrecReclMrr(block.test.dset.n_lbl, block.test.data_lbl_filterer,
                     pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

In [None]:
def model_fn(mname, bsz):
    model = DBT009.from_pretrained(mname, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                                   apply_softmax=True, use_encoder_parallel=True)
    return model

def init_fn(model): 
    model.init_dr_head()

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = load_model(args.output_dir, model_fn, {"mname": mname, "bsz": bsz}, init_fn, do_inference=do_inference, use_pretrained=True)

Some weights of DBT009 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-dot-v5 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
learn = XCLearner(
    model=model,
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

main(learn, input_args, n_lbl=block.test.dset.n_lbl)

## MRR

In [None]:
"""
This module computes evaluation metrics for MSMARCO dataset on the ranking task. Intenral hard coded eval files version. DO NOT PUBLISH!
Command line:
python msmarco_eval_ranking.py <path_to_candidate_file>

Creation Date : 06/12/2018
Last Modified : 4/09/2019
Authors : Daniel Campos <dacamp@microsoft.com>, Rutger van Haasteren <ruvanh@microsoft.com>
python3 get_score.py beir_dev_qrels.tsv topk_og_distill.txt 
"""
import sys
import statistics
import json

from collections import Counter

MaxMRRRank = 10

def load_reference_from_stream(f):
    """Load Reference reference relevant passages
    Args:f (stream): stream to load.
    Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 
    """
    qids_to_relevant_passageids = {}
    for l in f:
        try:
            l = l.strip().split('\t')
            qid = int(l[0])
            if qid in qids_to_relevant_passageids:
                pass
            else:
                qids_to_relevant_passageids[qid] = []
            qids_to_relevant_passageids[qid].append(int(l[1]))
        except Exception as e:
            print(e)
            raise IOError('\"%s\" is not valid format' % l)
    return qids_to_relevant_passageids

def load_reference(path_to_reference):
    """Load Reference reference relevant passages
    Args:path_to_reference (str): path to a file to load.
    Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 
    """
    with open(path_to_reference,'r') as f:
        qids_to_relevant_passageids = load_reference_from_stream(f)
    return qids_to_relevant_passageids

def load_candidate_from_stream(f):
    """Load candidate data from a stream.
    Args:f (stream): stream to load.
    Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
    """
    qid_to_ranked_candidate_passages = {}
    for l in f:
        try:
            l = l.strip().split('\t')
            qid = int(l[0])
            pid = int(l[1])
            rank = int(l[2])
            if qid in qid_to_ranked_candidate_passages:
                pass    
            else:
                # By default, all PIDs in the list of 1000 are 0. Only override those that are given
                tmp = [0] * 1000
                qid_to_ranked_candidate_passages[qid] = tmp
            qid_to_ranked_candidate_passages[qid][rank]=pid
        except:
            raise IOError('\"%s\" is not valid format' % l)
    return qid_to_ranked_candidate_passages
                
def load_candidate(path_to_candidate):
    """Load candidate data from a file.
    Args:path_to_candidate (str): path to file to load.
    Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
    """
    
    with open(path_to_candidate,'r') as f:
        qid_to_ranked_candidate_passages = load_candidate_from_stream(f)
    return qid_to_ranked_candidate_passages

def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
    """Perform quality checks on the dictionaries

    Args:
    p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
        Dict as read in with load_reference or load_reference_from_stream
    p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
    Returns:
        bool,str: Boolean whether allowed, message to be shown in case of a problem
    """
    message = ''
    allowed = True

    # Create sets of the QIDs for the submitted and reference queries
    candidate_set = set(qids_to_ranked_candidate_passages.keys())
    ref_set = set(qids_to_relevant_passageids.keys())

    # Check that we do not have multiple passages per query
    for qid in qids_to_ranked_candidate_passages:
        # Remove all zeros from the candidates
        duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1])

        if len(duplicate_pids-set([0])) > 0:
            message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format(
                    qid=qid, pid=list(duplicate_pids)[0])
            allowed = False

    return allowed, message

def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):

    prev_fault_list =  open('fault_list.txt' , 'r').readlines()
    prev_fault_list = [int(x.strip()) for x in prev_fault_list]

    score_dicts = []

    """Compute MRR metric
    Args:    
    p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
        Dict as read in with load_reference or load_reference_from_stream
    p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
    Returns:
        dict: dictionary of metrics {'MRR': <MRR Score>}
    """
    all_scores = {}
    fault_list = []
    MRR = 0
    qids_with_relevant_passages = 0
    ranking = []
    for qid in qids_to_ranked_candidate_passages:
        if qid in qids_to_relevant_passageids:
            ranking.append(0)
            target_pid = qids_to_relevant_passageids[qid]
            candidate_pid = qids_to_ranked_candidate_passages[qid]
            this_MRR = 0
            target_pid = [int(x) for x in target_pid]
            candidate_pid = [int(x) for x in candidate_pid]
            # if qid in prev_fault_list:
            #     print("Faulty QID: ", qid)
            #     print("Target PIDs: ", target_pid)
            #     print("Candidate PIDs: ", candidate_pid[:MaxMRRRank])
            for i in range(0,MaxMRRRank):
                if candidate_pid[i] in target_pid:
                    curr_score = 1/(i + 1)
                    MRR += curr_score
                    this_MRR += curr_score
                    ranking.pop()
                    ranking.append(i+1)
                    break
            if this_MRR == 0:
                fault_list.append(qid)
            score_dicts.append({'QID': qid, 'Target PIDs': target_pid, 'Candidate PIDs': candidate_pid[:MaxMRRRank], 'score': this_MRR})

    if len(ranking) == 0:
        raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")
    
    MRR = MRR/len(qids_to_relevant_passageids)
    all_scores['MRR @10'] = MRR
    all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages)
    return all_scores, fault_list, score_dicts
                
def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True):
    """Compute MRR metric
    Args:    
    p_path_to_reference_file (str): path to reference file.
        Reference file should contain lines in the following format:
            QUERYID\tPASSAGEID
            Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs
    p_path_to_candidate_file (str): path to candidate file.
        Candidate file sould contain lines in the following format:
            QUERYID\tPASSAGEID1\tRank
            If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 
            QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 
            Where the values are separated by tabs and ranked in order of relevance 
    Returns:
        dict: dictionary of metrics {'MRR': <MRR Score>}
    """
    
    qids_to_relevant_passageids = load_reference(path_to_reference)
    qids_to_ranked_candidate_passages = load_candidate(path_to_candidate)
    if perform_checks:
        allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
        if message != '': print(message)

    return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)

def main():
    """Command line:
    python msmarco_eval_ranking.py <path to reference> <path_to_candidate_file>
    """
    path_to_candidate = sys.argv[2] 
    path_to_reference = sys.argv[1]
    metrics, fault_list, scr_dict = compute_metrics_from_files(path_to_reference, path_to_candidate)
    print('#####################')
    for metric in sorted(metrics):
        print('{}: {}'.format(metric, metrics[metric]))
    print('#####################')
    print('Faulty QIDs: ', len(set(fault_list)))
    # save fault_list to disk
    with open('fault_list.txt', 'w') as f:
        for item in fault_list:
            f.write("%s\n" % item)
    # save score_dicts to disk
    newfile = path_to_candidate.replace('.txt', '_score.json')
    if path_to_candidate != newfile:
        with open(newfile, 'w') as f:
            for item in scr_dict:
                f.write(json.dumps(item) + '\n')
            print('saved score dicts to disk')
    else:
        print('Could not save score dicts to disk')

    
if __name__ == '__main__':
    main()

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/00_ngame-for-msmarco'

    config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt.json'
    config_key = 'data_entity-gpt'
    
    mname = 'sentence-transformers/msmarco-distilbert-dot-v5'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/mogicX/msmarco_data_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}.joblib'

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, only_test=input_args.only_test)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,                                                                                                                                          warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
    
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    def model_fn(mname, bsz):
        model = DBT009.from_pretrained(mname, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                                       apply_softmax=True, use_encoder_parallel=True)
        return model
    
    def init_fn(model): 
        model.init_dr_head()

    metric = PrecReclMrr(block.test.dset.n_lbl, block.test.data_lbl_filterer,
                     pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    model = load_model(args.output_dir, model_fn, {"mname": mname, "bsz": bsz}, init_fn, do_inference=do_inference, use_pretrained=input_args.use_pretrained)
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.test.dset.n_lbl)
    