This article will cover what MNR loss is, the data it requires, and how to implement it to fine-tune our own high-quality sentence transformers.

Implementation will cover two training approaches. The first is more involved, and outlines the exact steps to fine-tune the model. 

* https://www.pinecone.io/learn/fine-tune-sentence-transformers-mnr/


In [25]:
import random
random.sample(range(100), 10)

[9, 10, 58, 47, 18, 94, 52, 24, 38, 15]

In [1]:
# https://huggingface.co/docs/datasets/process
import datasets
import random
snli = datasets.load_dataset('snli', split='train')
mnli = datasets.load_dataset('glue', 'mnli', split='train')
mnli = mnli.remove_columns(['idx']).select(random.sample(range(1000), 10))
snli = snli.cast(mnli.features).select(random.sample(range(1000), 10))

dataset = datasets.concatenate_datasets([snli, mnli])

del snli, mnli

Reusing dataset snli (C:\Users\piush\.cache\huggingface\datasets\snli\plain_text\1.0.0\1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
Reusing dataset glue (C:\Users\piush\.cache\huggingface\datasets\glue\mnli\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Loading cached processed dataset at C:\Users\piush\.cache\huggingface\datasets\snli\plain_text\1.0.0\1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b\cache-b5fc5c7da89ce7e6.arrow


In [2]:
#Because we are using MNR loss, we only want anchor-positive pairs. We can apply a filter to remove all other pairs (including erroneous -1 labels).
print(f"before: {len(dataset)} rows")
dataset = dataset.filter(
    lambda x: True if x['label'] == 0 else False
)
print(f"after: {len(dataset)} rows")

before: 20 rows


  0%|          | 0/1 [00:00<?, ?ba/s]

after: 11 rows


In [44]:
dataset[5]

{'premise': 'The rule contains information collection requirements which will allow EPA to determine that detergent additives which are effective in controlling deposits are used and that emission control goals are realized.',
 'hypothesis': 'The rule has data collection requirements which aid the EPA to realize their emission control goals.',
 'label': 0}

We must convert our human-readable sentences into transformer-readable tokens, so we go ahead and tokenize our sentences. Both premise and hypothesis features must be split into their own input_ids and attention_mask tensors.

In [45]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

dataset = dataset.map(
    lambda x: tokenizer(
            x['premise'], max_length=128, padding='max_length',
            truncation=True
        ), batched=True
)

dataset = dataset.rename_column('input_ids', 'anchor_ids')
dataset = dataset.rename_column('attention_mask', 'anchor_mask')

dataset

  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['premise', 'hypothesis', 'label', 'anchor_ids', 'token_type_ids', 'anchor_mask'],
    num_rows: 9
})

After that, we’re ready to initialize our DataLoader, which will be used for loading batches of data into our model during training.

In [70]:
dataset.set_format(type='torch', output_all_columns=True)

In [71]:
len(dataset['anchor_mask'] == 1)

9

In [72]:
import torch

batch_size = 2

loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

In [73]:
loader

<torch.utils.data.dataloader.DataLoader at 0x2207212e100>

And with that, our data is ready. Let’s move on to training.

#### PyTorch Fine-Tuning
When training SBERT models, we don’t start from scratch. Instead, we begin with an already pretrained BERT — all we need to do is fine-tune it for building sentence embeddings.

In [74]:
from transformers import BertModel

# start from a pretrained bert-base-uncased model
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


MNR and softmax loss training approaches use a * ‘siamese’*-BERT architecture during fine-tuning. Meaning that during each step, we process a sentence A (our anchor) into BERT, followed by sentence B (our positive).

Because these two sentences are processed separately, it creates a siamese-like network with two identical BERTs trained in parallel. In reality, there is only a single BERT being used twice in each step.

We can extend this further with triplet-networks. In the case of triplet networks for MNR, we would pass three sentences, an anchor, it’s positive, and it’s negative. However, we are not using triplet-networks, so we have removed the negative rows from our dataset (rows where label is 2).

BERT outputs 512 768-dimensional embeddings. We convert these into averaged sentence embeddings using mean-pooling. Using the siamese approach, we produce two of these per step — one for the anchor that we will call a, and another for the positive called p

In [75]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

In the mean_pool function, we’re taking these token-level embeddings (the 512) and the sentence attention_mask tensor. We resize the attention_mask to match the higher 768-dimensionality of the token embeddings.

The resized mask in_mask is applied to the token embeddings to exclude padding tokens from the mean pooling operation. Mean-pooling takes the average activation of values across each dimension but excluding those padding values, which would reduce the average activation. This operation transformers our token-level embeddings (shape 512*768) to sentence-level embeddings (shape 1*768).

These steps are performed in batches, meaning we do this for many (anchor, positive) pairs in parallel. That is important in our next few steps.

Let’s put that all together and set up a training loop. First, we move our model and layers to a CUDA-enabled GPU if available.

In [76]:
# set device and move model there
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
print(f'moved to {device}')

moved to cuda


In [77]:
# define layers to be used in multiple-negatives-ranking
cos_sim = torch.nn.CosineSimilarity()
loss_func = torch.nn.CrossEntropyLoss()
scale = 20.0  # we multiply similarity score by this scale value
# move layers to device
cos_sim.to(device)
loss_func.to(device)

CrossEntropyLoss()

Then we set up the optimizer and schedule for training. We use an Adam optimizer with a linear warmup for 10% of the total number of steps.

In [78]:
anchors = dataset['anchor_mask'] == 1

In [79]:
from transformers.optimization import get_linear_schedule_with_warmup

# initialize Adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=2e-5)

# setup warmup for first ~10% of steps
total_steps = int(len(anchors) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
    optim, num_warmup_steps=warmup_steps,
    num_training_steps=total_steps-warmup_steps
)

And now we define the training loop, using the same training process that we worked through before.

In [80]:
from tqdm.auto import tqdm
epochs = 1

# 1 epoch should be enough, increase if wanted
for epoch in range(epochs):
    model.train()  # make sure model is in training mode
    # initialize the dataloader loop with tqdm (tqdm == progress bar)
    loop = tqdm(loader, leave=True)
    for batch in loop:
        # zero all gradients on each new step
        optim.zero_grad()
        # prepare batches and more all to the active device
        anchor_ids = batch['anchor']['input_ids'].to(device)
        anchor_mask = batch['anchor']['attention_mask'].to(device)
        pos_ids = batch['positive']['input_ids'].to(device)
        pos_mask = batch['positive']['attention_mask'].to(device)
        # extract token embeddings from BERT
        a = model(
            anchor_ids, attention_mask=anchor_mask
        )[0]  # all token embeddings
        p = model(
            pos_ids, attention_mask=pos_mask
        )[0]
        # get the mean pooled vectors
        a = mean_pool(a, anchor_mask)
        p = mean_pool(p, pos_mask)
        # calculate the cosine similarities
        scores = torch.stack([
            cos_sim(
                a_i.reshape(1, a_i.shape[0]), p
            ) for a_i in a])
        # get label(s) - we could define this before if confident of consistent batch sizes
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)
        # and now calculate the loss
        loss = loss_func(scores*scale, labels)
        # using loss, calculate gradients and then optimize
        loss.backward()
        optim.step()
        # update learning rate scheduler
        scheduler.step()
        # update the TDQM progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

  0%|          | 0/5 [00:00<?, ?it/s]

TypeError: new(): invalid data type 'numpy.str_'

### TO DO 
There is some issue with dataset set_format. Re-run this example after looking into it.

In [None]:
# With that, we’ve fine-tuned our BERT model using MNR loss. Now we save it to file.

import os

model_path = './sbert_test_mnr'

if not os.path.exists(model_path):
    os.mkdir(model_path)

model.save_pretrained(model_path)

And this can now be loaded using either the SentenceTransformer or HF from_pretrained methods. 

### Fast Fine-Tuning
As we already mentioned, there is an easier way to fine-tune models using MNR loss. The sentence-transformers library allows us to use pretrained sentence transformers and comes with some handy training utilities.

We will start by preprocessing our data. This is the same as we did before for the first few steps.

Before, we tokenized our data and then loaded it into a PyTorch DataLoader. This time we follow a slightly different format. We * don’t* tokenize; we reformat into a list of sentence-transformers InputExample objects and use a slightly different DataLoader.

In [3]:
from sentence_transformers import InputExample
from tqdm.auto import tqdm  # so we see progress bar

train_samples = []
for row in tqdm(dataset):
    train_samples.append(InputExample(
        texts=[row['premise'], row['hypothesis']]
    ))

  0%|          | 0/11 [00:00<?, ?it/s]

In [4]:
from sentence_transformers import datasets

batch_size = 32

loader = datasets.NoDuplicatesDataLoader(
    train_samples, batch_size=batch_size)

Our InputExample contains just our a and p sentence pairs, which we then feed into the NoDuplicatesDataLoader object. This data loader ensures that each batch is duplicate-free — a helpful feature when ranking pair similarity across randomly sampled pairs with MNR loss.

Now we define the model. The sentence-transformers library allows us to build models using modules. We need just a transformer model (we will use bert-base-uncased again) and a mean pooling module.

In [5]:
from sentence_transformers import models, SentenceTransformer

bert = models.Transformer('bert-base-uncased')
pooler = models.Pooling(
    bert.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(modules=[bert, pooler])

model

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

We now have an initialized model. Before training, all that’s left is the loss function — MNR loss.

In [6]:
from sentence_transformers import losses

loss = losses.MultipleNegativesRankingLoss(model)

And with that, we have our data loader, model, and loss function ready. All that’s left is to fine-tune the model! As before, we will train for a single epoch and warmup for the first 10% of our training steps.

In [7]:
epochs = 1
warmup_steps = int(len(loader) * epochs * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    #output_path='./sbert_test_mnr2',
    show_progress_bar=True
)  # I set 'show_progress_bar=False' as it printed every step
#    on to a new line

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration: 0it [00:00, ?it/s]

In [8]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)