<a href="https://colab.research.google.com/github/rahiakela/natural-language-processing-research-and-practice/blob/main/nlp-for-semantic-search/3_training_sentence_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Training Sentence Transformers the OG Way (with Softmax Loss)

There are several ways of training sentence transformers. One of the most popular (and the approach we will cover) is using Natural Language Inference (NLI) datasets.

NLI focus on identifying sentence pairs that infer or do not infer one another. We will use two of these datasets; the Stanford Natural Language Inference (SNLI) and Multi-Genre NLI (MNLI) corpora.

Merging these two corpora gives us 943K sentence pairs (550K from SNLI, 393K from MNLI). All pairs include a `premise` and a `hypothesis`, and each pair is assigned a `label`:

- 0 — entailment, e.g. the premise suggests the hypothesis.
- 1 — neutral, the premise and hypothesis could both be true, but they are not necessarily related.
- 2 — contradiction, the premise and hypothesis contradict each other.

**Reference**:

https://www.pinecone.io/learn/train-sentence-transformers-softmax/

##Setup

In [None]:
!pip -q install datasets
!pip -q install transformers
!pip -q install sentence_transformers

In [2]:
import datasets

from transformers import BertTokenizer
from transformers import BertModel
from transformers.optimization import get_linear_schedule_with_warmup

from sentence_transformers import InputExample

import torch
from torch.utils.data import DataLoader

import os
from tqdm.auto import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##NLI Training

When training the model, we will be feeding sentence A (the premise) into BERT, followed by sentence B (the hypothesis) on the next step.

From there, the models are optimized using softmax loss using the label field. We will explain this in more depth soon.

In [None]:
snli = datasets.load_dataset("snli", split="train")

In [5]:
snli

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 550152
})

In [6]:
print(snli[0])

{'premise': 'A person on a horse jumps over a broken down airplane.', 'hypothesis': 'A person is training his horse for a competition.', 'label': 1}


In [None]:
m_nli = datasets.load_dataset("glue", "mnli", split="train")

In [8]:
m_nli

Dataset({
    features: ['premise', 'hypothesis', 'label', 'idx'],
    num_rows: 392702
})

In [None]:
m_nli = m_nli.remove_columns(["idx"])
snli =  snli.cast(m_nli.features)
dataset = datasets.concatenate_datasets([snli, m_nli])

In [10]:
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942854
})

Both datasets contain `-1` values in the label feature where no confident class could be assigned. We remove them using the `filter` method.

In [11]:
print(len(dataset))

# there are -1 values in the label feature, these are where no class could be decided so we remove
dataset = dataset.filter(lambda x: 0 if x["label"] == -1 else 1)
print(len(dataset))

942854


  0%|          | 0/943 [00:00<?, ?ba/s]

942069


We must convert our human-readable sentences into transformer-readable tokens, so we go ahead and tokenize our sentences. Both `premise` and `hypothesis` features must be split into their own `input_ids` and `attention_mask` tensors.

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
all_cols = ["label"]

for part in ["premise", "hypothesis"]:
  dataset = dataset.map(
      lambda x: tokenizer(x[part], max_length=128, padding="max_length", truncation=True),
      batched=True
  )
  for col in ["input_ids", "attention_mask"]:
    dataset = dataset.rename_column(col, part + "_" + col)
    all_cols.append(part + "_" + col)

In [14]:
print(all_cols)

['label', 'premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask']


Now, all we need to do is prepare the data to be read into the model. 

To do this, we first convert the `dataset` features into PyTorch tensors and then initialize a data loader which will feed data into our model during training.

In [15]:
# covert dataset features to PyTorch tensors
dataset.set_format(type="torch", columns=all_cols)

# initialize the dataloader
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

And we’re done with data preparation. Let’s move on to the training approach.

##Softmax Loss

Optimizing with softmax loss was the primary method used by Reimers and Gurevych in the original SBERT paper.

Although this was used to train the first sentence transformer model, it is no longer the go-to training approach. Instead, [the MNR loss approach](https://www.pinecone.io/learn/fine-tune-sentence-transformers-mnr/) is most common today.

However, we hope that explaining softmax loss will help demystify the different approaches applied to training sentence transformers.

###Model Preparation

When we train an SBERT model, we don’t need to start from scratch. We begin with an already pretrained BERT model (and tokenizer).

In [None]:
# start from a pretrained bert-base-uncased model
model = BertModel.from_pretrained("bert-base-uncased")

We will be using what is called a `siamese`-BERT architecture during training. 

All this means is that given a sentence pair, we feed sentence A into BERT first, then feed sentence B once BERT has finished processing the first.

This has the effect of creating a siamese-like network where we can imagine two identical BERTs are being trained in parallel on sentence pairs. 

In reality, there is just a single model processing two sentences one after the other.

<img src="https://d33wubrfki0l68.cloudfront.net/b340f39a6c100e322d5354315e678e4caeea39a0/c5297/images/train-sentence-transformer-2.jpg" width=600 alt="Start SBERT">

BERT will output `512x768`-dimensional embeddings. We will convert these into an average embedding using mean-pooling. This pooled output is our sentence embedding. We will have two per step — one for sentence A that we call $u$, and one for sentence B, called $v$.



In [17]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
  # reshape attention_mask to cover 768-dimension embeddings
  mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
  # perform mean-pooling but exclude padding tokens (specified by in_mask)
  pooling = torch.sum(token_embeds * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)

  return pooling

Here we take BERT’s token embeddings output and the sentence’s `attention_mask` tensor. We then resize the `attention_mask` to align to the higher 768-dimensionality of the token embeddings.

We apply this resized mask to those token embeddings to exclude padding tokens from the mean pooling operation. Our mean pooling takes the average activation of values across each dimension to produce a single value. This brings our tensor sizes from `(512*768) to (1*768)`.

The next step is to concatenate these embeddings. Several different approaches to this were presented in the paper:

| | |
|--|--|
|Concatenation	| NLI Performance|
| (u, v)	|66.04|
| (\|u-v\|)	|69.78|
| (u*v)	|70.54|
| (\|u-v\|, u*v)	|78.37|
| (u, v, u*v)	|77.44|
| (u, v, \|u-v\|)	|80.78|
| (u, v, \|u-v\|, u*v)	|80.44|

Of these, the best performing is built by concatenating vectors $u, v$, and $|u-v|$. 

Concatenation of them all produces a vector three times the length of each original vector. We label this concatenated vector $(u, v, |u-v|)$. 

Where $|u-v|$ is the element-wise difference between vectors $u$ and $v$.

<img src="https://d33wubrfki0l68.cloudfront.net/d1da925240ca6265229f0e0e5cc896931f41f9bf/5a050/images/train-sentence-transformer-3.jpg" width=400 alt="UV Vectors">

We will perform this concatenation operation using PyTorch. 

Once we have our mean-pooled sentence vectors $u$ and $v$ we concatenate with:

In [18]:
u = torch.tensor([3, 3])
v = torch.tensor([2, 2])

# produces |u-v| tensor
uv_abs = torch.abs(torch.sub(u, v))
# then we concatenate
x = torch.cat([u, v, uv_abs], dim=-1)
x

tensor([3, 3, 2, 2, 1, 1])

Vector $(u, v, |u-v|)$ is fed into a feed-forward neural network (FFNN). 

The FFNN processes the vector and outputs three activation values. One for each of our label classes; `entailment`, `neutral`, and `contradiction`.

```python
# we would initialize the feed-forward NN first
ffnn = torch.nn.Linear(768 * 3, 3)
...
# then later in the code process our concatenated vector with it
x = ffnn(x)
```

As these activations and label classes are aligned, we now calculate the softmax loss between them.

<img src="https://d33wubrfki0l68.cloudfront.net/48ecded904827a6e7c02dd72b9fe4a1f6227052b/13070/images/train-sentence-transformer-4.jpg" width=600 alt="SBERT Training">

Softmax loss is calculated by applying a softmax function across the three activation values (or nodes), producing a predicted label. 

We then use cross-entropy loss to calculate the difference between our predicted label and true label.

```python
# as before, we would initialize the loss function first
loss_func = torch.nn.CrossEntropyLoss()
...
# then later in the code add them to the process
x = loss_func(x, label)  # label is our *true* 0, 1, 2 class
```

The model is then optimized using this loss. 

We use an Adam optimizer with a learning rate of `2e-5` and a linear warmup period of 10% of the total training data for the optimization function. 

To set that up, we use the standard PyTorch Adam optimizer alongside a learning rate scheduler provided by HF transformers:

```python
# we would initialize everything first
optim = torch.optim.Adam(model.parameters(), lr=2e-5)
# and setup a warmup for the first ~10% steps
total_steps = int(dataset / batch_size)
warmup_steps = int(0.1 * total_steps)

scheduler = get_linear_schedule_with_warmup(optim, 
                                            num_warmup_steps=warmup_steps, 
                                            num_training_steps=total_steps - warmup_steps)
...
# then during the training loop we update the scheduler per step
scheduler.step()
```

Now let’s put all of that together in a PyTorch training loop.



In [19]:
# we would initialize everything first
optim = torch.optim.Adam(model.parameters(), lr=2e-5)
# and setup a warmup for the first ~10% steps
total_steps = int(len(dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)

scheduler = get_linear_schedule_with_warmup(optim, 
                                            num_warmup_steps=warmup_steps, 
                                            num_training_steps=total_steps - warmup_steps)
ffnn = torch.nn.Linear(768*3, 3)
loss_func = torch.nn.CrossEntropyLoss()

In [20]:
# 1 epoch should be enough, increase if wanted
for epoch in range(1):
  # make sure model is in training mode
  model.train()

  # initialize the dataloader loop with tqdm (tqdm == progress bar)
  loop = tqdm(dataloader, leave=True)
  for batch in loop:
    # zero all gradients on each new step
    optim.zero_grad()

    # prepare batches and more all to the active device
    inputs_ids_a = batch["premise_input_ids"].to(device)
    inputs_ids_b = batch["hypothesis_input_ids"].to(device)
    attention_a = batch["premise_attention_mask"].to(device)
    attention_b = batch["hypothesis_attention_mask"].to(device)
    label = batch["label"].to(device)

    # extract token embeddings from BERT
    u = model(inputs_ids_a, attention_mask=attention_a)[0]  # all token embeddings A
    v = model(inputs_ids_b, attention_mask=attention_b)[0]  # all token embeddings B

    # get the mean pooled vectors
    u = mean_pool(u, attention_a)
    v = mean_pool(v, attention_b)

    # build the |u-v| tensor
    uv = torch.sub(u, v)
    uv_abs = torch.abs(uv)

    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1)

    # process concatenated tensor through FFNN
    x = ffnn(x)

    # calculate the 'softmax-loss' between predicted and true label
    loss = loss_func(x, label)
    
    # using loss, calculate gradients and then optimize
    loss.backward()
    optim.step()

    # update learning rate scheduler
    scheduler.step()

    # update the TDQM progress bar
    loop.set_description(f"Epoch {epoch}")
    loop.set_postfix(loss=loss.item())

  0%|          | 0/58880 [00:00<?, ?it/s]

RuntimeError: ignored

We only train for a single epoch here. Realistically this should be enough.


The last thing we need to do is save the model.

In [None]:
model_path = "./sbert_test_a"

if not os.path.exists(model_path):
  os.mkdir(model_path)

model.save_pretrained(model_path)

Now let’s compare everything we’ve done so far with `sentence-transformers` training utilities.

##Fine-Tuning With Sentence Transformers