In [1]:
import torch
import os

torch.cuda.is_available()

True

In [2]:
from sentence_transformers import CrossEncoder
import torch


In [3]:
miracle_n_hard_negs = 300
miracle_n_recall = 30

In [4]:
# Parameters
model_id = "hotchpotch/japanese-bge-reranker-v2-m3-v1"

# Model

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CrossEncoder(model_id, max_length=512, device=device)
if device == "cuda":
    model.model.half() # faster inference

# Miracle
* Need access token for huggingface

In [6]:
import os
try:
    import dotenv
    dotenv.load_dotenv("huggingface_access_token", override=True)
except Exception as e:
    print(e)
    pass

No module named 'dotenv'


In [7]:
import datasets

# query and positives
ds = datasets.load_dataset(
    "miracl/miracl", "ja", split="dev"
)
ds

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Dataset({
    features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
    num_rows: 860
})

In [8]:
# all corpus texts
corpus = datasets.load_dataset("miracl/miracl-corpus", "ja")
corpus

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['docid', 'title', 'text'],
        num_rows: 6953614
    })
})

In [9]:
import json
# hard negatives
with open("./miracl_hard_negs_1000.json") as f:
    hn = json.loads(f.read())
len(hn), list(hn.keys())[:5], hn["0"].keys(), hn["0"]["docids"][:2], hn["0"]["indices"][
    :2
]

(860,
 ['0', '3', '4', '5', '7'],
 dict_keys(['docids', 'indices']),
 ['2681119#0', '2681119#1'],
 [1393435, 1393436])

In [10]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist


def get_text(corpus_item):
    return corpus_item["title"] + " " + corpus_item["text"]


corpus_dict = {item["docid"]: get_text(item) for item in tqdm(corpus["train"])}

100%|██████████| 6953614/6953614 [01:16<00:00, 90557.25it/s]


In [11]:
import transformers
# for "Be aware, overflowing tokens are not returned for the setting you have chosen" warning
transformers.logging.set_verbosity_error()


In [12]:
n_total_pos = 0
n_total_tp = 0

for item in tqdm(ds, total=len(ds)):
    query = item["query"]
    # passages are set(300 hard negatives + positives)
    positive_docids = [pp["docid"] for pp in item["positive_passages"]]
    positive_texts = [get_text(pp) for pp in item["positive_passages"]]
    hn_docids = hn[item["query_id"]]["docids"][:miracle_n_hard_negs]

    # drop hard negatives in positives
    hn_docids = [docid for docid in hn_docids if docid not in positive_docids]

    nagative_texts = [corpus_dict[docid] for docid in hn_docids]
    target_texts = positive_texts + nagative_texts

    ranked = model.rank(query, target_texts, top_k=miracle_n_recall)
    ranked = [d['corpus_id'] for d in ranked]
    topk_indices = ranked

    # topK
    n_pos = len(positive_docids)
    n_tp = len(
        set(topk_indices) & set(range(len(positive_docids)))
    )  # positives are first indices

    n_total_pos += n_pos
    n_total_tp += n_tp

    # if n_pos > n_tp:
    # print(f"{item['query_id']}:{n_tp}/{n_pos}", end=", ")

miracl_recall = n_total_tp / n_total_pos

n_total_pos, n_total_tp, miracl_recall

  0%|          | 0/860 [00:00<?, ?it/s]

100%|██████████| 860/860 [14:30<00:00,  1.01s/it]


(1790, 1695, 0.946927374301676)

# Output

In [13]:
jsts_score = None
jsick_score = None
model_id, jsts_score, jsick_score, miracl_recall

('hotchpotch/japanese-bge-reranker-v2-m3-v1', None, None, 0.946927374301676)

In [14]:
import json

with open(f'./scores/{model_id.replace("/", "_")}.txt', "w") as f:
    f.write(
        json.dumps(
            {
                "model_id": model_id,
                "jsts": jsts_score,
                "jsick": jsick_score,
                "miracl": miracl_recall,
            }
        )
    )