In [1]:
from datasets import load_dataset

In [2]:
data_files = {"train": "all-nli-qwen-reranker-train.parquet",
              "test": "all-nli-qwen-reranker-test.parquet",
              "dev": "all-nli-qwen-reranker-dev.parquet" }
train_dataset = load_dataset("parquet", data_files=data_files, split="train")

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating dev split: 0 examples [00:00, ? examples/s]

In [3]:
train_dataset

Dataset({
    features: ['sentence1', 'sentence2', 'score'],
    num_rows: 942069
})

In [4]:
train_dataset[0]

{'sentence1': 'A person on a horse jumps over a broken down airplane.',
 'sentence2': 'A person is training his horse for a competition.',
 'score': 0.008453369140625}

In [5]:
from sentence_transformers import SentenceTransformer

In [6]:
model_name = "answerdotai/ModernBERT-base"
model = SentenceTransformer(model_name, device="cuda")

No sentence-transformers model found with name answerdotai/ModernBERT-base. Creating a new one with mean pooling.
Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)`


In [7]:
model.to('cuda')

SentenceTransformer(
  (0): Transformer({'max_seq_length': 8192, 'do_lower_case': False, 'architecture': 'ModernBertModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [8]:
from sentence_transformers.losses import CoSENTLoss
loss = CoSENTLoss(model)

In [9]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/ModernSBERT",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=False,
    bf16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
)

In [10]:
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction

# Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("parquet", data_files=data_files, split="dev")

# Initialize the evaluator
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    scores=eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="all-nli-dev",
)

In [11]:
test_dataset = load_dataset("parquet", data_files=data_files, split="test")

In [12]:
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=test_dataset["sentence1"],
    sentences2=test_dataset["sentence2"],
    scores=test_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
)

In [13]:
len(test_dataset)

19656

In [14]:
test_evaluator(model)

{'pearson_cosine': 0.2906708668132107, 'spearman_cosine': 0.3960724014157506}

In [15]:
from sentence_transformers import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,4.8134
1000,4.649
1500,4.5356
2000,4.4374
2500,4.3812
3000,4.3505
3500,4.331
4000,4.3341
4500,4.293
5000,4.2945


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [16]:
model.save_pretrained("ModernSBERT-base-all-nli-qwen-reranker/final")

In [None]:
model = SentenceTransformer("./ModernSBERT-base-all-nli-qwen-reranker/final")

In [17]:
test_evaluator(model)

{'pearson_cosine': 0.8129172295846817, 'spearman_cosine': 0.9003960149946096}