# Choose baseline

In [3]:
from utils import load_data

split = 'valid'

data_path = '/ivi/ilps/personal/svakule/spoken_qa'
dataset_versions = ['WD18_entities_original',
                    'WD18_entities_wav2vec2-base-960h',
                    'WD18_relations_original',
                    'WD18_relations_wav2vec2-base-960h']

In [8]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models

baselines = [["msmarco-distilbert-base-v2", ''],
             ["msmarco-distilbert-base-v3", 'cos_sim'],
             ["msmarco-roberta-base-v3", 'cos_sim'],
             ["msmarco-distilbert-base-dot-prod-v3", 'dot'],
             ["msmarco-distilbert-base-tas-b", 'dot']]

similarities = ['cos_sim', 'dot']


for model_save_path, original_similarity in baselines[1:2]:
    # load model
    model = DRES(models.SentenceBERT(model_save_path))
    
    # wiggle similarity function at inference time
    for similarity in similarities:
        print(model_save_path, original_similarity, similarity)
        retriever = EvaluateRetrieval(model, score_function=similarity)

        # iterate over all dataset versions
        for dataset_version in dataset_versions[:]:
            corpus, queries, qrels = load_data(split, dataset_version, data_path)

            results = retriever.retrieve(corpus, queries)
            ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
            print('%.3f\t'%precision['P@1'])
    #         print('%.3f\t'*5 % (recall['Recall@1'], recall['Recall@3'],
    #                             recall['Recall@5'], recall['Recall@10'],
    #                             recall['Recall@100']))
            break

msmarco-distilbert-base-v3 cos_sim cos_sim
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.915	
msmarco-distilbert-base-v3 cos_sim dot
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.916	


# Evaluate

In [3]:
import pathlib, os
from beir.datasets.data_loader import GenericDataLoader

data_path = '/ivi/ilps/personal/svakule/spoken_qa'
split = 'valid'
dataset_versions = ['WD18_entities_original',
                    'WD18_entities_wav2vec2-base-960h',
                    'WD18_relations_original',
                    'WD18_relations_wav2vec2-base-960h']

In [6]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models

from utils import load_data

model_path = '/ivi/ilps/personal/svakule/spoken_qa/models/'
model_names = [["msmarco-distilbert-base-v3", 'cos_sim'],
               ["msmarco-distilbert-base-tas-b", 'dot'],
               [model_path+"msmarco-distilbert-base-v3-cos_sim-WD18_entities_original", 'cos_sim'],
               [model_path+"msmarco-distilbert-base-tas-b-dot-WD18_entities_original", 'dot'],
               [model_path+"msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_original", 'cos_sim'],
               [model_path+"msmarco-distilbert-base-tas-b-dot-WD18_relations_original", 'dot'],
               [model_path+"msmarco-distilbert-base-tas-b-cos_sim-WD18_relations_original", 'cos_sim']]

# similarities = ['cos_sim', 'dot']


# iterate over models
for model_name, similarity in model_names[:]:
    # load model
    model = DRES(models.SentenceBERT(model_name))
    
    # wiggle similarity function at inference time
#     for similarity in similarities:
    print(model_name, similarity)
    retriever = EvaluateRetrieval(model, score_function=similarity)

    # iterate over datasets
    for dataset_version in dataset_versions[:]:
        corpus, queries, qrels = load_data(split, dataset_version, data_path)

        results = retriever.retrieve(corpus, queries)
        ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
        print('%.3f\t'*5 % (recall['Recall@1'], recall['Recall@3'],
                            recall['Recall@5'], recall['Recall@10'],
                            recall['Recall@100']))
#         print('%.3f\t'%precision['P@1'])

msmarco-distilbert-base-v3 cos_sim
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.915	0.961	0.970	0.978	0.988	
WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.428	0.519	0.551	0.593	0.719	
WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.178	0.271	0.324	0.386	0.653	
WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.122	0.196	0.232	0.287	0.518	
msmarco-distilbert-base-tas-b dot
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.921	0.968	0.977	0.984	0.990	
WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.452	0.544	0.578	0.623	0.750	
WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.298	0.452	0.506	0.572	0.801	
WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.206	0.321	0.371	0.429	0.640	
/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-v3-cos_sim-WD18_entities_original cos_sim
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.953	0.978	0.983	0.987	0.990	
WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.456	0.539	0.571	0.616	0.734	
WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.043	0.079	0.100	0.137	0.343	
WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.031	0.061	0.078	0.109	0.270	
/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-tas-b-dot-WD18_entities_original dot
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.939	0.974	0.979	0.983	0.989	
WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.433	0.512	0.542	0.579	0.698	
WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.134	0.231	0.292	0.360	0.591	
WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.100	0.170	0.212	0.263	0.449	
/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_original cos_sim
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.954	0.981	0.984	0.986	0.990	
WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.460	0.551	0.586	0.620	0.738	
WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.045	0.095	0.123	0.167	0.418	
WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.032	0.065	0.087	0.125	0.315	
/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-tas-b-dot-WD18_relations_original dot
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.223	0.347	0.416	0.521	0.823	
WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.142	0.207	0.236	0.288	0.490	
WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.815	0.910	0.937	0.963	0.989	
WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.621	0.717	0.751	0.787	0.847	
/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-tas-b-cos_sim-WD18_relations_original cos_sim
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.101	0.129	0.146	0.169	0.280	
WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.084	0.116	0.128	0.145	0.223	
WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.719	0.873	0.925	0.956	0.986	
WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.547	0.678	0.729	0.772	0.842	
