In [20]:
import pathlib, os

data_path = '/ivi/ilps/personal/svakule/spoken_qa'
dataset_version = 'WD18_entities_original'  # entities relations wav2vec2-base-960h original

model_name = "msmarco-distilbert-base-v3"  # msmarco-distilbert-base-tas-b msmarco-distilbert-base-v3
similarity = 'dot'  # dot cos_sim

#### Provide model save path
model_save_path = os.path.join(data_path, "models", "{}-{}-{}".format(model_name, similarity, dataset_version))
os.makedirs(model_save_path, exist_ok=True)
print(model_save_path)

/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-v3-dot-WD18_entities_original


# Train

In [21]:
# sample training script from https://github.com/UKPLab/beir/blob/main/examples/retrieval/training/train_sbert.py
from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever

from utils import load_data

corpus, queries, qrels = load_data('train', dataset_version, data_path)

#### Provide any sentence-transformers or HF model
# model_name = "distilbert-base-uncased" 
# word_embedding_model = models.Transformer(model_name, max_seq_length=350)
# pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
# model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

#### Or provide pretrained sentence-transformer model
model = SentenceTransformer(model_name)

retriever = TrainRetriever(model=model, batch_size=16)

#### Prepare training samples
train_samples = retriever.load_train(corpus, queries, qrels)
train_dataloader = retriever.prepare_train(train_samples, shuffle=True)

#### Training SBERT with cosine-product
if similarity == 'cos_sim':
    train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
#### training SBERT with dot-product
elif similarity == 'dot':
    train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)

#### Prepare dev evaluator
# ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()

#### Configure Train params
num_epochs = 1
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                                evaluator=ir_evaluator, 
                                epochs=num_epochs,
                                output_path=model_save_path,
                                warmup_steps=warmup_steps,
                                evaluation_steps=evaluation_steps,
                                use_amp=True)
print(model_save_path, 'trained.')

WD18 entities original train


Adding Input Examples:   0%|          | 0/1420 [00:00<?, ?it/s]

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1420 [00:00<?, ?it/s]

/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-v3-dot-WD18_entities_original trained.


# Evaluate

In [22]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models

# reload previously trained model instead
# from utils import load_data
# data_path = '/ivi/ilps/personal/svakule/spoken_qa'
# model_save_path = data_path + "/models/msmarco-distilbert-base-tas-b-dot-WD18_relations_original"

split = 'valid'
dataset_versions = ['WD18_entities_original',
                    'WD18_entities_wav2vec2-base-960h',
                    'WD18_relations_original',
                    'WD18_relations_wav2vec2-base-960h']

# change similarity at inference time
similarity = 'dot'  # cos_sim dot
print(model_save_path, similarity)

# load model
model = DRES(models.SentenceBERT(model_save_path))
retriever = EvaluateRetrieval(model, score_function=similarity)

# iterate over all dataset versions
metrics = [] 
for dataset_version in dataset_versions[:]:
    corpus, queries, qrels = load_data(split, dataset_version, data_path)

    results = retriever.retrieve(corpus, queries)
    ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
    metrics.extend([recall['Recall@1'], recall['Recall@3'],
                    recall['Recall@5'], recall['Recall@10'], recall['Recall@100']])
    
#     break

print('%.3f\t'*5*len(dataset_versions) % tuple(metrics))

/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-v3-dot-WD18_entities_original dot
WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

WD18 relations original valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

WD18 relations wav2vec2-base-960h valid


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.911	0.967	0.975	0.981	0.987	0.368	0.451	0.481	0.525	0.648	0.035	0.054	0.065	0.082	0.174	0.025	0.037	0.043	0.054	0.117	
