# [Sentence-BERT](https://arxiv.org/pdf/1908.10084.pdf)

[Reference Code](https://www.pinecone.io/learn/series/nlp/train-sentence-transformers-softmax/)

In [1]:
import os
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## 1. Data

### Train, Test, Validation 

In [2]:
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')
mnli['train'].features, snli['train'].features

  from .autonotebook import tqdm as notebook_tqdm


({'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
  'idx': Value(dtype='int32', id=None)},
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)})

In [3]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [4]:
# Remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [5]:
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [6]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([-1,  0,  1,  2]))

In [7]:
# there are -1 values in the label feature, these are where no class could be decided so we remove
snli = snli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [8]:
mnli = mnli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [9]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([0, 1, 2]))

In [132]:
# Assuming you have your two DatasetDict objects named snli and mnli
from datasets import DatasetDict
# Merge the two DatasetDict objects
raw_dataset = DatasetDict({
    'train': datasets.concatenate_datasets([snli['train'], mnli['train']]).shuffle(seed=55).select(list(range(85))),
    'test': datasets.concatenate_datasets([snli['test'], mnli['test_mismatched']]).shuffle(seed=55).select(list(range(15))),
    'validation': datasets.concatenate_datasets([snli['validation'], mnli['validation_mismatched']]).shuffle(seed=55).select(list(range(15)))
})
# Now, merged_dataset_dict contains the combined datasets from snli and mnli
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 85
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 15
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 15
    })
})

## 2. Preprocessing

In [133]:
from torchtext.data.utils import get_tokenizer

# Load the 'basic_english' tokenizer
tokenizer = get_tokenizer('basic_english')
vocab = torch.load('./model/vocab')

In [134]:
len(vocab)

17825

In [135]:
tokens_to_check = ['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]', 'the', 'of', 'and']
for token in tokens_to_check:
    print(f"Index of '{token}': {vocab[token]}")

Index of '[PAD]': 0
Index of '[CLS]': 1
Index of '[SEP]': 2
Index of '[MASK]': 3
Index of '[UNK]': 4
Index of 'the': 5
Index of 'of': 11
Index of 'and': 7


In [136]:
import re

sent = "Hello, world! How are you doing today? Let's explore - regex."
cleaned_sent = re.sub("[.,!?\\-]", '', sent.lower())

print(cleaned_sent)


hello world how are you doing today let's explore  regex


In [142]:
max_seq_length = 256

# Example usage before your model.forward() call


# def tokenize_and_pad(sentences, tokenizer, vocab, max_length=512):
#     # Tokenizes sentences, converts tokens to IDs, adds special tokens, and applies padding
#     tokenized = [tokenizer(re.sub("[.,!?\\-]", '', sent.lower())) for sent in sentences]
#     input_ids = [[vocab['[CLS]']] + [vocab[token] for token in tokens] + [vocab['[SEP]']] for tokens in tokenized]

#     attn_mask = [[1] * len(tokens) + [0] * (max_length - len(tokens)) for tokens in input_ids]
#     input_ids = [tokens + [0] * (max_length - len(tokens)) for tokens in input_ids]
#     return input_ids, attn_mask


import torch
import re

def tokenize_and_pad(sentences, tokenizer, vocab, max_length=512):
    # Directly using the provided indices for special tokens
    UNK_TOKEN_ID = 4  # '[UNK]' token ID
    CLS_TOKEN_ID = 1  # '[CLS]' token ID
    SEP_TOKEN_ID = 2  # '[SEP]' token ID
    PAD_TOKEN_ID = 0  # '[PAD]' token ID

    tokenized = [tokenizer(re.sub("[.,!?\\-]", '', sent.lower())) for sent in sentences]
    input_ids = []

    for tokens in tokenized:
        sentence_ids = [CLS_TOKEN_ID]
        for token in tokens:
            try:
                token_id = vocab[token]  # Attempt to get the token ID from vocab
            except KeyError:
                token_id = UNK_TOKEN_ID  # Use UNK_TOKEN_ID for tokens not in vocab
            sentence_ids.append(token_id)
        sentence_ids.append(SEP_TOKEN_ID)
        input_ids.append(sentence_ids)

    attn_mask = [[1] * len(ids) for ids in input_ids]
    input_ids = [ids + [PAD_TOKEN_ID] * (max_length - len(ids)) for ids in input_ids]
    attn_mask = [mask + [0] * (max_length - len(mask)) for mask in attn_mask]

    # Convert lists to tensors ensuring they are of type long to match expected input types for PyTorch models
    input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
    attn_mask_tensor = torch.tensor(attn_mask, dtype=torch.long)

    return input_ids_tensor, attn_mask_tensor





def preprocess_function(examples):
    # Tokenize and pad both premise and hypothesis
    premise_input_ids, premise_attn_mask = tokenize_and_pad(examples['premise'], tokenizer, vocab, max_seq_length)
    hypothesis_input_ids, hypothesis_attn_mask = tokenize_and_pad(examples['hypothesis'], tokenizer, vocab, max_seq_length)
    
    # Extract labels
    labels = examples["label"]
    
    return {
        "premise_input_ids": premise_input_ids,
        "premise_attention_mask": premise_attn_mask,
        "hypothesis_input_ids": hypothesis_input_ids,
        "hypothesis_attention_mask": hypothesis_attn_mask,
        "labels": labels
    }

# Map the preprocessing function across the dataset in a batched manner
tokenized_datasets = raw_dataset.map(
    preprocess_function,
    batched=True,
)

# Remove the original columns to focus on the processed ones and set the format to PyTorch tensors
tokenized_datasets = tokenized_datasets.remove_columns(['premise', 'hypothesis', 'label'])
tokenized_datasets.set_format("torch")


Map: 100%|██████████| 85/85 [00:00<00:00, 567.40 examples/s]
Map: 100%|██████████| 15/15 [00:00<00:00, 1039.45 examples/s]
Map: 100%|██████████| 15/15 [00:00<00:00, 327.28 examples/s]


In [157]:
sentences = ["Example sentence.", "Another example."]
input_ids, attn_mask = tokenize_and_pad(sentences, tokenizer, vocab, max_length=256)

# Convert lists to tensors
input_ids = torch.tensor(input_ids)
attn_mask = torch.tensor(attn_mask)


  input_ids = torch.tensor(input_ids)
  attn_mask = torch.tensor(attn_mask)


In [158]:
class SimpleBERTModel(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, 768)
        # Add additional layers and components as necessary

    def forward(self, input_ids, attention_mask=None):
        embedded = self.embedding(input_ids)
        # Implement the rest of the forward pass
        return embedded


In [159]:
model = SimpleBERTModel(len(vocab))


In [160]:
# Example training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

def train(model, input_ids, attn_mask, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    outputs = model(input_ids, attn_mask)
    loss = criterion(outputs, torch.tensor([1, 0]))  # Example target
    loss.backward()
    optimizer.step()
    return loss.item()


In [161]:
for epoch in range(1):  # Example single epoch
    loss = train(model, input_ids, attn_mask, optimizer, criterion)
    print(f"Epoch {epoch}, Loss: {loss}")


RuntimeError: Expected target size [2, 768], got [2]

In [143]:
tokenized_datasets['train'][0]

{'premise_input_ids': tensor([   1,  182,   31,   25, 1348, 6710,    2,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
       

## 3. Data loader

In [144]:
from torch.utils.data import DataLoader

# initialize the dataloader
batch_size = 16
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)

In [145]:
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_attention_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_attention_mask'].shape)
    print(batch['labels'].shape)
    break

torch.Size([16, 256])
torch.Size([16, 256])
torch.Size([16, 256])
torch.Size([16, 256])
torch.Size([16])


## 4. Model

In [146]:
# # start from a pretrained bert-base-uncased model
# from transformers import BertTokenizer, BertModel
# model = BertModel.from_pretrained('bert-base-uncased')
# model.to(device)
from model_class import *

# load the model and all its hyperparameters
load_path = './model/bert_best_model.pt'
params, state = torch.load(load_path)
model = BERT(**params, device=device).to(device)
model.load_state_dict(state)

<All keys matched successfully>

In [147]:
model

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(93, 768)
    (pos_embed): Embedding(512, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): Linear(in_features=768, out_features=512, bias=True)
        (W_K): Linear(in_features=768, out_features=512, bias=True)
        (W_V): Linear(in_features=768, out_features=512, bias=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=768, bias=True)
  (activ): Tanh()
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (decoder)

### Pooling
SBERT adds a pooling operation to the output of BERT / RoBERTa to derive a fixed sized sentence embedding

In [148]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

## 5. Loss Function

## Classification Objective Function 
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function. 
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)

<img src="./figures/sbert-architecture.png" >

In [149]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

<img src="./figures/sbert-ablation.png" width="350" height="300">

In [150]:
classifier_head = torch.nn.Linear(768*3, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

In [151]:
from transformers import get_linear_schedule_with_warmup

# and setup a warmup for the first ~10% steps
total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()



## 6. Training

In [152]:
from tqdm.auto import tqdm

def train(model, classifier_head, data, optimizer, optimizer_classifier, scheduler, scheduler_classifier, criterion, device):
    epoch_loss = []
    model.train()
    classifier_head.train()

    for step, batch in enumerate(tqdm(data, leave=True, desc='Training: ')):
        # zero all gradients on each new step
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        segment_ids = torch.zeros(batch_size, max_seq_length, dtype=torch.int32).to(device)  # each input contains only one sentence hence we define them all as sentence '0'
        label = batch['labels'].to(device)
        
        # extract token embeddings from BERT at last_hidden_state
        u_last_hidden_state = model.get_last_hidden_state(inputs_ids_a, segment_ids)  
        v_last_hidden_state = model.get_last_hidden_state(inputs_ids_b, segment_ids)  

        # u_last_hidden_state = u.last_hidden_state # all token embeddings A = batch_size, seq_len, hidden_dim
        # v_last_hidden_state = v.last_hidden_state # all token embeddings B = batch_size, seq_len, hidden_dim

         # get the mean pooled vectors
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim
        
        # build the |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
        uv_abs = torch.abs(uv) # batch_size,hidden_dim
        
        # concatenate u, v, |u-v|
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim
        
        # process concatenated tensor through classifier_head
        x = classifier_head(x) #batch_size, classifer
        
        # calculate the 'softmax-loss' between predicted and true label
        loss = criterion(x, label)
        
        # using loss, calculate gradients and then optimizerize
        loss.backward()
        epoch_loss.append(loss.item())
        optimizer.step()
        optimizer_classifier.step()

        scheduler.step() # update learning rate scheduler
        scheduler_classifier.step()

    return np.mean(epoch_loss)

In [153]:
def evaluate(model, classifier_head, data, criterion, device):
    epoch_loss = []
    model.eval()
    classifier_head.eval()

    with torch.no_grad():
        for step, batch in enumerate(tqdm(data, leave=True, desc='Evaluate: ')):
            
            # prepare batches and more all to the active device
            inputs_ids_a = batch['premise_input_ids'].to(device)
            inputs_ids_b = batch['hypothesis_input_ids'].to(device)
            attention_a = batch['premise_attention_mask'].to(device)
            attention_b = batch['hypothesis_attention_mask'].to(device)
            segment_ids = torch.zeros(batch_size, max_seq_length, dtype=torch.int32).to(device)  # each input contains only one sentence hence we define them all as sentence '0'
            label = batch['labels'].to(device)
            
            # extract token embeddings from BERT at last_hidden_state
            u_last_hidden_state = model.get_last_hidden_state(inputs_ids_a, segment_ids)  
            v_last_hidden_state = model.get_last_hidden_state(inputs_ids_b, segment_ids)  

            # u_last_hidden_state = u.last_hidden_state # all token embeddings A = batch_size, seq_len, hidden_dim
            # v_last_hidden_state = v.last_hidden_state # all token embeddings B = batch_size, seq_len, hidden_dim

            # get the mean pooled vectors
            u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
            v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim
            
            # build the |u-v| tensor
            uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
            uv_abs = torch.abs(uv) # batch_size,hidden_dim
            
            # concatenate u, v, |u-v|
            x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim
            
            # process concatenated tensor through classifier_head
            x = classifier_head(x) #batch_size, classifer
            
            # calculate the 'softmax-loss' between predicted and true label
            loss = criterion(x, label)
            epoch_loss.append(loss.item())

    return np.mean(epoch_loss)

In [154]:
import time
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [155]:
head_path = './model/best_s_bert_classifier_head.pt'
model_path = './model/best_s_bert.pt'

In [156]:
num_epoch = 5

best_val_loss = float('inf')
train_losses = []
val_losses = []

# 1 epoch should be enough, increase if wanted
for epoch in range(num_epoch):
    start_time = time.time()
    train_loss = train(model, classifier_head, train_dataloader, optimizer, optimizer_classifier, scheduler, scheduler_classifier, criterion, device)
    val_loss = evaluate(model, classifier_head, eval_dataloader, criterion, device)

    #for plotting
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    # save the model only when its validation loss is lower than all its predecessors
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(classifier_head, head_path)  # save the classifier head
        torch.save([model.params, model.state_dict()], model_path)  # save the model's parameters and state to a file
        
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {val_loss:.3f}')

Training:   0%|          | 0/6 [00:00<?, ?it/s]


IndexError: index out of range in self

## 7. Inference

In [None]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(model, tokenizer, sentence_a, sentence_b, device):
    # Tokenize and convert sentences to input IDs and attention masks
    inputs_a = tokenizer(sentence_a, return_tensors='pt', truncation=True, padding=True).to(device)
    inputs_b = tokenizer(sentence_b, return_tensors='pt', truncation=True, padding=True).to(device)

    # Move input IDs and attention masks to the active device
    inputs_ids_a = inputs_a['input_ids']
    attention_a = inputs_a['attention_mask']
    inputs_ids_b = inputs_b['input_ids']
    attention_b = inputs_b['attention_mask']

    # Extract token embeddings from BERT
    u = model(inputs_ids_a, attention_mask=attention_a)[0]  # all token embeddings A = batch_size, seq_len, hidden_dim
    v = model(inputs_ids_b, attention_mask=attention_b)[0]  # all token embeddings B = batch_size, seq_len, hidden_dim

    # Get the mean-pooled vectors
    u = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim
    v = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1)  # batch_size, hidden_dim

    # Calculate cosine similarity
    similarity_score = cosine_similarity(u.reshape(1, -1), v.reshape(1, -1))[0, 0]

    return similarity_score

# Example usage:
sentence_a = 'Your contribution helped make it possible for us to provide our students with a quality education.'
sentence_b = "Your contributions were of no help with our students' education."
similarity = calculate_similarity(model, tokenizer, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}")