# [Sentence-BERT](https://arxiv.org/pdf/1908.10084.pdf)

[Reference Code](https://www.pinecone.io/learn/series/nlp/train-sentence-transformers-softmax/)

In [None]:
import os
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## 1. Data

### Train, Test, Validation

In [None]:
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')
mnli['train'].features, snli['train'].features

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


({'premise': Value('string'),
  'hypothesis': Value('string'),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction']),
  'idx': Value('int32')},
 {'premise': Value('string'),
  'hypothesis': Value('string'),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'])})

In [None]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [None]:
# Remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [None]:
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [None]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([-1,  0,  1,  2]))

In [None]:
# there are -1 values in the label feature, these are where no class could be decided so we remove
snli = snli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [None]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([0, 1, 2]))

In [None]:
# Assuming you have your two DatasetDict objects named snli and mnli
from datasets import DatasetDict
# Merge the two DatasetDict objects
raw_dataset = DatasetDict({
    'train': datasets.concatenate_datasets([snli['train'], mnli['train']]).shuffle(seed=55).select(list(range(1000))),
    'test': datasets.concatenate_datasets([snli['test'], mnli['test_mismatched']]).shuffle(seed=55).select(list(range(100))),
    'validation': datasets.concatenate_datasets([snli['validation'], mnli['validation_mismatched']]).shuffle(seed=55).select(list(range(1000)))
})
#remove .select(list(range(1000))) in order to use full dataset
# Now, merged_dataset_dict contains the combined datasets from snli and mnli
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
})

## 2. Preprocessing

In [None]:
def custom_tokenizer(sentence, word2id, max_length=128):
    tokens = sentence.lower().split()
    ids = [word2id.get(w, 0) for w in tokens]  # unknown → PAD
    ids = ids[:max_length]
    attention_mask = [1]*len(ids)

    # padding
    pad_len = max_length - len(ids)
    ids += [0]*pad_len
    attention_mask += [0]*pad_len

    return ids, attention_mask


In [None]:
import collections

word2id = collections.defaultdict(lambda: 0) 
# word2id["[PAD]"] = 0
# word2id["[CLS]"] = 1
# word2id["[SEP]"] = 2

def preprocess_function(examples):
    max_seq_length = 128

    # Initialize lists to store tokenized results for the batch
    batch_premise_input_ids = []
    batch_premise_attention_mask = []
    batch_hypothesis_input_ids = []
    batch_hypothesis_attention_mask = []

    # Iterate over each example in the batch
    for i in range(len(examples['premise'])):
        # Tokenize the premise
        premise_input_ids, premise_attention_mask = custom_tokenizer(
            examples['premise'][i], word2id=word2id, max_length=max_seq_length)
        batch_premise_input_ids.append(premise_input_ids)
        batch_premise_attention_mask.append(premise_attention_mask)

        # Tokenize the hypothesis
        hypothesis_input_ids, hypothesis_attention_mask = custom_tokenizer(
            examples['hypothesis'][i], word2id=word2id, max_length=max_seq_length)
        batch_hypothesis_input_ids.append(hypothesis_input_ids)
        batch_hypothesis_attention_mask.append(hypothesis_attention_mask)

    # Extract labels
    labels = examples["label"]

    return {
        "premise_input_ids": batch_premise_input_ids,
        "premise_attention_mask": batch_premise_attention_mask,
        "hypothesis_input_ids": batch_hypothesis_input_ids,
        "hypothesis_attention_mask": batch_hypothesis_attention_mask,
        "labels" : labels
    }

# When using `map`, if your function requires additional arguments, you can pass them
# via `fn_kwargs`. However, since `word2id` is now a global placeholder in this cell,
# `preprocess_function` can access it directly.
tokenized_datasets = raw_dataset.map(
    preprocess_function,
    batched=True,
)

tokenized_datasets = tokenized_datasets.remove_columns(['premise','hypothesis','label'])
tokenized_datasets.set_format("torch")

In [None]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
})

## 3. Data loader

In [None]:
from torch.utils.data import DataLoader

# initialize the dataloader
batch_size = 32
train_dataloader = DataLoader(
    tokenized_datasets['train'],
    batch_size=batch_size,
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'],
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'],
    batch_size=batch_size
)

In [None]:
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_attention_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_attention_mask'].shape)
    print(batch['labels'].shape)
    break

torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32])


## 4. Model

In [None]:
# load the pretained BERT model

import sys
sys.path.append('/content/sample_data')

from bert_update import BERT


model = BERT(
    n_layers=12,
    n_heads=12,
    d_model=768,
    d_ff=768*4,
    d_k=64,
    n_segments=2,
    vocab_size=110913,
    max_len=1000, # Reduced max_len to match preprocess_function's max_seq_length
    device=device
).to(device)

model.load_state_dict(torch.load('/content/sample_data/bert_model.pth', map_location=device), strict=False)
model.to(device)

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(110913, 768)
    (pos_embed): Embedding(1000, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0-11): 12 x EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): Linear(in_features=768, out_features=768, bias=True)
        (W_K): Linear(in_features=768, out_features=768, bias=True)
        (W_V): Linear(in_features=768, out_features=768, bias=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=768, bias=True)
  (activ): Tanh()
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (d

In [None]:
import os
import torch

model_path = '/content/sample_data/bert_model.pth'

print(f"Checking file: {model_path}")

# Check if the file exists
if not os.path.exists(model_path):
    print(f"Error: The file {model_path} does not exist. Please ensure it is uploaded or accessible.")
else:
    print(f"File '{model_path}' exists.")

    # Get file size
    file_size_bytes = os.path.getsize(model_path)
    file_size_mb = file_size_bytes / (1024 * 1024)
    print(f"File size: {file_size_bytes} bytes ({file_size_mb:.2f} MB).")

    # Try to load the model state dictionary
    try:
        # Using map_location='cpu' to avoid GPU memory issues if not needed for this check
        # and to specifically test the file loading irrespective of device availability.
        state_dict = torch.load(model_path, map_location='cpu')
        print("Successfully loaded the model state dictionary. File appears to be intact.")
       
        # print("Keys in the state dictionary:", state_dict.keys())
    except RuntimeError as e:
        if "PytorchStreamReader failed reading zip archive: failed finding central directory" in str(e):
            print("Error: The .pth file appears to be corrupted or incomplete (missing central directory).")
            print("Please re-download or verify the integrity of 'bert_model.pth'.")
        else:
            print(f"An unexpected RuntimeError occurred during loading: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

Checking file: /content/sample_data/bert_model.pth
File '/content/sample_data/bert_model.pth' exists.
File size: 660760279 bytes (630.15 MB).
Successfully loaded the model state dictionary. File appears to be intact.


### Pooling
SBERT adds a pooling operation to the output of BERT / RoBERTa to derive a fixed sized sentence embedding

In [None]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

## 5. Loss Function

## Classification Objective Function
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function.
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)

<img src="./figures/sbert-architecture.png" >

In [None]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim

    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

In [None]:
hidden_dim = 768

classifier_head = nn.Linear(hidden_dim * 3, 3).to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

In [None]:
from transformers import get_linear_schedule_with_warmup
num_epoch = 3
# and setup a warmup for the first ~10% steps
total_steps = len(train_dataloader) * num_epoch

#total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()



## 6. Training

In [None]:


from tqdm.auto import tqdm

num_epoch = 5
# 1 epoch should be enough, increase if wanted
for epoch in range(num_epoch):
    model.train()
    classifier_head.train()
    # initialize the dataloader loop with tqdm (tqdm == progress bar)
    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):
        # zero all gradients on each new step
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()

        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['labels'].to(device)

        # extract token embeddings from BERT at last_hidden_state
        #u = model(inputs_ids_a, attention_mask=attention_a)
        #v = model(inputs_ids_b, attention_mask=attention_b)

        #u_last_hidden_state = u.last_hidden_state # all token embeddings A = batch_size, seq_len, hidden_dim
        #v_last_hidden_state = v.last_hidden_state # all token embeddings B = batch_size, seq_len, hidden_dim

        u_last_hidden_state = model.get_last_hidden_state(inputs_ids_a, segment_ids=torch.zeros_like(inputs_ids_a))
        v_last_hidden_state = model.get_last_hidden_state(inputs_ids_b, segment_ids=torch.zeros_like(inputs_ids_b))


         # get the mean pooled vectors
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim

        # build the |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
        uv_abs = torch.abs(uv) # batch_size,hidden_dim

        # concatenate u, v, |u-v|
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim

        # process concatenated tensor through classifier_head
        x = classifier_head(x) #batch_size, classifer

        # calculate the 'softmax-loss' between predicted and true label
        loss = criterion(x, label)

        # using loss, calculate gradients and then optimizerize
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        scheduler.step() # update learning rate scheduler
        scheduler_classifier.step()

    print(f'Epoch: {epoch + 1} | loss = {loss.item():.6f}')

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 1 | loss = 2.771497


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 2 | loss = 1.087213


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 3 | loss = 1.141512


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 4 | loss = 1.489467


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 5 | loss = 0.865883


In [None]:
model.eval()
classifier_head.eval()
total_similarity = 0
with torch.no_grad():
    for step, batch in enumerate(eval_dataloader):
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['labels'].to(device)

        # Use get_last_hidden_state which is the correct interface for this custom BERT model
        # and aligns with the training loop's usage.
        u = model.get_last_hidden_state(inputs_ids_a, segment_ids=torch.zeros_like(inputs_ids_a))
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids=torch.zeros_like(inputs_ids_b))

        # get the mean pooled vectors
        u_mean_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1) # batch_size, hidden_dim

        similarity_score = cosine_similarity(u_mean_pool, v_mean_pool)
        total_similarity += similarity_score

average_similarity = total_similarity / len(eval_dataloader)
print(f"Average Cosine Similarity: {average_similarity:.4f}")

Average Cosine Similarity: 0.8015


## 7. Inference

In [None]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(model, word2id_map, sentence_a, sentence_b, device):
    max_seq_length = 128 # Use the same max_seq_length as in preprocess_function

    # Tokenize and convert sentences to input IDs and attention masks using custom_tokenizer
    ids_a, attention_a_list = custom_tokenizer(sentence_a, word2id=word2id_map, max_length=max_seq_length)
    ids_b, attention_b_list = custom_tokenizer(sentence_b, word2id=word2id_map, max_length=max_seq_length)

    # Convert to PyTorch tensors
    inputs_ids_a = torch.tensor([ids_a], dtype=torch.long).to(device)
    attention_a = torch.tensor([attention_a_list], dtype=torch.long).to(device)
    inputs_ids_b = torch.tensor([ids_b], dtype=torch.long).to(device)
    attention_b = torch.tensor([attention_b_list], dtype=torch.long).to(device)

    model.eval() # Set model to evaluation mode
    with torch.no_grad():
        # Use get_last_hidden_state which is the correct interface for this custom BERT model
        u_last_hidden_state = model.get_last_hidden_state(inputs_ids_a, segment_ids=torch.zeros_like(inputs_ids_a))
        v_last_hidden_state = model.get_last_hidden_state(inputs_ids_b, segment_ids=torch.zeros_like(inputs_ids_b))

        # Get the mean-pooled vectors
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a).detach().cpu().numpy() # Shape (1, hidden_dim)
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b).detach().cpu().numpy() # Shape (1, hidden_dim)

        # Calculate cosine similarity. sklearn's cosine_similarity expects 2D arrays.
        similarity_score = cosine_similarity(u_mean_pool, v_mean_pool)[0, 0]

    return similarity_score

# Example usage:
sentence_a = 'Your contribution helped make it possible for us to provide our students with a quality education.'
sentence_b = "Your contributions were of no help with our students' education."
# Pass the global word2id dictionary
similarity = calculate_similarity(model, word2id, sentence_a, sentence_b, device)
print(f"Cosine Similarity: {similarity:.4f}")

Cosine Similarity: 0.9956


In [None]:
import torch
import numpy as np
from sklearn.metrics import classification_report, accuracy_score

model.eval()
classifier_head.eval() # Ensure classifier head is also in eval mode

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in eval_dataloader:

        # Prepare batches and move all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Get sentence embeddings using the model's get_last_hidden_state method
        u_last_hidden_state = model.get_last_hidden_state(inputs_ids_a, segment_ids=torch.zeros_like(inputs_ids_a))
        v_last_hidden_state = model.get_last_hidden_state(inputs_ids_b, segment_ids=torch.zeros_like(inputs_ids_b))

        # Apply mean pooling to get fixed-size sentence embeddings
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim

        # Build the |u-v| tensor
        uv_abs = torch.abs(u_mean_pool - v_mean_pool) # batch_size,hidden_dim

        # Concatenate u, v, |u-v| as done in training
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim

        # Pass through the classifier head
        logits_clsf = classifier_head(x)

        # Get predicted class
        preds = torch.argmax(logits_clsf, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Convert to numpy arrays
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Label names
target_names = ['entailment', 'neutral', 'contradiction']

# Print classification report
report = classification_report(
    all_labels,
    all_preds,
    target_names=target_names,
    digits=2
)

print("Classification Report:\n")
print(report)

# Print overall accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"\nOverall Accuracy: {accuracy:.4f}")

Classification Report:

               precision    recall  f1-score   support

   entailment       0.35      0.28      0.31       338
      neutral       0.31      0.10      0.15       328
contradiction       0.34      0.63      0.44       334

     accuracy                           0.34      1000
    macro avg       0.33      0.34      0.30      1000
 weighted avg       0.33      0.34      0.30      1000


Overall Accuracy: 0.3370
