In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import bm25s
import re

test_df = pd.read_csv("../../dataset_for_hf/test.csv")

In [None]:
def create_corpus(folder):
    corpus = []
    for file in folder.iterdir():
        with Path(file).open("r") as f:
            corpus.append(f.read())
    return corpus

corpus = create_corpus(Path("../../data/background_information_data/drug_data/Wiki/")) + \
         create_corpus(Path("../../data/background_information_data/drug_data/Wiki_complexified/")) + \
         create_corpus(Path("../../data/background_information_data/drug_data/SMILES/")) + \
         create_corpus(Path("../../data/background_information_data/protein_data/Wiki/"))

corpus_tokens = bm25s.tokenize(corpus)
retriever = bm25s.BM25(corpus=corpus)
retriever.index(corpus_tokens, )
retriever.save("bm25s_index")

In [None]:
top_k = 5
hits = []
recall = []
mrr = []

retriever = bm25s.BM25.load("bm25s_index", load_corpus=True)

def hits_k(gold_list, retrieved_list):
    if set(gold_list).issubset(set(retrieved_list)):
        return 1
    else:
        return 0

def recall_k(gold_list, retrieved_list):
    return len(set(gold_list).intersection(set(retrieved_list)))/len(retrieved_list)

def mrr_k(gold_list, retrieved_list):
    for idx, item in enumerate(retrieved_list):
        if item in gold_list:
            return 1/(idx+1)
    return 0

for row in test_df.itertuples():
    # Querying
    query_tokens = bm25s.tokenize(row.Question)
    retrieved_docs, _ = retriever.retrieve(query_tokens, k=top_k)
    retrieved_docs = [x["text"] for x in retrieved_docs[0]]

    # Processing Gold data
    gold_docs = set()
    if "SMILES" in row.Question_Background:
        split = row.Question_Background.split("\n")
        for i in range(0, len(split), 3):
            if "INTERACTION" not in split[i]:
                gold_docs.add(f"{split[i]}\n{split[i+1]}")
    else:
        gold_docs = [x.strip() for x in re.findall(r"BACKGROUND INFORMATION:(.*?)(?=DRUG|PROTEIN|\Z)", row.Question_Background, re.DOTALL)]

    # Metrics
    hits.append(hits_k(gold_docs, retrieved_docs))
    recall.append(recall_k(gold_docs, retrieved_docs))
    mrr.append(mrr_k(gold_docs, retrieved_docs))

print(f"Hits@{top_k}: {np.round(np.array(hits).mean(),2)}")
print(f"Recall@{top_k}: {np.round(np.array(recall).mean(),2)}")
print(f"MRR@{top_k}: {np.round(np.array(mrr).mean(),2)}")