In [2]:
# Python utils
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import os
import random
import sys


# DL utils
import torch;
import transformers;

from torch import nn
from torch.nn import TripletMarginLoss
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModel

from pytorch_pretrained_bert import BertTokenizer, BertModel

In [3]:
def get_transformer_tok_and_model():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = AutoModel.from_pretrained("bert-base-uncased")


    train_on_gpu = True
    if train_on_gpu:
        model.to('cuda')
    
    return tokenizer, model


def embed_text_batch_using_bert(text_list, aggregation='cls'):
    input_ids = torch.tensor(tokenizer.batch_encode_plus(data_batch, pad_to_max_length=True)['input_ids'])
    outputs = model(input_ids.to('cuda'))
    last_hidden_states = outputs[0].cpu()  # The last hidden-state is the first element of the output tuple
    
    if aggregation == 'cls':
        # return the vec of the CLS token
        return last_hidden_states[:,0,:].detach().numpy()
    elif aggregation == 'mean':
        # return the mean of all token vecs
        return last_hidden_states.mean(1).detach().numpy()

### Custom Dataset class

Needs to implement init, len, getitem

In [None]:
class QuoraQuestionsDataset(Dataset):
    def __init__(self, data):
        self.question_tok_triples = data
        
    def __len__(self):
        return len(self.question_tok_triples)

    def __getitem__(self, idx):
        return self.question_tok_triples[idx]

### Network

In [4]:
class CLNetwork(nn.Module):
    def __init__(self, emb_dim=768):
        super(CLNetwork, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
    def forward(self, x):
        x = self.bert(x)
        return x

In [5]:
def init_weights(m):
    torch.nn.init.kaiming_normal_(m.weight)

In [6]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum()
    
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = self.calc_euclidean(anchor, positive)
        distance_negative = self.calc_euclidean(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()

In [8]:
device = 'cuda:1'

model = CLNetwork(768)
# model = torch.jit.script(model).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = TripletLoss()