In [1]:
from metrics.accuracy_metrics import AccuracyMetrics
from metrics.similarity_metrics import CosineSimilarity
from tqdm import tqdm
import torch

In [2]:
from loaders.elco_dataloader import get_loaders
train_loader, test_loader = get_loaders("data/ELCo.csv", batch_size=1, shuffle=True, num_workers=0)

In [11]:
import torch.nn as nn
from sentence_transformers import SentenceTransformer


class TeacherModel():
    def __init__(self):
        self.model = SentenceTransformer('bert-base-nli-mean-tokens')

    def encode(self, x):
        return self.model.encode(x, convert_to_tensor=True)
    
    def __call__(self, x):
        return self.encode(x)

class StudentModel():
    def __init__(self):
        self.model = SentenceTransformer('bert-base-nli-mean-tokens')

    def forward(self, x):
        return self.model.encode(x, convert_to_tensor=True, )
    
    def __call__(self, x):
        return self.forward(x)

In [22]:
from torch.autograd import Variable
from itertools import chain

class ELCoMTrainer:
    def __init__(self, teacher_model, student_model, optim, lr, accuracy_metric):
        self.teacher = teacher_model
        self.student = student_model
        # self.teacher_optimizer = optim(self.teacher.model._first_module().parameters(), lr=lr)
        self.student_optimizer = optim(self.student.model._first_module().parameters(), lr=lr)
        self.accuracy_metric = accuracy_metric
        self.optim = optim
        self.mse = nn.MSELoss()
        self.cosine = nn.CosineSimilarity()
        self.similarity = CosineSimilarity()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def train(self, train_loader, epochs):
        # self.teacher.train()
        # self.teacher.to(self.device)
        # self.student.train()
        # self.student.to(self.device)
        total_loss = 0
        for epoch in range(epochs):
            for x, y in tqdm(train_loader):
                teacher_en = Variable(self.teacher(x), requires_grad=True).mean(axis=0).unsqueeze(0)
                student_en = Variable(self.student(x), requires_grad=True).mean(axis=0).unsqueeze(0)
                student_em = Variable(self.student(y), requires_grad=True)
                loss = self.mse(student_en, teacher_en) + self.mse(student_em, teacher_en)
                + self.cosine(student_en, student_em).item()
                # self.teacher_optimizer.zero_grad()
                self.student_optimizer.zero_grad()
                loss.backward()
                # self.teacher_optimizer.step()
                self.student_optimizer.step()
                total_loss += loss.item()
            print(f"Epoch {epoch + 1} loss: {total_loss / len(train_loader)}")
        print("Training finished")
    
    def test(self, test_loader):
        
        accuracy = 0
        for x, y in tqdm(test_loader):
            y = y[0]
            rank_scores = []
            y_score = self.student(y)
            for rank, sentences in x.items():
                sentences = list(chain(*sentences))
                embeddings = self.student(sentences).mean(axis=0) # simple mean
                rank_scores.append((self.similarity(embeddings, y_score).cpu().detach().numpy()[0], rank))
            sorted_rank_scores = sorted(rank_scores, key=lambda x: x[0], reverse=True)
            accuracy += self.accuracy_metric(sorted_rank_scores)
        return accuracy / len(test_loader)

In [18]:
from torch.optim import Adam
teacher_model = TeacherModel()
student_model = StudentModel()
lr = 1e-3
trainer = ELCoMTrainer(teacher_model, student_model, Adam, lr, AccuracyMetrics().accuracy)
trainer.train(train_loader, epochs=5)

100%|██████████| 168/168 [04:06<00:00,  1.47s/it]
  0%|          | 0/168 [00:00<?, ?it/s]

Epoch 1 loss: 0.3136802017688751


100%|██████████| 168/168 [04:43<00:00,  1.69s/it]
  0%|          | 0/168 [00:00<?, ?it/s]

Epoch 2 loss: 0.2646413743495941


100%|██████████| 168/168 [04:12<00:00,  1.50s/it]
  0%|          | 0/168 [00:00<?, ?it/s]

Epoch 3 loss: 0.2129564881324768


100%|██████████| 168/168 [04:02<00:00,  1.44s/it]
  0%|          | 0/168 [00:00<?, ?it/s]

Epoch 4 loss: 0.2808777689933777


100%|██████████| 168/168 [04:09<00:00,  1.48s/it]

Epoch 5 loss: 0.2949393689632416
Training finished





In [23]:
trainer = ELCoMTrainer(trainer.teacher, trainer.student, Adam, lr, AccuracyMetrics().accuracy)
trainer.test(test_loader)

100%|██████████| 42/42 [01:21<00:00,  1.94s/it]


0.150462962962963