In [None]:
! pip install -q sentence-transformers datasets pylate beir ranx

In [None]:
import wandb
wandb.init(mode="disabled")

In [None]:
! rm -r output/
! rm -r pylate-index/

In [None]:
import torch
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)

from pylate import evaluation, losses, models, utils

# Define model parameters for contrastive training
model_name = "rasyosef/roberta-base-amharic"  # Choose the pre-trained model you want to use as base

# Set the run name for logging and output directory
run_name = "colbert-base-amharic"
output_dir = f"output/{run_name}"

# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder.
model = models.ColBERT(
  model_name_or_path=model_name,
  document_length=256
)

# Compiling the model makes the training faster
# model = torch.compile(model)

In [None]:
model.query_length, model.document_length

(32, 256)

In [None]:
# Load Dataset
from datasets import load_dataset

am_relevance_dataset = load_dataset("yosefw/amharic-news-retrieval-dataset-v2-with-negatives-V2")#.select(range(4_000))
am_relevance_dataset

DatasetDict({
    train: Dataset({
        features: ['query_id', 'passage_id', 'query', 'passage', 'category', 'link', 'source_dataset', 'negative_passages'],
        num_rows: 61469
    })
    test: Dataset({
        features: ['query_id', 'passage_id', 'query', 'passage', 'category', 'link', 'source_dataset', 'negative_passages'],
        num_rows: 6832
    })
})

In [None]:
test_passage_ids = set(am_relevance_dataset["test"]["passage_id"])
len(test_passage_ids)

6764

In [None]:
from datasets import Dataset
from tqdm import tqdm
import random

ds_rows = []
for row in tqdm(am_relevance_dataset["train"]):
  neg_passages = row["negative_passages"]
  # neg_passages = list(filter(lambda x: x["passage_id"] not in test_passage_ids, neg_passages))
  neg_passages_filtered = neg_passages[:4] + neg_passages[-4:]

  ds_rows.append({
      "query_id": row["query_id"],
      "passage_id": row["passage_id"],
      "query": row["query"],
      "positive": row["passage"],
      "negative_1": neg_passages_filtered[0]["passage"],
      "negative_2": neg_passages_filtered[2]["passage"],
      "negative_3": neg_passages_filtered[4]["passage"],
      "negative_4": neg_passages_filtered[6]["passage"],
    })

  ds_rows.append({
      "query_id": row["query_id"],
      "passage_id": row["passage_id"],
      "query": row["query"],
      "positive": row["passage"],
      "negative_1": neg_passages_filtered[1]["passage"],
      "negative_2": neg_passages_filtered[3]["passage"],
      "negative_3": neg_passages_filtered[5]["passage"],
      "negative_4": neg_passages_filtered[7]["passage"],
    })
  # print(ds_rows)
  # break

relevance_dataset = Dataset.from_list(ds_rows).sort("query_id")#.select(range(4000))
relevance_dataset

100%|██████████| 61469/61469 [00:25<00:00, 2436.52it/s]


Dataset({
    features: ['query_id', 'passage_id', 'query', 'positive', 'negative_1', 'negative_2', 'negative_3', 'negative_4'],
    num_rows: 122938
})

In [None]:
len(test_passage_ids.intersection(relevance_dataset["passage_id"]))

161

In [None]:
# Split the dataset (this dataset does not have a validation set, so we split the training set)

EVAL_SIZE = 4_000

eval_dataset = relevance_dataset.select(range(EVAL_SIZE))
train_dataset = relevance_dataset.select(range(EVAL_SIZE, len(relevance_dataset)))

train_dataset[0], train_dataset

In [None]:
eval_dataset[0], eval_dataset

In [None]:
# Define the loss function
train_loss = losses.Contrastive(model=model)

# Initialize the evaluator
dev_evaluator = evaluation.ColBERTTripletEvaluator(
    anchors=eval_dataset["query"] * 4,
    positives=eval_dataset["positive"] * 4,
    negatives=eval_dataset["negative_1"] + eval_dataset["negative_2"] + eval_dataset["negative_3"] + eval_dataset["negative_4"],
)

In [None]:
eval_steps = 1500
batch_size = 32  # Larger batch size often improves results, but requires more memory
num_train_epochs = 4  # Adjust based on your requirements

# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    learning_rate=1e-5,
    lr_scheduler_type="linear",
    eval_strategy="steps",
    eval_steps=eval_steps,
    save_strategy="steps",
    save_steps=eval_steps,
    # save_total_limit=2,
    logging_strategy="steps",
    logging_steps=eval_steps,
)

# Initialize the trainer for the contrastive training
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
    data_collator=utils.ColBERTCollator(model.tokenize),
)

In [None]:
# Start the training process
trainer.train()

### **Evaluate**

In [None]:
dev_evaluator(
    model=model,
    # model=models.ColBERT("/content/output/contrastive-bert-small/checkpoint-2850"),
    output_path="."
  )

{'accuracy': 0.9801875352859497}

In [None]:
import os
from google.colab import userdata

os.environ["HF_TOKEN"] = userdata.get("HF_WRITE")

# push model to hub
trainer.model.push_to_hub("ColBERT-Amharic-Base")

### **BIER Eval**

In [None]:
from pylate import evaluation, indexes, models, retrieve

# Step 1: Initialize the ColBERT model

# dataset = "scifact" # Choose the dataset you want to evaluate
model = models.ColBERT(
    model_name_or_path="rasyosef/ColBERT-Amharic-Base",
    device="cuda" # "cpu" or "cuda" or "mps"
)

# Step 2: Create a Voyager index
index = indexes.Voyager(
    index_folder="pylate-index",
    index_name="colbert-eval",
    override=True,  # Overwrite any existing index
)

test_dataset = am_relevance_dataset["test"]

documents = [
    {"id": pid, "text": pos}
    for pid, pos in dict(zip(
        am_relevance_dataset["test"]["passage_id"] + am_relevance_dataset["train"]["passage_id"],
        am_relevance_dataset["test"]["passage"] + am_relevance_dataset["train"]["passage"]
    )).items()
  ]

queries = [q for q in test_dataset["query"]]
qrels = {query: {pid:1} for query, pid in zip(test_dataset["query"], test_dataset["passage_id"])}

# Step 4: Encode the documents
documents_embeddings = model.encode(
    [document["text"] for document in documents],
    batch_size=32,
    is_query=False,  # Indicate that these are documents
    show_progress_bar=True,
)

# Step 5: Add document embeddings to the index
index.add_documents(
    documents_ids=[document["id"] for document in documents],
    documents_embeddings=documents_embeddings,
)

# Step 6: Encode the queries
queries_embeddings = model.encode(
    queries,
    batch_size=32,
    is_query=True,  # Indicate that these are queries
    show_progress_bar=True,
)

# Step 7: Retrieve top-k documents
retriever = retrieve.ColBERT(index=index)
scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=100,  # Retrieve the top 100 matches for each query
    batch_size=128,
)

# Step 8: Evaluate the retrieval results
results = evaluation.evaluate(
    scores=scores,
    qrels=qrels,
    queries=queries,
    metrics=[f"ndcg@{k}" for k in [5, 10, 100]] # NDCG for different k values
    + ["recall@5", "recall@10", "recall@50", "recall@100"]   # Recall at k
    + ["mrr@5", "mrr@10", "mrr@100"]
)

print(results)

In [None]:
# Step 8: Evaluate the retrieval results
results = evaluation.evaluate(
    scores=scores,
    qrels=qrels,
    queries=queries,
    metrics=[f"ndcg@{k}" for k in [5, 10, 100]] # NDCG for different k values
    + ["recall@5", "recall@10", "recall@50", "recall@100"]   # Recall at k
    + ["mrr@5", "mrr@10", "mrr@100"]
)

print(results)

metrics = ["ndcg@5", "ndcg@10", "ndcg@100", "mrr@5", "mrr@10", "mrr@100", "recall@5", "recall@10", "recall@50", "recall@100"]
for m in metrics:
  print(f"{m}: {results[m]}")

{'ndcg@5': np.float64(0.8253545025645599), 'ndcg@10': np.float64(0.8346465001473153), 'ndcg@100': np.float64(0.8443166547375272), 'recall@5': np.float64(0.9015817223198594), 'recall@10': np.float64(0.930140597539543), 'recall@50': np.float64(0.9667545401288811), 'recall@100': np.float64(0.9745166959578208), 'mrr@5': np.float64(0.799558191759422), 'mrr@10': np.float64(0.8034251471531787), 'mrr@100': np.float64(0.8054998678863597)}
ndcg@5: 0.8253545025645599
ndcg@10: 0.8346465001473153
ndcg@100: 0.8443166547375272
mrr@5: 0.799558191759422
mrr@10: 0.8034251471531787
mrr@100: 0.8054998678863597
recall@5: 0.9015817223198594
recall@10: 0.930140597539543
recall@50: 0.9667545401288811
recall@100: 0.9745166959578208
