In [1]:
import torch
import os

torch.cuda.is_available()

True

In [2]:
from sentence_transformers import CrossEncoder
import torch


  from tqdm.autonotebook import tqdm, trange


In [3]:
miracle_n_hard_negs = 300
miracle_n_recall = 30

In [4]:
# Parameters
model_id = "hotchpotch/japanese-splade-base-v1"

# Model


In [5]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = model_id

model = AutoModelForMaskedLM.from_pretrained(model_name)
if device == "cuda":
    model.to(device)
    model.half()  # faster inference
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32768, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [6]:
import numpy as np
from tqdm import tqdm


def splade_max_pooling(logits, attention_mask):
    relu_log = torch.log(1 + torch.relu(logits))
    weighted_log = relu_log * attention_mask.unsqueeze(-1)
    max_val, _ = torch.max(weighted_log, dim=1)
    return max_val


def splade_vector(texts: list[str], bs=64):
    pooled_all = []
    for i in range(0, len(texts), bs):
        with torch.no_grad():
            batch_texts = texts[i : i + bs]
            inputs = tokenizer(
                batch_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512,
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            attention_mask = inputs["attention_mask"]
            input_ids = inputs["input_ids"]
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            pooled = splade_max_pooling(logits, attention_mask)
            pooled_all.append(pooled)

    return torch.cat(pooled_all).cpu().numpy()

In [7]:
# 1000個のテキストをランダムに生成
import random

random.seed(42)
texts = [
    "".join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(100)])
    for _ in range(1000)
]

# 1000個のテキストに対して、それぞれのテキストの特徴ベクトルを計算
vectors = []
for vec in splade_vector(texts):
    vectors.append(vec)
vectors[0].shape

(32768,)

# Miracle

- Need access token for huggingface


In [8]:
import os

try:
    import dotenv

    dotenv.load_dotenv("huggingface_access_token", override=True)
except Exception as e:
    print(e)
    pass

No module named 'dotenv'


In [9]:
import datasets

# query and positives
ds = datasets.load_dataset("miracl/miracl", "ja", split="dev")
ds

Dataset({
    features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
    num_rows: 860
})

In [10]:
# all corpus texts
corpus = datasets.load_dataset("miracl/miracl-corpus", "ja")
corpus

DatasetDict({
    train: Dataset({
        features: ['docid', 'title', 'text'],
        num_rows: 6953614
    })
})

In [11]:
import json

# hard negatives
with open("./miracl_hard_negs_1000.json") as f:
    hn = json.loads(f.read())
(
    len(hn),
    list(hn.keys())[:5],
    hn["0"].keys(),
    hn["0"]["docids"][:2],
    hn["0"]["indices"][:2],
)

(860,
 ['0', '3', '4', '5', '7'],
 dict_keys(['docids', 'indices']),
 ['2681119#0', '2681119#1'],
 [1393435, 1393436])

In [12]:
target_curpus_ids = set()
for item in ds:
    query = item["query"]
    # passages are set(300 hard negatives + positives)
    positive_docids = [pp["docid"] for pp in item["positive_passages"]]
    hn_docids = hn[item["query_id"]]["docids"][:miracle_n_hard_negs]
    target_curpus_ids.update(positive_docids)
    target_curpus_ids.update(hn_docids)
print(f"target_curpus_ids: {len(target_curpus_ids)}")

target_curpus_ids: 227726


In [13]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist


def get_text(corpus_item):
    return corpus_item["title"] + " " + corpus_item["text"]


corpus_dict = {item["docid"]: get_text(item) for item in tqdm(corpus["train"])}

100%|██████████| 6953614/6953614 [01:14<00:00, 92802.37it/s]


In [14]:
len(corpus_dict)

6953614

In [15]:
import transformers

# for "Be aware, overflowing tokens are not returned for the setting you have chosen" warning
transformers.logging.set_verbosity_error()


In [16]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist


n_total_pos = 0
n_total_tp = 0

for item in tqdm(ds):
    # query
    # query_emb = model.encode([retrieve_query_prefix + item["query"]])
    query_emb = splade_vector([item["query"]])

    # passages are set(300 hard negatives + positives)
    positive_docids = [pp["docid"] for pp in item["positive_passages"]]
    positive_texts = [get_text(pp) for pp in item["positive_passages"]]
    hn_docids = hn[item["query_id"]]["docids"][:miracle_n_hard_negs]

    # drop hard negatives in positives
    hn_docids = [docid for docid in hn_docids if docid not in positive_docids]

    # search target
    target_docids = positive_docids + hn_docids
    target_texts = positive_texts + [corpus_dict[docid] for docid in hn_docids]

    # embedding
    # target_embs = model.encode(
    #     [retrieve_passage_prefix + text for text in target_texts]
    # )
    target_embs = splade_vector(target_texts)

    # topK
    # topk_indices = np.argsort(cdist(query_emb, target_embs, metric="cosine"))[0][
    #     :miracle_n_recall
    # ]
    scores = np.dot(query_emb, target_embs.T)
    topk_indices = np.argsort(-scores).tolist()[0][:miracle_n_recall]
    n_tp = len(
        set(topk_indices) & set(range(len(positive_docids)))
    )  # positives are first indices
    n_pos = len(positive_docids)

    n_total_pos += n_pos
    n_total_tp += n_tp


miracl_recall = n_total_tp / n_total_pos

n_total_pos, n_total_tp, miracl_recall

  0%|          | 0/860 [00:00<?, ?it/s]

100%|██████████| 860/860 [05:35<00:00,  2.56it/s]


(1790, 1656, 0.9251396648044693)

# Output


In [17]:
jsts_score = None
jsick_score = None
model_id, jsts_score, jsick_score, miracl_recall

('hotchpotch/japanese-splade-base-v1', None, None, 0.9251396648044693)

In [18]:
import json

with open(f'./scores/{model_id.replace("/", "_")}.txt', "w") as f:
    f.write(
        json.dumps(
            {
                "model_id": model_id,
                "jsts": jsts_score,
                "jsick": jsick_score,
                "miracl": miracl_recall,
            }
        )
    )