In [4]:
import pathlib, os
from utils import load_data, data_path, model_path

dataset_version = 'WD18_entities_wav2vec2-base-960h'  # entities relations wav2vec2-base-960h original

# msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_wav2vec2-base-960h
model_name = "msmarco-distilbert-base-tas-b"  # msmarco-distilbert-base-tas-b msmarco-distilbert-base-v3
# model_name = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
similarity = 'cos_sim'  # dot cos_sim

BATCH_SIZE = 512
NEPOCHS = 512
LABEL = ''  # QE

#### Provide model save path
new_model_name = "{}-{}-{}-{}-{}-{}".format(model_name.split('/')[-1], similarity, dataset_version,
                                            NEPOCHS, BATCH_SIZE, LABEL)
model_save_path = os.path.join(model_path, new_model_name)
os.makedirs(model_save_path, exist_ok=True)
print(model_save_path)

# load dataset
corpus, queries, qrels = load_data(dataset_version, 'train', data_path)

if LABEL == 'QE':
    # prepand questions and corpus with special tokens for each type
    for q_id, q in corpus.items():
        corpus[q_id]['text'] = '[E]' + q['text']
    for q_id, q in queries.items():
        queries[q_id] = '[Q]' + q

/ivi/ilps/personal/svakule/spoken_qa/models/msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_wav2vec2-base-960h-512-512-
Loaded: WD18 entities wav2vec2-base-960h train


# Train

In [None]:
# sample training script from https://github.com/UKPLab/beir/blob/main/examples/retrieval/training/train_sbert.py
import torch
import torch.nn as nn

from sentence_transformers import losses, models, SentenceTransformer

from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever


if LABEL == 'QE':
    # load HF model
    word_embedding_model = models.Transformer(model_name, max_seq_length=350)
    
    # add special tokens
    special_tokens_dict = {'additional_special_tokens': ['[Q]','[E]','[R]']}
    num_added_toks = word_embedding_model.tokenizer.add_special_tokens(special_tokens_dict)
    word_embedding_model.auto_model.resize_token_embeddings(len(word_embedding_model.tokenizer))

    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

else:
    # pretrained sentence-transformer model
    model = SentenceTransformer(model_name)

retriever = TrainRetriever(model=model, batch_size=BATCH_SIZE)

# send model to the GPUs
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
retriever = nn.DataParallel(retriever)
retriever.to('cuda')

#### Prepare training samples
train_samples = retriever.module.load_train(corpus, queries, qrels)
train_dataloader = retriever.module.prepare_train(train_samples, shuffle=True)

#### Training SBERT with cosine-product
if similarity == 'cos_sim':
    train_loss = losses.MultipleNegativesRankingLoss(model=retriever.module.model)
#### training SBERT with dot-product
elif similarity == 'dot':
    train_loss = losses.MultipleNegativesRankingLoss(model=retriever.module.model, similarity_fct=util.dot_score)

#### Prepare dev evaluator
# ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.module.load_dummy_evaluator()

#### Configure Train params
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * NEPOCHS / retriever.module.batch_size * 0.1)

retriever.module.fit(train_objectives=[(train_dataloader, train_loss)], 
                                evaluator=ir_evaluator, 
                                epochs=NEPOCHS,
                                output_path=model_save_path,
                                warmup_steps=warmup_steps,
                                evaluation_steps=evaluation_steps,
                                save_best_model=True,
                                use_amp=True)
print(model_save_path, 'trained.')

Let's use 2 GPUs!


Adding Input Examples:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch:   0%|          | 0/512 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

# Evaluate

In [None]:
from utils import evaluate_model, model_path

# model_name = 'msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_wav2vec2-base-960h'
evaluate_model(model_path+new_model_name, 'entities')