In [None]:
import datasets

snli = datasets.load_dataset('snli', split='train')
mnli = datasets.load_dataset('glue', 'mnli', split='train')

snli = snli.cast(mnli.features)

dataset = datasets.concatenate_datasets([snli, mnli])

del snli, mnli

In [None]:
print(f"before: {len(dataset)} rows")
dataset = dataset.filter(
    lambda x: True if x['label'] == 0 else False
)
print(f"after: {len(dataset)} rows")

In [None]:
from sentence_transformers import InputExample
from tqdm.auto import tqdm  # so we see progress bar

train_samples = []
for row in tqdm(nli):
    train_samples.append(InputExample(
        texts=[row['premise'], row['hypothesis']]
    ))

In [None]:
from sentence_transformers import datasets

batch_size = 32

loader = datasets.NoDuplicatesDataLoader(
    train_samples, batch_size=batch_size)

In [None]:
from sentence_transformers import models, SentenceTransformer

bert = models.Transformer('bert-base-uncased')
pooler = models.Pooling(
    bert.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(modules=[bert, pooler])

model

In [None]:
from sentence_transformers import losses

loss = losses.MultipleNegativesRankingLoss(model)

In [None]:
epochs = 1
warmup_steps = int(len(loader) * epochs * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='./sbert_test_mnr2',
    show_progress_bar=False
)  # I set 'show_progress_bar=False' as it printed every step
#    on to a new line