In [15]:
import logging
import traceback

import torch
from datasets import load_dataset

from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderModelCardData
from sentence_transformers.cross_encoder.evaluation import (
    CrossEncoderNanoBEIREvaluator,
    CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

In [37]:
# model_name = "answerdotai/ModernBERT-base"
# model_name = "sergeyzh/BERTA"
model_name = "ai-forever/ru-en-RoSBERTa"

train_batch_size = 64 if torch.cuda.is_available() else 8
num_epochs = 1
num_hard_negatives = 5  # How many hard negatives should be mined for each question-answer pair

In [38]:
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = CrossEncoder(
    model_name,
    model_card_data=CrossEncoderModelCardData(
        language="ru-en",
        license="apache-2.0",
        model_name=f"{model_name} trained on GooAQ",
    ),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at ai-forever/ru-en-RoSBERTa and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-04-20 17:59:24 - Use pytorch device: cpu


Model max length: 512
Model num labels: 1


In [18]:
scores = model.predict([
    ("Сегодня прекрасный день на улице.", "Сегодня так солнечно!"),
    ("Сегодня прекрасный день на улице.", "Он поехал на работу пораньше.")])
scores

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

array([0.54257035, 0.51132107], dtype=float32)

In [19]:
# 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq

logging.info("Read the gooaq training dataset")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)

2025-04-20 17:40:39 - Read the gooaq training dataset
2025-04-20 17:40:41 - Dataset({
    features: ['question', 'answer'],
    num_rows: 99000
})
2025-04-20 17:40:41 - Dataset({
    features: ['question', 'answer'],
    num_rows: 1000
})


In [20]:
# Print 3 random samples from the training dataset
print("Sample questions and answers from training dataset:")
for i, sample in enumerate(train_dataset.select(range(3))):
    print(f"\nSample {i+1}:")
    print(f"Question: {sample['question']}")
    print(f"Answer: {sample['answer']}")

Sample questions and answers from training dataset:

Sample 1:
Question: what are the 5 characteristics of a star?
Answer: Key Concept: Characteristics used to classify stars include color, temperature, size, composition, and brightness.

Sample 2:
Question: are copic markers alcohol ink?
Answer: Copic Ink is alcohol-based and flammable. Keep away from direct sunlight and extreme temperatures.

Sample 3:
Question: what is the difference between appellate term and appellate division?
Answer: Appellate terms An appellate term is an intermediate appellate court that hears appeals from the inferior courts within their designated counties or judicial districts, and are intended to ease the workload on the Appellate Division and provide a less expensive forum closer to the people.


In [None]:
# 2b. Modify our training dataset to include hard negatives using a very efficient embedding model

embedding_model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
emb_model_batch_size = 4096

# embedding_model_name = "sergeyzh/BERTA"
# emb_model_batch_size = 512

# embedding_model_name = "sergeyzh/LaBSE-ru-turbo"
# emb_model_batch_size = 512

embedding_model = SentenceTransformer(embedding_model_name, device="cpu")
hard_train_dataset = mine_hard_negatives(
    train_dataset,
    embedding_model,
    num_negatives=num_hard_negatives,  # How many negatives per question-answer pair
    absolute_margin=0,  # Similarity between query and negative samples should be x lower than query-positive similarity
    range_min=0,  # Skip the x most similar samples
    range_max=100,  # Consider only the x most similar samples
    sampling_strategy="top",  # Sample the top negatives from the range
    batch_size=emb_model_batch_size,  # Use a batch size of 4096 for the embedding model
    output_format="labeled-pair",  # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
    use_faiss=True,
)
logging.info(hard_train_dataset)

2025-04-20 17:42:47 - Load pretrained SentenceTransformer: sentence-transformers/static-retrieval-mrl-en-v1


Batches:   0%|          | 0/21 [00:00<?, ?it/s]

Batches:   0%|          | 0/25 [00:00<?, ?it/s]

2025-04-20 17:43:01 - Loading faiss with AVX2 support.
2025-04-20 17:43:01 - Successfully loaded faiss with AVX2 support.
2025-04-20 17:43:01 - Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes.
Querying FAISS index: 100%|██████████| 7/7 [02:35<00:00, 22.17s/it]


Metric       Positive       Negative     Difference
Count          99,000        479,402               
Mean           0.5882         0.5000         0.0988
Median         0.5989         0.5000         0.0610
Std            0.1425         0.1098         0.1040
Min           -0.0514         0.1552         0.0000
25%            0.4993         0.4215         0.0179
50%            0.5989         0.5000         0.0610
75%            0.6889         0.5777         0.1494
Max            0.9748         0.9508         0.7159
Skipped 878,050 potential negatives (8.78%) due to the absolute_margin of 0.
Could not find enough negatives for 15598 samples (3.15%). Consider adjusting the range_max and absolute_margin parameters if you'd like to find more valid negatives.


2025-04-20 17:45:53 - Dataset({
    features: ['question', 'answer', 'label'],
    num_rows: 578402
})


In [13]:
import numpy as np

# Print 3 random samples from the hard negatives dataset
random_indices = np.random.choice(np.arange(len(hard_train_dataset)), 3, replace=False)
print("Sample questions and answers from hard negatives dataset:")
for i, sample in enumerate(hard_train_dataset.select(random_indices)):
    print(f"\nSample {i+1}:")
    print(f"Question: {sample['question']}")
    print(f"Answer: {sample['answer']}")
    print(f"Label: {sample['label']}")


Sample questions and answers from hard negatives dataset:

Sample 1:
Question: what are the three categories of micro aggressive acts?
Answer: ['Microbiological hazards. Microbiological hazards include bacteria, yeasts, moulds and viruses.', 'Chemical hazards. ... ', 'Physical hazards. ... ', 'Allergens.']
Label: 0

Sample 2:
Question: what is difference between get and post method in rest api?
Answer: Both GET and POST method is used to transfer data from client to server in HTTP protocol but Main difference between POST and GET method is that GET carries request parameter appended in URL string while POST carries request parameter in message body which makes it more secure way of transferring data from client to ...
Label: 1

Sample 3:
Question: who has arya killed with brown eyes?
Answer: The eyes are always the same size from birth to death. Baby eyes are proportionally larger than adult eyes, but they are still smaller.
Label: 0


In [16]:
from datasets import load_from_disk

# 2c. (Optionally) Save the hard training dataset to disk
hard_train_dataset.save_to_disk("datasets/gooaq-hard-train")
# Load again with:
hard_train_dataset = load_from_disk("datasets/gooaq-hard-train")

Saving the dataset (0/1 shards):   0%|          | 0/578402 [00:00<?, ? examples/s]

In [39]:
# 3. Define our training loss.
# pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))

# 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
    dataset_names=[
        "msmarco", "nfcorpus", "nq"
        # "rus-NFCorpus"
    ],
    batch_size=train_batch_size,
)

                                                                        

In [28]:
# 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
# We include the positive answer in the list of negatives, so the evaluator can use the performance of the
# embedding model as a baseline.
hard_eval_dataset = mine_hard_negatives(
    eval_dataset,
    embedding_model,
    corpus=full_dataset["answer"],  # Use the full dataset as the corpus
    num_negatives=30,  # How many documents to rerank
    batch_size=4096,
    include_positives=True,
    output_format="n-tuple",
    use_faiss=True,
)
logging.info(hard_eval_dataset)

Setting range_max to 31 based on the provided parameters.


Batches:   0%|          | 0/22 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Querying FAISS index: 100%|██████████| 1/1 [00:01<00:00,  1.80s/it]
2025-04-20 17:47:32 - Dataset({
    features: ['question', 'answer', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'negative_9', 'negative_10', 'negative_11', 'negative_12', 'negative_13', 'negative_14', 'negative_15', 'negative_16', 'negative_17', 'negative_18', 'negative_19', 'negative_20', 'negative_21', 'negative_22', 'negative_23', 'negative_24', 'negative_25', 'negative_26', 'negative_27', 'negative_28', 'negative_29', 'negative_30'],
    num_rows: 1000
})


Metric       Positive       Negative     Difference
Count           1,000         30,000               
Mean           0.5853         0.4550         0.1303
Median         0.5996         0.4488         0.1176
Std            0.1434         0.1217         0.1602
Min            0.0475         0.1611        -0.5592
25%            0.4970         0.3644         0.0149
50%            0.5996         0.4488         0.1176
75%            0.6784         0.5365         0.2318
Max            0.9532         0.9604         0.6286


In [29]:
reranking_evaluator = CrossEncoderRerankingEvaluator(
    samples=[
        {
            "query": sample["question"],
            "positive": [sample["answer"]],
            "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
        }
        for sample in hard_eval_dataset
    ],
    batch_size=train_batch_size,
    name="gooaq-dev",
    always_rerank_positives=False,
)

In [32]:
from dotenv import load_dotenv
import os
import wandb

load_dotenv()  # Load environment variables from a .env file

wandb_api_key = os.getenv("WANDB_API_KEY")
if wandb_api_key is None:
    raise ValueError("WANDB_API_KEY not found in .env file.")

wandb.login(key=wandb_api_key)
print("Logged into wandb")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/anton/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtony-pitchblack[0m ([33moverfit1010[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Logged into wandb


In [None]:
# 4c. Combine the evaluators & run the base model on them
evaluator = SequentialEvaluator([
    reranking_evaluator,
    nano_beir_evaluator
])

evaluator(model)

In [None]:
# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-{short_model_name}-gooaq-bce"
args = CrossEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    dataloader_num_workers=4,
    load_best_model_at_end=True,
    metric_for_best_model="eval_gooaq-dev_ndcg@10",
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=200,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    seed=12,
)

# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
    model=model,
    args=args,
    train_dataset=hard_train_dataset,
    loss=loss,
    evaluator=evaluator,
)
trainer.train()

2025-04-20 17:57:12 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:


KeyboardInterrupt: 

In [None]:
# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)

In [None]:
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)

# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
    model.push_to_hub(run_name)
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub('{run_name}')`."
    )