In [2]:
!pip install -q transformers

[K     |████████████████████████████████| 3.1 MB 12.8 MB/s 
[K     |████████████████████████████████| 59 kB 6.4 MB/s 
[K     |████████████████████████████████| 3.3 MB 33.0 MB/s 
[K     |████████████████████████████████| 596 kB 45.0 MB/s 
[K     |████████████████████████████████| 895 kB 39.2 MB/s 
[?25h

## SentenceBERT for NLI objective

In [3]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig, AutoTokenizer

class SBERT(nn.Module):
    def __init__(self, config, dimension, device, max_length):
        super(SBERT, self).__init__()
        self.config = config
        self.dim = dimension
        self.device = device
        self.max_length = max_length

        # Initialize the model
        self.config = AutoConfig.from_pretrained(config)
        self.tokenizer = AutoTokenizer.from_pretrained(config)
        self.bert_net = AutoModel.from_pretrained(config).to(self.device)
        self.output_layer = nn.Sequential(
            nn.Tanh(),
            nn.Linear(self.dim * 3, 3),
        ).to(device)

    def forward(self, sentence1, sentence2, pooling="mean"):
        # Encode the first sentence 
        seq_indexed_1 = self.tokenizer(sentence1, padding='max_length', truncation=True, max_length=self.max_length)
        input_ids_1 = torch.tensor(seq_indexed_1['input_ids'], dtype=torch.int64).to(self.device)
        att_mask_1 = torch.tensor(seq_indexed_1['attention_mask'], dtype=torch.int64).to(self.device)
        embedding1 = self.bert_net(input_ids_1, attention_mask=att_mask_1)[0]

        # Encode the second sentence 
        seq_indexed_2 = self.tokenizer(sentence2, padding='max_length', truncation=True, max_length=self.max_length)
        input_ids_2 = torch.tensor(seq_indexed_2['input_ids'], dtype=torch.int64).to(self.device)
        att_mask_2 = torch.tensor(seq_indexed_2['attention_mask'], dtype=torch.int64).to(self.device)
        embedding2 = self.bert_net(input_ids_2, attention_mask=att_mask_2)[0]
        
        # Pooling layer
        if pooling == "mean":
          u = self.mean_pooling_strategy(embedding1, att_mask_1)
          v = self.mean_pooling_strategy(embedding2, att_mask_2)
          
        elif pooling == "max":
          u = self.max_pooling_strategy(embedding1, att_mask_1)
          v = self.max_pooling_strategy(embedding2, att_mask_2)

        elif pooling == "cls":
          u = self.cls_pooling_strategy(embedding1, att_mask_1)
          v = self.cls_pooling_strategy(embedding2, att_mask_2)
        else: 
          raise ValueError("Pooling should be mean, max or cls.")

        return self.output_layer(torch.cat((u, v, torch.abs(u - v)), 1)).to(self.device)

    # Three pooling strategies

    def mean_pooling_strategy(self, sentence_embedding, att_mask):
      expanded_att_mask = att_mask.unsqueeze(-1).expand(sentence_embedding.size()).float()
      sum_embedding = (sentence_embedding * expanded_att_mask).sum(1)
      sum_att_mask = torch.clamp(expanded_att_mask.sum(1), min=1e-9)
      return sum_embedding / sum_att_mask

    def max_pooling_strategy(self, sentence_embedding, att_mask):
      expanded_att_mask = att_mask.unsqueeze(-1).expand(sentence_embedding.size()).float()
      sentence_embedding[expanded_att_mask == 0] = -1e9
      return torch.max(sentence_embedding, 1)[0]

    def cls_pooling_strategy(self, sentence_embedding, att_mask):
      return sentence_embedding[:,0]

    

In [11]:
from torch.optim import Adam
import numpy as np
import random
import math

np.random.seed(233)

def load_snli_data(path, max_num=None):
  sentence1, sentence2, labels = [], [], []
  label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
  invalid_row = 0

  with open(path, 'r', encoding='utf-8') as f:
    for line in f:
      row = line.strip().split('\t')
      if row[0] == 'gold_label':
        continue
      try:
        label_id = label2int[row[0]]
        labels.append(label_id)
        sentence1.append(row[5])
        sentence2.append(row[6])
      except:
        invalid_row += 1
      if max_num!=None and len(labels)>=max_num:
        break

    random.shuffle(sentence1)
    random.shuffle(sentence2)
    random.shuffle(labels)

    return sentence1, sentence2, labels, invalid_row

def train(model_config, dim, batch_size, learning_rate, max_length, max_num, loss_function, pooling, device, train_data_path, model_save_path):

    # Load and pre-process the training dataset
    sentence1, sentence2, labels, invalid_row = load_snli_data(train_data_path, max_num=max_num)
    print('Data Scale:', len(labels))
    print('Invalid Row:', invalid_row)

    # Initialization
    model = SBERT(config, dim, device, max_length)
    model.to(device)
    train_loss = loss_function
    train_loss.to(device)
    optimizer = Adam(params=[{'params': model.bert_net.parameters(), 'lr': learning_rate},
                  {'params': model.output_layer.parameters(), 'lr': learning_rate},], lr=learning_rate)
    
    # Warm up + Cosine Anneal
    warm_up_iter = len(labels)//10
    t_max = 50
    lr_max = 0.1	
    lr_min = 1e-5	

    lambda0 = lambda cur_iter: 1
    lambda1 = lambda cur_iter: cur_iter / warm_up_iter if  cur_iter < warm_up_iter else \
            (lr_min + 0.5*(lr_max-lr_min)*(1.0+math.cos( (cur_iter-warm_up_iter)/(t_max-warm_up_iter)*math.pi)))/0.1

    WarmUp_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda0, lambda1])
    optimizer.zero_grad()


    # Training the model
    
    for i in range(0, len(labels), batch_size):
        if i + batch_size <= len(labels):
          features = model(sentence1[i: i+batch_size], sentence2[i: i+batch_size], pooling).to(device)
          loss = train_loss(features, torch.tensor(labels[i: i+batch_size], dtype=torch.int64).to(device)).to(device)
        else:
          features = model(sentence1[i:], sentence2[i:], pooling).to(device)
          loss = train_loss(features, torch.tensor(labels[i:], dtype=torch.int64).to(device)).to(device)
        loss.backward()
        optimizer.step()
        WarmUp_scheduler.step()
        optimizer.zero_grad()
        print('batch:{0}    loss:{1}'.format(i//batch_size, round(float(loss),5)))
    torch.save(model.state_dict(), model_save_path)


In [12]:
config = 'bert-base-uncased'
dim = 768
batch_size = 16
learning_rate = 2e-5
max_length = 30
max_num = 1000
loss_function = nn.CrossEntropyLoss()
pooling = 'mean'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_data_path = '/content/drive/MyDrive/sbert_data/snli_1.0_train.txt'
model_save_path = 'SentenceBERT_SNLI_1t'

train( config, 
    dim, 
    batch_size, 
    learning_rate, 
    max_length, 
    max_num,
    loss_function,
    pooling, 
    device, 
    train_data_path, 
    model_save_path)


Data Scale: 1000
Invalid Row: 2


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


batch:0    loss:1.142
batch:1    loss:1.11523
batch:2    loss:1.03957
batch:3    loss:1.09322
batch:4    loss:1.12638
batch:5    loss:1.09815
batch:6    loss:1.13145
batch:7    loss:1.08921
batch:8    loss:1.15764
batch:9    loss:1.06695
batch:10    loss:1.20151
batch:11    loss:1.09614
batch:12    loss:1.11692
batch:13    loss:1.13099
batch:14    loss:1.1068
batch:15    loss:1.11822
batch:16    loss:1.11792
batch:17    loss:1.1141
batch:18    loss:1.06325
batch:19    loss:1.07677
batch:20    loss:1.14935
batch:21    loss:1.07231
batch:22    loss:1.13179
batch:23    loss:1.05522
batch:24    loss:1.1208
batch:25    loss:1.06678
batch:26    loss:1.13989
batch:27    loss:1.11314
batch:28    loss:1.11799
batch:29    loss:1.11564
batch:30    loss:1.07944
batch:31    loss:1.16056
batch:32    loss:1.10456
batch:33    loss:1.10476
batch:34    loss:1.08606
batch:35    loss:1.12279
batch:36    loss:1.10558
batch:37    loss:1.09366
batch:38    loss:1.11663
batch:39    loss:1.07725
batch:40    los

## Extra Section Ⅰ: Evaluation of SentenceBert NLI

In [13]:
def evaluator(model_config, dim, device, max_length, batch_size, sample_num, test_data_path, model_load_path):

    sentence1, sentence2, labels, invalid_row = load_snli_data(test_data_path, max_num=sample_num)
    print('Predicted Samples:', len(labels))
    print('Invalid Row:', invalid_row)
    model = SBERT(config, dim, device, max_length)
    model.to(device)
    parameters = torch.load(model_load_path)
    model.load_state_dict(parameters)
    model.eval()

    predicted_labels = []
    for i in range(0, len(labels), batch_size):
      if i + batch_size <= len(labels):
          prediction = model(sentence1[i: i + batch_size], sentence2[i: i + batch_size]).to(device)
      else:
          prediction = model(sentence1[i:], sentence2[i:]).to(device)

      predicted_labels += prediction.max(1)[1].tolist()

    accuracy = 0
    for i, j in zip(predicted_labels, labels):
      if i == j:
        accuracy += 1
    print('Accuracy Rate:', accuracy/len(labels))
        

In [19]:
config = 'bert-base-uncased'
dim = 768
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_length = 30
batch_size = 16
sample_num = 100
test_data_path = '/content/drive/MyDrive/sbert_data/snli_1.0_test.txt'
model_load_path = '/content/SentenceBERT_SNLI_1t'

evaluator( config, 
      dim,
      device, 
      max_length, 
      batch_size,
      sample_num, 
      test_data_path, 
      model_load_path)

Predicted Samples: 100
Invalid Row: 1


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy Rate: 0.36


## Extra Section Ⅱ: SentenceBert model for STS(cosine similarity)

In [5]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig, AutoTokenizer

class SBERT_STS(nn.Module):
    def __init__(self, config, dimension, device, max_length):
        super(SBERT_STS, self).__init__()
        self.config = config
        self.dim = dimension
        self.device = device
        self.max_length = max_length

        # Initialize the model
        self.config = AutoConfig.from_pretrained(config)
        self.tokenizer = AutoTokenizer.from_pretrained(config)
        self.bert_net = AutoModel.from_pretrained(config).to(self.device)
       

    def forward(self, sentence1, sentence2, pooling="mean"):
        # Encode the first sentence 
        seq_indexed_1 = self.tokenizer(sentence1, padding='max_length', truncation=True, max_length=self.max_length)
        input_ids_1 = torch.tensor(seq_indexed_1['input_ids'], dtype=torch.int64).to(self.device)
        att_mask_1 = torch.tensor(seq_indexed_1['attention_mask'], dtype=torch.int64).to(self.device)
        embedding1 = self.bert_net(input_ids_1, attention_mask=att_mask_1)[0]

        # Encode the second sentence 
        seq_indexed_2 = self.tokenizer(sentence2, padding='max_length', truncation=True, max_length=self.max_length)
        input_ids_2 = torch.tensor(seq_indexed_2['input_ids'], dtype=torch.int64).to(self.device)
        att_mask_2 = torch.tensor(seq_indexed_2['attention_mask'], dtype=torch.int64).to(self.device)
        embedding2 = self.bert_net(input_ids_2, attention_mask=att_mask_2)[0]
        
        # Pooling layer
        if pooling == "mean":
          u = self.mean_pooling_strategy(embedding1, att_mask_1)
          v = self.mean_pooling_strategy(embedding2, att_mask_2)
          
        elif pooling == "max":
          u = self.max_pooling_strategy(embedding1, att_mask_1)
          v = self.max_pooling_strategy(embedding2, att_mask_2)

        elif pooling == "cls":
          u = self.cls_pooling_strategy(embedding1, att_mask_1)
          v = self.cls_pooling_strategy(embedding2, att_mask_2)
        else: 
          raise ValueError("Pooling should be mean, max or cls.")

        # Compute cosine similarity
        return (u * v).sum(1) / torch.sqrt( pow(u,2) + pow(v,2) ).sum(1)

    # Three pooling strategies

    def mean_pooling_strategy(self, sentence_embedding, att_mask):
      expanded_att_mask = att_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float()
      sum_embedding = (sentence_embedding * expanded_att_mask).sum(1)
      sum_att_mask = torch.clamp(expanded_att_mask.sum(1), min=1e-9)
      return sum_embedding / sum_att_mask

    def max_pooling_strategy(self, sentence_embedding, att_mask):
      expanded_att_mask = att_mask.unsqueeze(-1).expand(sentence_embedding.size()).float()
      sentence_embedding[expanded_att_mask == 0] = -1e9
      return torch.max(sentence_embedding, 1)[0]

    def cls_pooling_strategy(self, sentence_embedding, att_mask):
      return sentence_embedding[:,0]

## Extra Section Ⅲ: SentenceBert model for triplet objective

In [8]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig, AutoTokenizer

class SBERT_TRI(nn.Module):
    def __init__(self, config, dimension, device, max_length):
        super(SBERT_TRI, self).__init__()
        self.config = config
        self.dim = dimension
        self.device = device
        self.max_length = max_length

        # Initialize the model
        self.config = AutoConfig.from_pretrained(config)
        self.tokenizer = AutoTokenizer.from_pretrained(config)
        self.bert_net = AutoModel.from_pretrained(config).to(self.device)
       

    def forward(self, sentence1, sentence2, sentence3, pooling="mean"):
        # Encode the first sentence 
        seq_indexed_1 = self.tokenizer(sentence1, padding='max_length', truncation=True, max_length=self.max_length)
        input_ids_1 = torch.tensor(seq_indexed_1['input_ids'], dtype=torch.int64).to(self.device)
        att_mask_1 = torch.tensor(seq_indexed_1['attention_mask'], dtype=torch.int64).to(self.device)
        embedding1 = self.bert_net(input_ids_1, attention_mask=att_mask_1)[0]

        # Encode the second sentence 
        seq_indexed_2 = self.tokenizer(sentence2, padding='max_length', truncation=True, max_length=self.max_length)
        input_ids_2 = torch.tensor(seq_indexed_2['input_ids'], dtype=torch.int64).to(self.device)
        att_mask_2 = torch.tensor(seq_indexed_2['attention_mask'], dtype=torch.int64).to(self.device)
        embedding2 = self.bert_net(input_ids_2, attention_mask=att_mask_2)[0]

        # Encode the third sentence
        seq_indexed_3 = self.tokenizer(sentence3, padding='max_length', truncation=True, max_length=self.max_length)
        input_ids_3 = torch.tensor(seq_indexed_3['input_ids'], dtype=torch.int64).to(self.device)
        att_mask_3 = torch.tensor(seq_indexed_3['attention_mask'], dtype=torch.int64).to(self.device)
        embedding3 = self.bert_net(input_ids_3, attention_mask=att_mask_3)[0]
        
        # Pooling layer
        if pooling == "mean":
          anchor_sentence = self.mean_pooling_strategy(embedding1, att_mask_1)
          pos_sentence = self.mean_pooling_strategy(embedding2, att_mask_2)
          neg_sentence = self.mean_pooling_strategy(embedding3, att_mask_3)
          
        elif pooling == "max":
          anchor_sentence = self.max_pooling_strategy(embedding1, att_mask_1)
          pos_sentence = self.max_pooling_strategy(embedding2, att_mask_2)
          neg_sentence = self.max_pooling_strategy(embedding3, att_mask_3)
         
        elif pooling == "cls":
          anchor_sentence = self.cls_pooling_strategy(embedding1, att_mask_1)
          pos_sentence = self.cls_pooling_strategy(embedding2, att_mask_2)
          neg_sentence = self.cls_pooling_strategy(embedding3, att_mask_3)
          
        else: 
          raise ValueError("Pooling should be mean, max or cls.")
          
        return anchor_sentence, pos_sentence, neg_sentence

    # Three pooling strategies

    def mean_pooling_strategy(self, sentence_embedding, att_mask):
      expanded_att_mask = att_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float()
      sum_embedding = (sentence_embedding * expanded_att_mask).sum(1)
      sum_att_mask = torch.clamp(expanded_att_mask.sum(1), min=1e-9)
      return sum_embedding / sum_att_mask

    def max_pooling_strategy(self, sentence_embedding, att_mask):
      expanded_att_mask = att_mask.unsqueeze(-1).expand(sentence_embedding.size()).float()
      sentence_embedding[expanded_att_mask == 0] = -1e9
      return torch.max(sentence_embedding, 1)[0]

    def cls_pooling_strategy(self, sentence_embedding, att_mask):
      return sentence_embedding[:,0]

In [5]:
import torch.nn.functional as F

# Triplet loss function
class Triplet_Loss(nn.Module):
    def __init__(self, distance_metric= 'EUCLIDEAN' ,triplet_margin=1):
        super(Triplet_Loss, self).__init__()
        self.distance_metric = distance_metric
        self.triplet_margin = triplet_margin
    def forward(self, anchor_sentence, pos_sentence, neg_sentence):
        if distance_metric== 'EUCLIDEAN':
          pos_distance = F.pairwise_distance(anchor_sentence, pos_sentence, p=2)
          neg_distance = F.pairwise_distance(anchor_sentence, neg_sentence, p=2)
        elif distance_metric== 'MANHATTAN':
          pos_distance = F.pairwise_distance(anchor_sentence, pos_sentence, p=1)
          neg_distance = F.pairwise_distance(anchor_sentence, neg_sentence, p=1)
        elif distance_metric== 'COSINE':
          pos_distance = 1 - F.cosine_similarity(anchor_sentence, pos_sentence)
          neg_distance = 1 - F.cosine_similarity(anchor_sentence, neg_sentence)
        else:
          raise ValueError("The distance metric can not be", distance_metric)

        loss = F.relu(pos_distance - neg_distance + self.triplet_margin)
        return loss.mean()

## Extra Section Ⅳ: Contrastive loss for siamese network

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Contrastive_Loss(nn.Module):

    def __init__(self, distance_metric= 'EUCLIDEAN', margin = 1):
        super(Contrastive_Loss, self).__init__()
        self.distance_metric = distance_metric
        self.margin = margin
        
    def forward(self, sentence1, sentence2, labels):
        distance = F.pairwise_distance(sentence1, sentence2, p=2)
        loss = 0.5 * (labels * distance.pow(2) + (1 - labels) * F.relu(self.margin - distance).pow(2))
        return losses.mean()