In [None]:
! pip install -Uq transformers sentence_transformers datasets fsspec faiss-gpu-cu12

In [None]:
import wandb
wandb.init(mode="disabled")

### **Load Dataset**

In [None]:
from datasets import load_dataset

am_relevance_dataset = load_dataset("yosefw/amharic-news-retrieval-dataset-v2-with-negatives-V2")
am_relevance_dataset

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


DatasetDict({
    train: Dataset({
        features: ['query_id', 'passage_id', 'query', 'passage', 'category', 'link', 'source_dataset', 'negative_passages'],
        num_rows: 61469
    })
    test: Dataset({
        features: ['query_id', 'passage_id', 'query', 'passage', 'category', 'link', 'source_dataset', 'negative_passages'],
        num_rows: 6832
    })
})

In [None]:
test_passage_ids = set(am_relevance_dataset["test"]["passage_id"])
len(test_passage_ids)

6764

In [None]:
# Binary dataset

from datasets import Dataset
from tqdm import tqdm
import random

ds_rows = []
for row in tqdm(am_relevance_dataset["train"]):
  ds_rows.append({
      "query_id": row["query_id"],
      "passage_id": row["passage_id"],
      "query": row["query"],
      "passage": row["passage"],
      "label": 1
    })

  neg_passages = row["negative_passages"]

  # Remove negative passages that exist in the test set
  # neg_passages = list(filter(lambda x: x["passage_id"] not in test_passage_ids, neg_passages))

  for neg in neg_passages[:3]+neg_passages[-4:]:
    ds_rows.append({
      "query_id": row["query_id"],
      "passage_id": neg["passage_id"],
      "query": row["query"],
      "passage": neg["passage"],
      "label": 0
    })

# ds_rows = ds_rows[:64000] #
relevance_train_dataset = Dataset.from_list(ds_rows).shuffle(seed=42)
relevance_train_dataset

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 61469/61469 [00:27<00:00, 2248.76it/s]


Dataset({
    features: ['query_id', 'passage_id', 'query', 'passage', 'label'],
    num_rows: 491752
})

In [None]:
relevance_test_dataset = am_relevance_dataset["test"]#.select(range(100))
relevance_test_dataset

Dataset({
    features: ['query_id', 'passage_id', 'query', 'passage', 'category', 'link', 'source_dataset', 'negative_passages'],
    num_rows: 6832
})

### **Model Training**

In [None]:
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderModelCardData
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator

# First, we define the transformer model we want to fine-tune
model_name = "rasyosef/roberta-medium-amharic"

# We set num_labels=1, which predicts a continuous score between 0 and 1
model = CrossEncoder(
    model_name,
    num_labels=1,
    max_length=510,
    device="cuda",
    model_card_data=CrossEncoderModelCardData(
        language="am",
        license="mit",
        model_name="roberta-amharic-reranker-medium",
    ),
  )

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at rasyosef/roberta-medium-amharic and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from sentence_transformers.util import mine_hard_negatives
from sentence_transformers import SentenceTransformer

# Download from the ü§ó Hub
embedding_model = SentenceTransformer(
    "rasyosef/roberta-amharic-text-embedding-base",
    device="cuda",
  )

hard_eval_dataset = mine_hard_negatives(
        relevance_test_dataset.select_columns(["query", "passage"]),
        embedding_model,
        corpus=am_relevance_dataset["train"]["passage"], #[:8000],  # Use the full dataset as the corpus
        num_negatives=50,  # How many documents to rerank
        batch_size=128,
        include_positives=False,
        output_format="n-tuple",
        use_faiss=True,
    )

print(hard_eval_dataset)

Setting range_max to 52 based on the provided parameters.
Found 6828 unique queries out of 6832 total queries.
Found an average of 1.001 positives per query.


Batches:   0%|          | 0/517 [00:00<?, ?it/s]

Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Querying FAISS index: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  8.29it/s]


Metric       Positive       Negative     Difference
Count           6,832        341,600               
Mean           0.7040         0.5328         0.1713
Median         0.7382         0.5228         0.1890
Std            0.1498         0.0886         0.1491
Min           -0.0702         0.2566        -0.6424
25%            0.6270         0.4710         0.0808
50%            0.7382         0.5228         0.1890
75%            0.8155         0.5850         0.2812
Max            0.9446         0.9402         0.5229
Dataset({
    features: ['query', 'passage', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'negative_9', 'negative_10', 'negative_11', 'negative_12', 'negative_13', 'negative_14', 'negative_15', 'negative_16', 'negative_17', 'negative_18', 'negative_19', 'negative_20', 'negative_21', 'negative_22', 'negative_23', 'negative_24', 'negative_25', 'negative_26', 'negative_27', 'negative_28', 'negative_29', 'negative

In [None]:
row = hard_eval_dataset[0]
for col in hard_eval_dataset.column_names:
  print(col,": ", row[col][:64])

query :  ·ãà·ã∞ ·àÉ·åà·à≠ ·â§·âµ ·ã®·â∞·àò·àà·à± ·ã®·çñ·àà·â≤·ä´ ·ãµ·à≠·åÖ·â∂·âΩ ·â†·àù·à≠·å´ ·â¶·à≠·ãµ ·ä†·àç·â∞·àò·ãò·åà·â°·àù ·â∞·â£·àà 
passage :  ‚Äú·ä®·ãú·åç·äê·âµ ·åã·à≠ ·ã®·â∞·ã´·ã´·ãò·ãç ·àÖ·åç ·ä•·äï·âÖ·çã·âµ ·àÜ·äñ·â•·äì·àç‚Äù·â†·àò·äï·åç·àµ·âµ ·å•·à™ ·â∞·ã∞·à≠·åé·àã·â∏·ãç ·ä®·ãç·å≠ ·ä†·åà·à´·âµ ·ãà·ã∞ ·àÉ·åà
negative_1 :  ·ä†·ã≤·àµ ·ä†·â†·â£·ç°- ·çì·à≠·â≤·ãé·âΩ ·â†·ä†·ã≤·à± ·ä†·ãã·åÖ ·àò·à∞·à®·âµ ·àà·àò·àò·ãù·åà·â• ·ã®·â∞·à∞·å£·â∏·ãç ·ã®·åä·ãú ·åà·ã∞·â• ·àä·å†·äì·âÄ·âÖ 18 ·âÄ·äì·âµ
negative_2 :  ·â†·â¶·à≠·ãµ ·àò·à∞·à®·ãõ·â∏·ãç ·ã®·â∞·åà·àà·å∏·ãç ·â†·â∞·àà·ã´·ã© ·åä·ãú·ã´·âµ ·â∞·àò·àµ·à≠·â∞·ãç ·â†·â•·àî·à≠·äì ·â†·ä†·åà·à≠ ·ä†·âÄ·çç ·çì·à≠·â≤·äê·âµ ·â†·àï·åã·ãä ·àò
negative_3 :  ·ã®·ä¢·âµ·ãÆ·åµ·ã´ ·â•·àî·à´·ãä ·àù·à≠·å´ ·â¶·à≠·ãµ 27 ·ã®·çñ·àà·â≤·ä´ ·çì·à≠·â≤·ãé·âΩ·äï ·àò·à∞·à®·ãô·äï ·ä†·àµ·â≥·ãà·âÄ\n·âÄ·ã∞·àù ·à≤·àç ·â†·äê·â†·à®·ãç ·ä†·à∞
negative_4 :  ·ã®·ä¢·âµ·ãÆ·åµ·ã´ ·â•·àî·à´·ãä ·àù·à≠·å´ ·â¶·à≠·ãµ ·â†·ä†·ã≤·à± ·àÖ·åç ·àã·ã≠ ·ã®·â∞·âÄ·àò·å° ·àò·àµ·çà·à≠·â∂·âΩ·äï ·ä†·àã·à

In [None]:
# We add an evaluator, which evaluates the performance during training]
dev_evaluator = CrossEncoderRerankingEvaluator(
  samples=[
    {
        "query": row["query"],
        "positive": [row["passage"]],
        "negative": [row[column_name] for column_name in hard_eval_dataset.column_names[2:]],
    }
    for row in hard_eval_dataset.shuffle(seed=42).select(range(1600))
  ],
  at_k=10,
  batch_size=128,
  name="amh-passage-retrieval-dev",
  always_rerank_positives=True,
  show_progress_bar=True
)

In [None]:
dev_evaluator(model)



{'amh-passage-retrieval-dev_map': 0.08924252914778016,
 'amh-passage-retrieval-dev_mrr@10': 0.05746974206349206,
 'amh-passage-retrieval-dev_ndcg@10': 0.08978664796439237}

#### **Train**

In [None]:
import torch
from sentence_transformers.cross_encoder import CrossEncoderTrainer, CrossEncoderTrainingArguments
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.training_args import BatchSamplers

run_name = f"roberta-amharic-reranker-medium"

num_train_epochs = 4
train_batch_size = 64

args = CrossEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    learning_rate=4e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    # weight_decay=0.1,
    fp16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    dataloader_num_workers=2,
    load_best_model_at_end=True,
    metric_for_best_model="amh-passage-retrieval-dev_mrr@10",
    # Optional tracking/debugging parameters:
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=2,
    # logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    seed=42,
)

loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(7))

# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
    model=model,
    args=args,
    train_dataset=relevance_train_dataset.select_columns(['query', 'passage', 'label']),
    loss=loss,
    evaluator=dev_evaluator,
)

Token indices sequence length is longer than the specified maximum sequence length for this model (653 > 510). Running this sequence through the model will result in indexing errors


In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Amh-passage-retrieval-dev Map,Amh-passage-retrieval-dev Mrr@10,Amh-passage-retrieval-dev Ndcg@10
1,0.4048,No log,0.791953,0.7893,0.828904
2,0.2366,No log,0.823457,0.821167,0.85464
3,0.1588,No log,0.796708,0.794708,0.835278
4,0.1024,No log,0.822961,0.821262,0.85509




TrainOutput(global_step=30736, training_loss=0.22566378234513781, metrics={'train_runtime': 4898.576, 'train_samples_per_second': 401.547, 'train_steps_per_second': 6.274, 'total_flos': 0.0, 'train_loss': 0.22566378234513781, 'epoch': 4.0})

In [None]:
# Save latest model
model.save_pretrained(run_name + "-latest")

# **Evaluation**

In [None]:
# We add an evaluator, which evaluates the performance during training]
evaluator = CrossEncoderRerankingEvaluator(
  samples=[
    {
        "query": row["query"],
        "positive": [row["passage"]],
        "negative": [row[column_name] for column_name in hard_eval_dataset.column_names[2:]],
    }
    for row in hard_eval_dataset
  ],
  at_k=10,
  batch_size=128,
  name="amh-passage-retrieval-dev",
  always_rerank_positives=True,
  show_progress_bar=True
)

evaluator(model)



{'amh-passage-retrieval-dev_map': 0.8267625270131399,
 'amh-passage-retrieval-dev_mrr@10': 0.8246759530314858,
 'amh-passage-retrieval-dev_ndcg@10': 0.8578685327553554}

In [None]:
from sentence_transformers.cross_encoder import CrossEncoder

# model = CrossEncoder("rasyosef/roberta-amharic-reranker-medium")
scores = model.predict([
    ["·â†·àà·äï·ã∞·äï ·ã®·ã©·ä≠·à¨·äï ·ã∞·åã·çä ·ä†·åà·à´·âµ ·àò·à™·ãé·âΩ ·åâ·â£·ä§", "·åÄ·à≠·àò·äï ·ä†·åà·à≠ ·ã®·â∞·àõ·à© ·ä¢·âµ·ãÆ·åµ·ã´·ãâ·ã´·äï ·âÅ·å•·à≠ ·ä®·ä†·çç·à™·âÉ ·ä†·äï·ã∞·äõ ·äê·ãâ·ç¢ ·ä•·äê ·ä´·àú·à©·äï ·ä®·äê·ä¨·äï·ã´·äï ·àÅ·àâ ·ã≠·â†·àç·å£·àç·ç¢ ·â†·àõ·äÖ·â†·à≠ ·ã®·àò·ã∞·à´·åÄ·â≥·âΩ·äï ·ãã·äì·ãâ ·ãì·àã·àõ ·ã®·ä¢·âµ·ãÆ-·åÄ·à≠·àò·äï·äï ·åç·äï·äô·äê·âµ·äï ·ä®·çç ·àõ·ãµ·à®·åç ·äê·ãâ·ç¢ ·ä® 10 ·à∫·àÖ ·â†·àã·ã≠ ·ä¢·âµ·ãÆ·åµ·ã´·ãâ·ã´·äï ·â∞·àõ·à™·ãé·âΩ ·â†·åÄ·à≠·àò·äï ·ä†·åà·à≠ ·â∞·àù·à®·ãâ ·â∞·àò·àç·à∞·ãã·àç·ç¢ ·ä†·àç·ã´·àù ·ä†·ãâ·àÆ·å≥ ·ãà·ã∞ ·ä†·àú·à™·ä´·àù ·ã®·â∞·àª·åà·à© ·ä†·àâ·ç¢"],
    ["·â†·àà·äï·ã∞·äï ·ã®·ã©·ä≠·à¨·äï ·ã∞·åã·çä ·ä†·åà·à´·âµ ·àò·à™·ãé·âΩ ·åâ·â£·ä§", "·â†·åã·àù·â§·àã ·ä•·àµ·ä´·àÅ·äï 23 ·à∞·ãé·âΩ ·â†·äÆ·àå·à´ ·àò·àû·â≥·â∏·ãç·äï ·ã®·â∞·äì·åà·à©·âµ ·ä†·äï·ãµ ·ã®·å§·äì ·â£·àà·àô·ã´ ·â†·àΩ·â≥·ãç·äï ·àà·àò·ä®·àã·ä®·àç ·ä†·äï·ãµ ·ã®·åç·àç ·â∞·âã·àù ·ãµ·åã·çç ·àõ·ãµ·à®·åâ·äï ·ä•·äì ·â†·ãì·àà·àù ·å§·äì ·ãµ·à≠·åÖ·âµ ·ã®·àÖ·ä≠·àù·äì ·ãµ·åã·çç ·àò·à∞·å†·â±·äï ·ä†·â•·à´·à≠·â∞·ãã·àç·ç°·ç° ·ä®·ã∞·â°·â• ·à±·ã≥·äï ·ä†·åé·à´·â£·âΩ ·ä†·äÆ·â¶ ·ãà·à®·ã≥ ·àà·àò·åÄ·àò·à™·ã´ ·åä·ãú ·àò·ä®·à∞·â± ·â†·â∞·äê·åà·à®·ãç n·â†·àΩ·â≥·ãç ·â•·ãô ·à∞·ãé·âΩ ·ã®·â∞·å†·âÅ·âµ ·ãã·äï·â±·ãã ·â†·â∞·â£·àà ·ãà·à®·ã≥ ·àò·àÜ·äë·äï ·å†·âÅ·àò·ãã·àç·ç°·ç°"],
    ["·â†·àà·äï·ã∞·äï ·ã®·ã©·ä≠·à¨·äï ·ã∞·åã·çä ·ä†·åà·à´·âµ ·àò·à™·ãé·âΩ ·åâ·â£·ä§", "·â†·ã∞·àù ·ä•·å•·à®·âµ ·àù·ä≠·äï·ã´·âµ ·àà·àÖ·àô·àõ·äï ·ä†·àµ·çà·àã·åä·ãâ·äï ·àÖ·ä≠·àù·äì ·àà·àò·àµ·å†·âµ ·â∞·â∏·åç·à®·äì·àç ·à≤·àâ ·ã®·â∞·äì·åà·à©·âµ ·ã®·âÜ·â¶ ·àÜ·àµ·çí·â≥·àç ·àµ·à´ ·ä†·àµ·ä™·ã´·åÖ ·â¢·äñ·à©·àù ·ã®·ã∞·àù ·ä•·å•·à®·â± ·àÄ·ä™·àû·âΩ ·ä†·åà·àç·åç·àé·âµ ·ä•·äï·ã≤·à∞·å° ·ä†·àã·àµ·âª·àã·â∏·ãâ·àù ·ã≠·àã·àâ·ç¢ ·ã®·ãà·àç·ãµ·ã´ ·ã∞·àù ·â£·äï·ä≠ ·ä†·åà·àç·åç·àé·âµ ·àµ·à´ ·ä†·àµ·ä™·ã´·åÖ ·ä†·àÅ·äï ·àà·ä†·àµ·â∏·ä≥·ã≠ ·àÖ·ä≠·àù·äì ·â•·âª ·ã®·àö·àÜ·äï ·àà20 ·âÄ·äì·âµ ·â•·âª ·ã®·àö·âÜ·ã≠ ·ã∞·àù ·äê·ãâ ·ã´·àà·äï ·ã≠·àã·àâ"],
    ["·â†·àà·äï·ã∞·äï ·ã®·ã©·ä≠·à¨·äï ·ã∞·åã·çä ·ä†·åà·à´·âµ ·àò·à™·ãé·âΩ ·åâ·â£·ä§", "·ä†·àú·à™·ä´ ·àà·ãç·å≠ ·ã®·àù·âµ·à∞·å†·ãç·äï ·ä•·à≠·ã≥·â≥ ·àõ·âÜ·àü ·â†·â∞·àà·ã≠ ·â†·ä†·çç·à™·âÉ ·àÉ·åà·à´·âµ ·ä®·çç·â∞·äõ ·åç·à≠·â≥·äì ·ãµ·äï·åã·å§ ·çà·å•·àØ·àç·ç¢ ·àà·àÅ·àà·â∞·äõ ·åä·ãú ·ãà·ã∞ ·ã®·à•·àç·å£·äï ·àò·äï·â†·à≠ ·ã®·àò·å°·âµ ·ã®·ä†·àú·à™·ä´·ãç ·çï·à¨·ãù·ã∞·äï·âµ ·ã∂·äì·àç·ãµ ·âµ·à´·àù·çï ·ã®·ãç·å≠ ·ä•·à≠·ã≥·â≥ ·àà·ãò·å†·äì ·âÄ·äì·âµ ·ä•·äï·ã≤·âÜ·àù ·ãà·àµ·äê·ãã·àç·ç¢ ·âµ·à´·àù·çï ·ä•·à≠·ã≥·â≥ ·ä†·çç·à™·âÉ·ãç·ã´·äï·äï ·ä†·à≥·äï·çè·â∏·ãã·àç ·äê·ãç ·ã®·àö·àâ·âµ·ç¢"],
    ["·â†·àà·äï·ã∞·äï ·ã®·ã©·ä≠·à¨·äï ·ã∞·åã·çä ·ä†·åà·à´·âµ ·àò·à™·ãé·âΩ ·åâ·â£·ä§", "·â†·àà·äï·ã∞·äï ·ã®·â•·à™·â≥·äí·ã´·ãç ·å†/·àö·äí ·ä¨·ã≠·à≠ ·àµ·â≥·à≠·àò·à≠ ·ä†·àµ·â∞·äì·åã·åÖ·äê·âµ ·ã®·â∞·ä´·àÑ·ã∞·ãç ·ãã·äì·ãé·âπ ·ã®·ã©·ä≠·à¨·äï ·ã∞·åã·çä ·ä†·åà·àÆ·âΩ ·àò·à™·ãé·âΩ ·åâ·â£·ä§ ·ã©·ä≠·à¨·äï·äï ·â†·ãà·â≥·ã∞·à´·ãä·äì ·çã·ã≠·äì·äì·àµ ·ã®·â†·àà·å† ·àà·àò·à≠·ã≥·âµ ·â∞·àµ·àõ·àù·â∂ ·àõ·â•·âÉ·â± ·â∞·åà·àç·åø·àç·ç¢ ·ã≠·àÖ ·ä®·ä†·ãç·àÆ·çì ·àÖ·â•·à®·âµ ·ãç·å≠ ·ã®·ä´·äì·ã≥·äï·äì ·ã®·äñ·à≠·ãå·ã≠ ·àò·à™·ãé·âΩ ·â•·àé·àù ·ã®·â±·à≠·ä≠·äï ·ã®·ãç·å≠ ·åâ·ã≥·ã≠ ·àö·äí·àµ·âµ·à≠ ·åâ·â£·ä§ ·ã®·â∞·å†·à´·ãâ ·ã®·äê·å¨ ·â§·â∞·àò·äï·åç·àµ·âµ ·ä≠·àµ·â∞·âµ·äï ·â∞·ä®·âµ·àé ·äê·ãâ·ç¢"]
  ])
scores

array([3.4760369e-04, 3.4751720e-04, 9.3876332e-04, 4.1233754e-04,
       9.9968314e-01], dtype=float32)

In [None]:
from google.colab import userdata

# model.push_to_hub("roberta-amharic-reranker-medium", exist_ok=True, token=userdata.get("HF_WRITE"))