# Setup

In [429]:
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader

from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.losses import CosineSimilarityLoss
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction

# For debugging
# from scipy.stats import spearmanr, pearsonr
# from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances

### Settings

In [576]:
# Data and training
batch_size = 16  # {16, 32}

# Training
epochs = 1
warmup_steps = 100

# Data

Using the official STS Benchmark dataset:  
https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark

Information about the train/dev/test datasets are available under `data/stsbenchmark/readme.txt`

## Data loading and cleanup

There seem to be some rows with 7 columns and some with 9 columns (proabably due to additional licensing information), and so we'll try to make sure we don't miss the 3 important columns that we need: Score, Sentence 1, and Sentence 2.

In [582]:
train_csv = 'data/stsbenchmark/sts-train.csv'
dev_csv   = 'data/stsbenchmark/sts-dev.csv'
test_csv  = 'data/stsbenchmark/sts-test.csv'

### Training set

In [515]:
with open(train_csv) as f:
    lines = f.readlines()
    lines = [l.strip() for l in lines]

In [516]:
# Should be 5749
len(lines)

5749

In [517]:
rows = [l.split('\t') for l in lines]

In [518]:
row_type_counts = {}

for row in rows:
    if len(row) not in row_type_counts.keys():
        row_type_counts[len(row)] = 1
    else:
        row_type_counts[len(row)] += 1

row_type_counts

{7: 5552, 9: 197}

Assuming that the data is clean enough and the rows with 7 columns are all correctly filled, we will just quickly check the relevant columns of those with 9 fields, just to make sure our data is clean.

In [519]:
for row in rows:
    if len(row) == 9:
        print(f'{row[4][:10]}  |  {row[5][:10]}  |  {row[6][:10]}')

4  |  Driver bac  |  Driver bac
5  |  Spain Prin  |  Spain prin
5  |  Senate con  |  Senate app
5  |  U.N. right  |  UN Rights 
5  |  US Senate   |  Senate con
5  |  Syrian Reb  |  Syrian reb
3  |  Mayawati d  |  Mayawati d
4  |  Uganda's p  |  Uganda's p
5  |  Rocks, Tea  |  Rocks, tea
5  |  Boston bom  |  Boston bom
5  |  Six dead i  |  6 killed i
5  |  China army  |  China army
4  |  Silvio Ber  |  Silvio Ber
5  |  Two killed  |  Two killed
5  |  Russia war  |  Russia war
4  |  Hosni Muba  |  Mubarak re
4  |  Egypt prot  |  Egypt prot
5  |  Couple mar  |  Couple get
4  |  Qatar's em  |  Qatari emi
3  |  Philippine  |  Philippine
4  |  Egypt brac  |  Egypt brac
4  |  Red Sox Be  |  Red Sox be
4  |  China land  |  China land
4  |  Ukrainian   |  Ukraine's 
5  |  Venezuela   |  Venezuela 
5  |  North Kore  |  North Kore
4  |  Captain of  |  Captain of
3  |  Cars, driv  |  Cars plung
5  |  Death toll  |  Death toll
4  |  Communist   |  Communist 
3  |  Egypt's Mo  |  Egypt's Mo
5  |  Ir

Everything seems OK. Let's just check the first 20 rows with 7 columns also.

In [520]:
for row in rows[:20]:
    if len(row) == 7:
        print(f'{row[4][:10]}  |  {row[5][:10]}  |  {row[6][:10]}')

5.000  |  A plane is  |  An air pla
3.800  |  A man is p  |  A man is p
3.800  |  A man is s  |  A man is s
2.600  |  Three men   |  Two men ar
4.250  |  A man is p  |  A man seat
4.250  |  Some men a  |  Two men ar
0.500  |  A man is s  |  A man is s
1.600  |  The man is  |  The man is
2.200  |  A man is p  |  A woman is
5.000  |  A person i  |  A person t
4.200  |  The man hi  |  The man sp
4.600  |  A woman pi  |  A woman pi
3.867  |  A man is p  |  A man is p
4.667  |  A person i  |  Someone is
1.667  |  A man is r  |  A panda do
3.750  |  A dog is t  |  A dog is t
5.000  |  The polar   |  A polar be
0.500  |  A woman is  |  A woman is
3.800  |  A cat is r  |  A cat is r
5.000  |  The man is  |  A man is r


All good!

Making SentenceTransformer-specific DataLoaders out of them then.

In [521]:
examples = []

for row in rows:
    examples.append(
        InputExample(
            texts=[row[5].strip(), row[6].strip()],
            label=float(row[4])))

train_dl = DataLoader(examples, shuffle=True, batch_size=batch_size)

In [522]:
# Should be 5749
len(examples)

5749

### Development set

In [539]:
with open(dev_csv) as f:
    lines = f.readlines()
    lines = [l.strip() for l in lines]

In [540]:
# Should be 1500
len(lines)

1500

In [541]:
rows = [l.split('\t') for l in lines]

In [542]:
row_type_counts = {}

for row in rows:
    if len(row) not in row_type_counts.keys():
        row_type_counts[len(row)] = 1
    else:
        row_type_counts[len(row)] += 1

row_type_counts

{7: 1478, 9: 22}

Assuming that the data is clean enough and the rows with 7 columns are all correctly filled, we will just quickly check the relevant columns of those with 9 fields, just to make sure our data is clean.

In [543]:
for row in rows:
    if len(row) == 9:
        print(f'{row[4][:10]}  |  {row[5][:10]}  |  {row[6][:10]}')

5  |  US Supreme  |  High court
4  |  Euro crisi  |  Eurozone d
4  |  Treasury p  |  Treasury p
3  |  EU Ministe  |  EU Ministe
5  |  Chinese lu  |  China land
2  |  Boy, 14, a  |  Two teenag
3  |  Protests a  |  Zimmerman 
3  |  Maldives b  |  Crisis-hit
1  |  Obama, Hol  |  Hollande a
5  |  Iran, IAEA  |  Iran, IAEA
2  |  Bombings k  |  Bombings k
2  |  10 Things   |  10 Things 
2  |  New UN pea  |  UN takes o
5  |  Oil falls   |  Oil prices
3  |  Israeli fo  |  Israeli fo
3  |  Israeli po  |  Israel Pol
0  |  3 killed,   |  Five kille
2  |  Scientists  |  Has Nasa d
0  |  Pranab str  |  WTO: India
2  |  Volkswagen  |  Volkswagen
0  |  Obama is r  |  Obama wait
0  |  New video   |  New York p


Everything seems OK. Let's just check the first 20 rows with 7 columns also.

In [544]:
for row in rows:
    if len(row) == 7:
        print(f'{row[4][:10]}  |  {row[5][:10]}  |  {row[6][:10]}')

5.000  |  A man with  |  A man wear
4.750  |  A young ch  |  A child is
5.000  |  A man is f  |  The man is
2.400  |  A woman is  |  A man is p
2.750  |  A woman is  |  A man is p
2.615  |  A woman is  |  A man is c
5.000  |  A man is e  |  The man is
2.333  |  A woman is  |  A woman is
3.750  |  Three men   |  Three men 
5.000  |  A woman pe  |  A woman is
3.200  |  People are  |  Men are pl
1.583  |  A man is p  |  A man is p
5.000  |  The cougar  |  A cougar i
5.000  |  The man cu  |  A man chop
4.909  |  The man is  |  A man is p
0.800  |  A man is f  |  A woman is
2.400  |  The girl s  |  The lady s
5.000  |  A man is c  |  A man clim
4.000  |  Kittens ar  |  Kittens ar
0.636  |  A man is s  |  A man is s
3.000  |  A woman is  |  A woman is
1.714  |  A man is p  |  A man is p
3.200  |  An animal   |  An animal 
2.167  |  A man is p  |  A man is p
1.000  |  A man is p  |  A man is p
1.917  |  A girl is   |  A girl is 
4.250  |  A man is l  |  A man is l
3.000  |  An animal   |  A s

All good!

In [545]:
examples = []

dev_s1 = []
dev_s2 = []
dev_scores = []

for row in rows:
    examples.append(
        InputExample(
            texts=[row[5].strip(), row[6].strip()],
            label=float(row[4])))
    
    dev_s1.append(row[5].strip())
    dev_s2.append(row[6].strip())
    dev_scores.append(float(row[4].strip()))

dev_dl = DataLoader(examples, shuffle=True, batch_size=batch_size)

# Because EmbeddingSimilarityEvaluator won't accept a DataLoader
dev_examples = {
      'sentences1': dev_s1
    , 'sentences2': dev_s2
    , 'scores': dev_scores }

In [546]:
# Should be 1500
len(examples)

1500

### Test set

In [547]:
with open(test_csv) as f:
    lines = f.readlines()
    lines = [l.strip() for l in lines]

In [548]:
# Should be 1379
len(lines)

1379

In [549]:
rows = [l.split('\t') for l in lines]

In [550]:
row_type_counts = {}

for row in rows:
    if len(row) not in row_type_counts.keys():
        row_type_counts[len(row)] = 1
    else:
        row_type_counts[len(row)] += 1

row_type_counts

{7: 1095, 9: 284}

Assuming that the data is clean enough and the rows with 7 columns are all correctly filled, we will just quickly check the relevant columns of those with 9 fields, just to make sure our data is clean.

In [551]:
for row in rows:
    if len(row) == 9:
        print(f'{row[4][:10]}  |  {row[5][:10]}  |  {row[6][:10]}')

3  |  I remained  |  I remained
3  |  In the US,  |  It really 
0  |  There's al  |  There is a
0  |  You also i  |  You can do
2  |  I did this  |  I have thi
0  |  You just h  |  You may wa
5  |  You do not  |  You don't 
3  |  You should  |  You should
2  |  You should  |  You should
0  |  You need t  |  You have t
2  |  It depends  |  i think it
5  |  You can do  |  Yes, you c
1  |  You should  |  You can do
4  |  You have t  |  You have t
4  |  I have few  |  I have two
1  |  You want t  |  You will h
1  |  if you don  |  The key th
1  |  Unfortunat  |  My answer 
4  |  As soon as  |  Start them
1  |  You just h  |  It depends
5  |  The answer  |  The answer
4  |  To give th  |  I'll answe
4  |  Unfortunat  |  Sorry, I d
4  |  The rule -  |  I always g
4  |  This is no  |  This sound
3  |  Yes, it's   |  It's a goo
4  |  It probabl  |  It depends
0  |  It's not a  |  It's a goo
0  |  It's prett  |  It's much 
1  |  Yes, there  |  Yes, that 
4  |  Have you t  |  Have you t
5  |  Yo

Everything seems OK. Let's just check the first 20 rows with 7 columns also.

In [552]:
for row in rows[:20]:
    if len(row) == 7:
        print(f'{row[4][:10]}  |  {row[5][:10]}  |  {row[6][:10]}')

2.500  |  A girl is   |  A girl is 
3.600  |  A group of  |  A group of
5.000  |  One woman   |  A woman me
4.200  |  A man is c  |  A man is s
1.500  |  A man is p  |  A man is p
1.800  |  A woman is  |  A woman is
3.500  |  A man is r  |  A man is r
2.200  |  A man is p  |  A man is p
2.200  |  A man is p  |  A lady is 
1.714  |  A man is p  |  A man is p
1.714  |  A man is p  |  A man is p
5.000  |  A man is c  |  A man cuts
0.600  |  A man is c  |  A man is t
4.400  |  A man is s  |  A man is c
2.000  |  A man is s  |  A man is s
1.800  |  A man is p  |  A man is p
4.400  |  A baby pan  |  A panda sl
3.600  |  A man is s  |  A man is p
3.600  |  A man atta  |  A man slap
1.200  |  A man is d  |  A man is r


All good!

In [553]:
examples = []

test_s1 = []
test_s2 = []
test_scores = []

for row in rows:
    examples.append(
        InputExample(
            texts=[row[5].strip(), row[6].strip()],
            label=float(row[4])))
    
    test_s1.append(row[5].strip())
    test_s2.append(row[6].strip())
    test_scores.append(float(row[4].strip()))

test_dl = DataLoader(examples, shuffle=True, batch_size=batch_size)

# Because EmbeddingSimilarityEvaluator won't accept a DataLoader
test_examples = {
      'sentences1': test_s1
    , 'sentences2': test_s2
    , 'scores': test_scores }

In [554]:
# Should be 1379
len(examples)

1379

# Model

In [580]:
model_name = f'S-BERT-STS-Trained-batch_size-{batch_size}-epochs-{epochs}-warmup_steps-{warmup_steps}'
model_path = 'models' + '/' + model_name

Either build a model...

In [571]:
# word_embedding_model = Transformer('bert-base-uncased')  # [optional] max_seq_length=256
# mean_pooling = Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean')
# model = SentenceTransformer(modules=[word_embedding_model, mean_pooling])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


... or load a previously saved one from disk.

In [578]:
model = SentenceTransformer(model_path)

# Training

### S-BERT Regression Objective

Documentation for the Cosine Similarity Loss:  
https://www.sbert.net/docs/package_reference/losses.html#sentence_transformers.losses.CosineSimilarityLoss

The loss that will be minimized is the following:

$$ \mathrm{Loss} = \lVert  \,\,  \mathrm{input \, label} - \mathrm{cos\_score\_transformation} \, ( \, \mathrm{cosine\_sim} \, (u,v) \,)  \,\,  \rVert_2 $$

The similarity score in the STS benchmark goes from 0 to 5, while cosine similarity goes from -1 to 1.  
We need to therefore map the cosine similarity score back to the 0-5 range.

In [563]:
class ScoreMap(nn.Module):
    def __init__(self, *args, **kwargs):
        super(ScoreMap, self).__init__()

    def forward(self, input: Tensor) -> Tensor:
        """Maps input from [-1, 1] to [0, 5]."""
        return 5/2 * input + 5/2

In [564]:
train_loss = CosineSimilarityLoss(
                 model=model,
                 loss_fct=nn.MSELoss(),
                 cos_score_transformation=ScoreMap())

### Model Training (Fine-Tuning)

Not necessary to call `fit()`, if you've loaded a fine-tuned model.

In [None]:
# model.fit(train_objectives=[(train_dl, train_loss)],
#           epochs=epochs,
#           warmup_steps=warmup_steps)

Save the model to disk for later use if you want.

In [581]:
# model.save(path=model_path, model_name=model_name, create_model_card=True)

In [565]:
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=dev_examples['sentences1'],
    sentences2=dev_examples['sentences2'],
    scores=dev_examples['scores'],
    batch_size=batch_size,
    main_similarity=SimilarityFunction.COSINE,  # {COSINE, EUCLIDEAN, MANHATTAN, DOT_PRODUCT}
    show_progress_bar=True,
    write_csv=False)

# The SimilarityFunction values come from:
# https://github.com/UKPLab/sentence-transformers/blob/46a149433fe9af0851f7fa6f9bf37b5ffa2c891c/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py#L120

### Evaluation: Spearman Rank Correlation on Dev Set

In order to obtain the embeddings necessary for the evaluation, the model will run over the data twice: once for the first sentences, and once for the second sentences.

In [569]:
dev_spearman_rho = model.evaluate(dev_evaluator)
print(f'Spearman (dev set): {dev_spearman_rho * 100}')  # For consistency with the paper

Spearman (dev set): 73.03830917146799


# Final Evaluation

### Spearman Rank Correlation on Test Set

In [567]:
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=test_examples['sentences1'],
    sentences2=test_examples['sentences2'],
    scores=test_examples['scores'],
    batch_size=batch_size,
    main_similarity=SimilarityFunction.COSINE,  # {COSINE, EUCLIDEAN, MANHATTAN, DOT_PRODUCT}
    show_progress_bar=True,
    write_csv=False)

# The SimilarityFunction values come from:
# https://github.com/UKPLab/sentence-transformers/blob/46a149433fe9af0851f7fa6f9bf37b5ffa2c891c/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py#L120

In [570]:
test_spearman_rho = model.evaluate(test_evaluator)
print(f'Spearman (test set): {test_spearman_rho * 100}')  # For consistency with the paper

Spearman (test set): 65.57546300121835


# Conclusion

All else being equal, for `batch_size` in {16, 32}, we have the following Spearman scores:
 - dev set: {73.03830917146799, 72.19737929737312}
 - test set: {65.57546300121835, 64.55123444928913}