In [9]:
data_path = '/ivi/ilps/personal/svakule/spoken_qa'
dataset_name = 'WD18'
model_name = "msmarco-distilbert-base-tas-b"
trained_on = 'original'
validated_on = 'original'

#### Provide model save path
model_save_path = os.path.join("/ivi/ilps/personal/svakule/msmarco", "output", "{}-{}-{}".format(model_name, dataset_name, trained_on))
os.makedirs(model_save_path, exist_ok=True)

In [10]:
# load our dataset for training
# KGQA dataset from https://github.com/askplatypus/wikidata-simplequestions
from beir.datasets.data_loader import GenericDataLoader

def load_data(split='valid', questions='original'):
    query_path = os.path.join(data_path, dataset_name, "%s_%s.jsonl" % (split, questions))  # original text questions
    # query_path = data_path + dataset + "wav2vec2-base-960h.jsonl"  # questions transcribed from synthethised speech
    qrels_path = os.path.join(data_path, dataset_name, "%s.tsv" % split)
    corpus_path = os.path.join(data_path, dataset_name, "entities.jsonl")
    return GenericDataLoader(corpus_file=corpus_path, query_file=query_path, qrels_file=qrels_path).load_custom()

corpus, queries, qrels = load_data(split='train', questions=trained_on)
dev_corpus, dev_queries, dev_qrels = load_data(split='valid', questions=validated_on)

2021-06-18 21:02:10 - Loaded 28497 Documents.
2021-06-18 21:02:10 - Doc Example: {'text': 'Mirosław Bork', 'title': ''}
2021-06-18 21:02:10 - Loaded 22719 Queries.
2021-06-18 21:02:10 - Query Example: who is a musician born in detroit
2021-06-18 21:02:10 - Loaded 28497 Documents.
2021-06-18 21:02:10 - Doc Example: {'text': 'Mirosław Bork', 'title': ''}
2021-06-18 21:02:10 - Loaded 2811 Queries.
2021-06-18 21:02:10 - Query Example: What is a film directed by wiebke von carolsfeld?


In [None]:
# corrupt queries


In [11]:
# sample training script from https://github.com/UKPLab/beir/blob/main/examples/retrieval/training/train_sbert.py
from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever
import pathlib, os

#### Provide any sentence-transformers or HF model
# model_name = "distilbert-base-uncased" 
# word_embedding_model = models.Transformer(model_name, max_seq_length=350)
# pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
# model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

#### Or provide pretrained sentence-transformer model
model = SentenceTransformer(model_name)

retriever = TrainRetriever(model=model, batch_size=16)

#### Prepare training samples
train_samples = retriever.load_train(corpus, queries, qrels)
train_dataloader = retriever.prepare_train(train_samples, shuffle=True)

#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
#### training SBERT with dot-product
# train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)

#### Prepare dev evaluator
ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

#### If no dev set is present from above use dummy evaluator
# ir_evaluator = retriever.load_dummy_evaluator()

#### Configure Train params
num_epochs = 1
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                use_amp=True)

2021-06-18 21:02:10 - Load pretrained SentenceTransformer: msmarco-distilbert-base-tas-b
2021-06-18 21:02:10 - Did not find folder msmarco-distilbert-base-tas-b
2021-06-18 21:02:10 - Search model on server: http://sbert.net/models/msmarco-distilbert-base-tas-b.zip
2021-06-18 21:02:10 - Load SentenceTransformer from folder: /home/svakule/.cache/torch/sentence_transformers/sbert.net_models_msmarco-distilbert-base-tas-b
2021-06-18 21:02:11 - Use pytorch device: cuda


HBox(children=(FloatProgress(value=0.0, description='Adding Input Examples', max=1420.0, style=ProgressStyle(d…


2021-06-18 21:02:11 - Loaded 22719 training pairs.
2021-06-18 21:02:11 - eval set contains 28497 documents and 2811 queries
2021-06-18 21:02:11 - Starting to Train...


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=1.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1420.0, style=ProgressStyle(description_w…


2021-06-18 21:04:29 - Information Retrieval Evaluation on eval dataset after epoch 0:
2021-06-18 21:04:38 - Queries: 2811
2021-06-18 21:04:38 - Corpus: 28497

2021-06-18 21:04:38 - Score-Function: cos_sim
2021-06-18 21:04:38 - Accuracy@1: 95.30%
2021-06-18 21:04:38 - Accuracy@3: 98.04%
2021-06-18 21:04:38 - Accuracy@5: 98.26%
2021-06-18 21:04:38 - Accuracy@10: 98.51%
2021-06-18 21:04:38 - Precision@1: 95.30%
2021-06-18 21:04:38 - Precision@3: 32.68%
2021-06-18 21:04:38 - Precision@5: 19.65%
2021-06-18 21:04:38 - Precision@10: 9.85%
2021-06-18 21:04:38 - Recall@1: 95.30%
2021-06-18 21:04:38 - Recall@3: 98.04%
2021-06-18 21:04:38 - Recall@5: 98.26%
2021-06-18 21:04:38 - Recall@10: 98.51%
2021-06-18 21:04:38 - MRR@10: 0.9671
2021-06-18 21:04:38 - NDCG@10: 0.9716
2021-06-18 21:04:38 - MAP@100: 0.9673
2021-06-18 21:04:38 - Score-Function: dot_score
2021-06-18 21:04:38 - Accuracy@1: 95.70%
2021-06-18 21:04:38 - Accuracy@3: 98.04%
2021-06-18 21:04:38 - Accuracy@5: 98.22%
2021-06-18 21:04:38 