In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import bm25s
import re

test_df = pd.read_csv("../../dataset_for_hf/test.csv")

In [None]:
def create_corpus(folder):
    corpus = []
    for file in folder.iterdir():
        with Path(file).open("r") as f:
            corpus.append(f.read())
    return corpus

corpus = create_corpus(Path("../../data/background_information_data/drug_data/Wiki/")) + \
         create_corpus(Path("../../data/background_information_data/drug_data/Wiki_complexified/")) + \
         create_corpus(Path("../../data/background_information_data/drug_data/SMILES/")) + \
         create_corpus(Path("../../data/background_information_data/protein_data/Wiki/"))

corpus_tokens = bm25s.tokenize(corpus)
retriever = bm25s.BM25(corpus=corpus)
retriever.index(corpus_tokens, )
retriever.save("bm25s_index")

In [None]:
hits_5 = []
hits_10 = []
hits_15 = []
recall = []
mrr = []

retriever = bm25s.BM25.load("bm25s_index", load_corpus=True)

def hits_k(gold_list, retrieved_list):
    if set(gold_list).issubset(set(retrieved_list)):
        return 1
    else:
        return 0

def recall_k(gold_list, retrieved_list):
    return len(set(gold_list).intersection(set(retrieved_list)))/len(retrieved_list)

def mrr_k(gold_list, retrieved_list):
    for idx, item in enumerate(retrieved_list):
        if item in gold_list:
            return 1/(idx+1)
    return 0

for row in test_df.itertuples():
    # Querying
    query_tokens = bm25s.tokenize(row.Question)
    retrieved_docs_5, _ = retriever.retrieve(query_tokens, k=5)
    retrieved_docs_10, _ = retriever.retrieve(query_tokens, k=10)
    retrieved_docs_15, _ = retriever.retrieve(query_tokens, k=15)
    
    retrieved_docs_5 = [x["text"] for x in retrieved_docs_5[0]]
    retrieved_docs_10 = [x["text"] for x in retrieved_docs_10[0]]
    retrieved_docs_15 = [x["text"] for x in retrieved_docs_15[0]]

    # Processing Gold data
    gold_docs = []
    if "SMILES" in row.Question_Background: # Then you know it will be a molecular interaction question.
        split = row.Question_Background.split("\n")
        for i in range(0, len(split), 3):
            if "INTERACTION" not in split[i]:
                name = re.sub(r"DRUG \d+\s? NAME","DRUG NAME", split[i])
                smiles = re.sub(r"DRUG \d+\s? SMILES","DRUG SMILES", split[i+1])
                gold_docs.append(f"{name}\n{smiles}")
    else:
        gold_docs = [x.strip() for x in re.findall(r"BACKGROUND INFORMATION:(.*?)(?=DRUG|PROTEIN|\Z)", row.Question_Background, re.DOTALL)]

    # Metrics
    hits_5.append(hits_k(gold_docs, retrieved_docs_5))
    hits_10.append(hits_k(gold_docs, retrieved_docs_10))
    hits_15.append(hits_k(gold_docs, retrieved_docs_15))
    recall.append(recall_k(gold_docs, retrieved_docs_5))
    mrr.append(mrr_k(gold_docs, retrieved_docs_5))

print(f"Hits@5: {np.round(np.array(hits_5).mean(),2)}")
print(f"Hits@10: {np.round(np.array(hits_10).mean(),2)}")
print(f"Hits@15: {np.round(np.array(hits_15).mean(),2)}")
print(f"Recall@5: {np.round(np.array(recall).mean(),2)}")
print(f"MRR: {np.round(np.array(mrr).mean(),2)}")