In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from pyserini.search import LuceneSearcher, FaissSearcher
from pyserini.search.hybrid import HybridSearcher

ssearcher = LuceneSearcher("../temp/pyserini_index/")
# fsearcher = FaissSearcher("../temp/dpr_index/dpr_full", "../temp/qry_encoder_dpr/")

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
ssearcher.search("Minh")[0].raw

'{\n  "id" : "13059949::[10,19]",\n  "title" : "Minh",\n  "contents" : "BULLET: - Minh Mạng (1791–1841), Vietnamese emperor BULLET: - Minh Nguyen-Vo (born 1956), Vietnamese film director BULLET: - Minh Tuyet (born 1975), Vietnamese singer BULLET: - Minh Vu (born 1990), American soccer player BULLET: - Nguyen Xuan Minh (born 1971), Vietnamese businessman BULLET: - Phạm Bình Minh (born 1959), Vietnamese politician BULLET: - Quyền Văn Minh (born 1954), Vietnamese jazz saxophonist BULLET: - Son Ngoc Minh (1920–1972), Cambodian politician BULLET: - Thu Minh (born 1977), Vietnamese singer BULLET: - Trần Văn Minh (1923–2009), Vietnamese general and diplomat"\n}'

In [4]:
searcher = HybridSearcher(ssearcher, fsearcher)

In [5]:
import pandas as pd

df = pd.read_json("../data/totto/augmented/dev.jsonl", lines=True)

In [6]:
import json

def process_content(raw_content: str) -> str:
    title, content = json.loads(raw_content)["contents"].split("\n\n", 1)
    return title, content

In [8]:
from read.utils.table import linearize
import torch
import pprint
from torchmetrics import RetrievalHitRate, RetrievalRPrecision, RetrievalRecall, RetrievalPrecision, RetrievalMRR, RetrievalMAP


k2metrics = {k: [RetrievalHitRate(k=k), RetrievalRPrecision(k=k), RetrievalRecall(k=k), RetrievalPrecision(k=k), RetrievalMRR(k=k), RetrievalMAP(k=k)] for k in [1, 5, 10]}


preds = []
indices = []
golds = []

prediction_data = []

df = df.sample(frac=1)
idx = 0

for _, row in df.iloc[:1000].iterrows():
    idx += 1
    one_prediction = []
    table = row["table"]
    query = linearize(table, row["highlighted_cells"], row_sep=" ; ", value_sep=" : ", includes_header=True, return_text=True)
    
    hits = ssearcher.search(query, k=10)
    
    for hit in hits:
        doc = ssearcher.doc(hit.docid)
        title, text = process_content(doc.raw())
        
        preds.append(hit.score)
        indices.append(idx)
        golds.append(title == row["table_page_title"])
        
        # if title == row["table_page_title"]:
        #     print(f"Title: '{title}' with query ({query}) for sentence ({row['sentence']})\n")
        #     print(f"Label: {row['negatives'] is not None}\n and index {idx}")
        #     print(f"Text: {text}\n")

        one_prediction.append({"idx": idx, "title": title, "text": text, "score": hit.score, "gold": row["table_page_title"], "table": table, "query": query, "sentence": row["sentence_annotations"]})
    
    if idx % 100 == 0:
        result = {}

        pred_tensor = torch.tensor(preds)
        gold_tensor = torch.tensor(golds)
        indices_tensor = torch.tensor(indices)

        for k, metrics in k2metrics.items():
            for metric in metrics:
                result[f"{metric.__class__.__name__}@{k}"] = metric(pred_tensor, gold_tensor, indices_tensor)
        print(f"Index {idx}")
        pprint.pprint(result)
    prediction_data.append(one_prediction)


pred_tensor = torch.tensor(preds)
gold_tensor = torch.tensor(golds)
indices_tensor = torch.tensor(indices)
for k, metrics in k2metrics.items():
    for metric in metrics:
        result[f"{metric.__class__.__name__}@{k}"] = metric(pred_tensor, gold_tensor, indices_tensor)
print(f"Index {idx}")
pprint.pprint(result)

Index 100
{'RetrievalHitRate@1': tensor(0.3500),
 'RetrievalHitRate@10': tensor(0.6400),
 'RetrievalHitRate@5': tensor(0.5200),
 'RetrievalMAP@1': tensor(0.4204),
 'RetrievalMAP@10': tensor(0.4204),
 'RetrievalMAP@5': tensor(0.4204),
 'RetrievalMRR@1': tensor(0.4376),
 'RetrievalMRR@10': tensor(0.4376),
 'RetrievalMRR@5': tensor(0.4376),
 'RetrievalPrecision@1': tensor(0.3500),
 'RetrievalPrecision@10': tensor(0.1040),
 'RetrievalPrecision@5': tensor(0.1580),
 'RetrievalRPrecision@1': tensor(0.3249),
 'RetrievalRPrecision@10': tensor(0.3249),
 'RetrievalRPrecision@5': tensor(0.3249),
 'RetrievalRecall@1': tensor(0.2505),
 'RetrievalRecall@10': tensor(0.6400),
 'RetrievalRecall@5': tensor(0.4941)}
Index 200
{'RetrievalHitRate@1': tensor(0.4350),
 'RetrievalHitRate@10': tensor(0.6950),
 'RetrievalHitRate@5': tensor(0.6050),
 'RetrievalMAP@1': tensor(0.4924),
 'RetrievalMAP@10': tensor(0.4924),
 'RetrievalMAP@5': tensor(0.4924),
 'RetrievalMRR@1': tensor(0.5121),
 'RetrievalMRR@10': tenso