# Semantic Text Similarity Benchmark

1. Baseline BERT
2. SBERT (Sentence BERT)

Reproduce the results from the Sentence BERT paper

In [1]:
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")
from bert_utils import dynamic_padding
from bert import BERT
from bert_config import BERTConfig
import torch
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from datasets import load_dataset
dataset = load_dataset('sentence-transformers/stsb')


In [2]:
# BERT embeddings for sentence1 and sentence2
from scipy.stats import spearmanr

device = "cuda"
def calculate_spearmanr_correlation(model):
    test_set = dataset["test"]
    model.to(device)
    cos = torch.nn.CosineSimilarity()
    sim_scores = torch.zeros(len(test_set))
    batch_size = 16
    dl = DataLoader(test_set,batch_size=batch_size,collate_fn=dynamic_padding)
    with torch.no_grad():
        for i,data in enumerate(tqdm(dl)):
            s1_embedding = model.bert(data['sentence1'].input_ids.to(device),
                                data['sentence1'].token_type_ids.to(device),
                                data['sentence1'].attention_mask.to(device))
            s2_embedding = model.bert(data['sentence2'].input_ids.to(device),
                                data['sentence2'].token_type_ids.to(device),
                                data['sentence2'].attention_mask.to(device))
            sim_scores[batch_size*i:min(len(test_set),batch_size*(i+1))] = cos(s1_embedding,s2_embedding)
    gt_scores = [item["score"] for item in test_set]
    return spearmanr(gt_scores,sim_scores.tolist()).statistic

In [3]:
# Spearman rank correlation
import torch
from sentenceBERT import sentenceBERT
# model = BERT.from_pretrained(BERTConfig())
ckpt_path = "out/bert_ckpt_train.pt"
ckpt = torch.load(ckpt_path)
model_config = BERTConfig(**ckpt["model_config"])
model = sentenceBERT(model_config)
model.load_state_dict(ckpt["model"])
print(calculate_spearmanr_correlation(model))

  ckpt = torch.load(ckpt_path)


Loading pre-trained weights for bert


100%|██████████| 87/87 [00:04<00:00, 18.86it/s]


0.7460946033571232


| Model                 | Spearman's Correlation coefficient    |
:-----------------      | :-------:|
| BERT -CLS embedding   | 0.2030                                  |
| Avg BERT Embeddings   | 0.4693 |
| SBERT pretrained on SNLI | 0.7057 |
| SBERT pretrained on SNLI + MultiNLI | 0.7460 |