In [6]:
# training parameters
NEPOCHS = 50
BATCH_SIZE = 256

teacher_model_name = 'msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_original'
model_name = "model-distillation-%s-%d-%d" % (teacher_model_name, NEPOCHS, BATCH_SIZE)

In [7]:
# based on the sample training script from https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/distillation/model_distillation.py
import os
from torch.utils.data import DataLoader

from sentence_transformers import losses, evaluation
from sentence_transformers import util, InputExample, SentenceTransformer
from sentence_transformers.datasets import ParallelSentencesDataset

from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

from utils import data_path, model_path, load_queries


# Models
# load the pre-trained teacher model with subword tokenizer
teacher_model = SentenceTransformer(model_path + teacher_model_name)

# load the same model as student to learn different token embeddings
student_model = SentenceTransformer(model_path + teacher_model_name)


# Data
# load source (original) questions for the teacher model and target (transcripts) sentences for the student model
questions_o = load_queries(queries_version='original')
questions_t = load_queries(queries_version='wav2vec2-base-960h')
assert len(questions_o) == len(questions_t)

# source - target question pairs
train_questions = []
dev_questions_o, dev_questions_t = [], []
for q_id in questions_o:
    if q_id[0] == 't': 
        train_questions.append([questions_o[q_id], questions_t[q_id]])
    elif q_id[0] == 'v':
        dev_questions_o.append(questions_o[q_id])
        dev_questions_t.append(questions_t[q_id])
print(train_questions[0])

assert len(dev_questions_o) == len(dev_questions_t)
print(dev_questions_o[0], dev_questions_t[0])

train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model,
                                      batch_size=BATCH_SIZE, use_embedding_cache=False)
train_data.add_dataset(train_questions, max_sentence_length=256)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)


# Training
output_path = os.path.join(model_path, model_name)
print(output_path)

train_loss = losses.MSELoss(model=student_model)

# We create an evaluator, that measure the Mean Squared Error (MSE) between the teacher and the student embeddings
dev_evaluator_mse = evaluation.MSEEvaluator(dev_questions_o, dev_questions_t, teacher_model=teacher_model)

# Train the student model to imitate the teacher
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
                  evaluator=evaluation.SequentialEvaluator([dev_evaluator_mse]),
                  epochs=NEPOCHS,
                  warmup_steps=1000,
                  evaluation_steps=5000,
                  output_path=output_path,
                  save_best_model=True,
                  optimizer_params={'lr': 1e-4, 'eps': 1e-6, 'correct_bias': False},
                  use_amp=True)

['what movie is produced by warner bros.', 'what muvi is produced by warner brothers']
Who was the trump ocean club international hotel and tower named after he was the trumpotian club international hotal and tower named after
/ivi/ilps/personal/svakule/spoken_qa/models/model-distillation-msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_original-50-256


Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

Iteration:   0%|          | 0/259 [00:00<?, ?it/s]

In [8]:
from utils import evaluate_model, model_path

# model_name = 'model-distillation-msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_original-20-128'
evaluate_model(model_path+model_name, 'entities')

/ivi/ilps/personal/svakule/spoken_qa/models/model-distillation-msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_original-50-256 cos_sim
Loaded: WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

Loaded: WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.941	0.975	0.982	0.984	0.989	0.483	0.566	0.597	0.639	0.750	
/ivi/ilps/personal/svakule/spoken_qa/models/model-distillation-msmarco-distilbert-base-tas-b-cos_sim-WD18_entities_original-50-256 dot
Loaded: WD18 entities original valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

Loaded: WD18 entities wav2vec2-base-960h valid


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Batches:   0%|          | 0/223 [00:00<?, ?it/s]

0.920	0.961	0.972	0.978	0.988	0.506	0.595	0.636	0.681	0.801	
