In [5]:
import pathlib, os
from utils import model_path, data_path

model_name = "msmarco-distilbert-base-tas-b"  # msmarco-distilbert-base-tas-b msmarco-distilbert-base-v3
similarity = 'cos_sim'  # dot cos_sim

dataset = 'msmarco'
data_path = os.path.join(data_path, dataset)

BATCH_SIZE = 512
NEPOCHS = 512
LABEL = 'augmented'

#### Provide model save path
new_model_name = "{}-{}-{}-{}-{}-{}".format(model_name.split('/')[-1], similarity, dataset,
                                            NEPOCHS, BATCH_SIZE, LABEL)
model_save_path = os.path.join(model_path, new_model_name)
os.makedirs(model_save_path, exist_ok=True)
print(model_save_path)

/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-tas-b-cos_sim-msmarco-512-512-augmented


In [None]:
# load dataset for training
from beir.datasets.data_loader import GenericDataLoader

corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")

print(len(queries), 'original queries')

In [None]:
# corrupt queries
import augly.text as textaugs

nqueries = 3000

corrupted_queries, new_qrels = {}, {}
for q_id, q in queries.items():
    new_q = textaugs.simulate_typos(q)
    corrupted_queries[q_id] = new_q  # just overwrite the original query
    new_qrels[q_id] = qrels[q_id]
    
print(len(corrupted_queries), 'corrupted queries')

queries = corrupted_queries
qrels = new_qrels
print(len(queries), 'corrupted queries')

In [None]:
# sample training script from https://github.com/UKPLab/beir/blob/main/examples/retrieval/training/train_sbert.py
from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever

#### Provide any sentence-transformers or HF model
# model_name = "distilbert-base-uncased" 
# word_embedding_model = models.Transformer(model_name, max_seq_length=350)
# pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
# model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

#### Or provide pretrained sentence-transformer model
model = SentenceTransformer(model_name)

retriever = TrainRetriever(model=model, batch_size=16)

#### Prepare training samples
train_samples = retriever.load_train(corpus, queries, qrels)
train_dataloader = retriever.prepare_train(train_samples, shuffle=True)

#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
#### training SBERT with dot-product
# train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)

#### Prepare dev evaluator
# ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()

#### Configure Train params
num_epochs = 1
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                save_best_model=True,
                use_amp=True)

In [None]:
# evaluate the model
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

model = DRES(models.SentenceBERT(model_save_path))
similarities = ['cos_sim', 'dot']

# wiggle similarity function at inference time
for similarity in similarities:
    print(model_name, similarity)
    retriever = EvaluateRetrieval(model, score_function=similarity)

    corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

    results = retriever.retrieve(corpus, queries)
    ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
    print(ndcg, _map, recall, precision)