In [1]:
import torch

torch.cuda.is_available()

True

In [2]:
# Uses RAGatouille for the ColBERT handling
!pip install --upgrade --no-deps RAGatouille==0.0.5b3

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [3]:
from ragatouille import RAGPretrainedModel

In [4]:
miracle_n_hard_negs = 300
miracle_n_recall = 30

In [5]:
# Parameters
model_id = "bclavie/jacolbert"

# Model

In [6]:
JaColBERT = RAGPretrainedModel.from_pretrained(model_id)

# Miracle
* Need access token for huggingface

In [7]:
import os
import dotenv

dotenv.load_dotenv("huggingface_access_token", override=True)

True

In [8]:
import datasets

# query and positives
ds = datasets.load_dataset(
    "miracl/miracl", "ja", split="dev"
)
ds

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Dataset({
    features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
    num_rows: 860
})

In [9]:
# all corpus texts
corpus = datasets.load_dataset("miracl/miracl-corpus", "ja")
corpus

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['docid', 'title', 'text'],
        num_rows: 6953614
    })
})

In [10]:
import json
# hard negatives
with open("./miracl_hard_negs_1000.json") as f:
    hn = json.loads(f.read())
len(hn), list(hn.keys())[:5], hn["0"].keys(), hn["0"]["docids"][:2], hn["0"]["indices"][
    :2
]

(860,
 ['0', '3', '4', '5', '7'],
 dict_keys(['docids', 'indices']),
 ['2681119#0', '2681119#1'],
 [1393435, 1393436])

In [11]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist


def get_text(corpus_item):
    return corpus_item["title"] + " " + corpus_item["text"]


corpus_dict = {item["docid"]: get_text(item) for item in tqdm(corpus["train"])}

100%|██████████| 6953614/6953614 [05:20<00:00, 21694.04it/s]


In [12]:
n_total_pos = 0
n_total_tp = 0
try:
    JaColBERT.clear_encoded_docs(force=True)
except:
    pass

for i, item in enumerate(tqdm(ds)):    
    # passages are set(300 hard negatives + positives)
    positive_docids = [pp["docid"] for pp in item["positive_passages"]]
    positive_texts = [get_text(pp) for pp in item["positive_passages"]]
    hn_docids = hn[item["query_id"]]["docids"][:miracle_n_hard_negs]

    # drop hard negatives in positives
    hn_docids = [docid for docid in hn_docids if docid not in positive_docids]

    # search target
    target_texts = positive_texts + [corpus_dict[docid] for docid in hn_docids]
    metadata = [{"pid": _id} for _id in positive_docids + hn_docids]


    JaColBERT.encode(target_texts, document_metadatas=metadata, verbose=False)

    results = JaColBERT.search_encoded_docs(item["query"], k=miracle_n_recall)

    topk_indices = [result['document_metadata']['pid'] for result in results]

    n_pos = len(positive_docids)
    n_tp = len(
        set(topk_indices) & set((positive_docids))
    )  # positives are first indices

    n_total_pos += n_pos
    n_total_tp += n_tp

    JaColBERT.clear_encoded_docs(force=True)
    if i == 99:
        print(f"First 100 entries recall: {n_total_tp / n_total_pos}")
miracl_recall = n_total_tp / n_total_pos

n_total_pos, n_total_tp, miracl_recall

 12%|█▏        | 100/860 [02:45<19:30,  1.54s/it]

First 100 entries recall: 0.8717948717948718


100%|██████████| 860/860 [19:23<00:00,  1.35s/it]


(1790, 1535, 0.8575418994413407)

# Output

In [13]:
jsts_score = None
jsick_score = None
model_id, jsts_score, jsick_score, miracl_recall

('bclavie/jacolbert', None, None, 0.8575418994413407)

In [14]:
import json

with open(f'./scores/{model_id.replace("/", "_")}.txt', "w") as f:
    f.write(
        json.dumps(
            {
                "model_id": model_id,
                "jsts": jsts_score,
                "jsick": jsick_score,
                "miracl": miracl_recall,
            }
        )
    )