# BioASQ Dataset Exploration

This notebook explores the rag-mini-bioasq dataset structure and provides visualizations.

In [3]:
import sys
sys.path.insert(0, '../src')

from datasets import load_dataset
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('whitegrid')
%matplotlib inline

## 1. Load Dataset

In [5]:
# Load the dataset
text_corpus = load_dataset('rag-datasets/rag-mini-bioasq', "text-corpus")
question_answer_passages = load_dataset('rag-datasets/rag-mini-bioasq', "question-answer-passages")

Generating passages split: 100%|██████████| 40221/40221 [00:00<00:00, 298249.92 examples/s]
Generating test split: 100%|██████████| 4719/4719 [00:00<00:00, 657200.94 examples/s]


In [28]:
assert all(question_answer_passages["test"][i]["id"]==i for i in range(len(question_answer_passages["test"])))

In [29]:
question_answer_passages["test"][0]

{'question': 'Is Hirschsprung disease a mendelian or a multifactorial disorder?',
 'answer': "Coding sequence mutations in RET, GDNF, EDNRB, EDN3, and SOX10 are involved in the development of Hirschsprung disease. The majority of these genes was shown to be related to Mendelian syndromic forms of Hirschsprung's disease, whereas the non-Mendelian inheritance of sporadic non-syndromic Hirschsprung disease proved to be complex; involvement of multiple loci was demonstrated in a multiplicative model.",
 'relevant_passage_ids': '[20598273, 6650562, 15829955, 15617541, 23001136, 8896569, 21995290, 12239580, 15858239]',
 'id': 0}

In [None]:
# save passages and dictionaries
# encode all passages with sentence-transformers/all-MiniLM-L6-v2
# top-k retrieval with faiss

In [15]:
text_corpus["passages"][0]

{'passage': 'New data on viruses isolated from patients with subacute thyroiditis de Quervain \nare reported. Characteristic morphological, cytological, some physico-chemical \nand biological features of the isolated viruses are described. A possible role \nof these viruses in human and animal health disorders is discussed. The isolated \nviruses remain unclassified so far.',
 'id': 9797}

In [16]:
from tqdm import tqdm
pid_to_passage = {}
pid_to_idx = {}
for i in tqdm(range(len(text_corpus["passages"]))):
    passage = text_corpus["passages"][i]["passage"]
    passage_id = text_corpus["passages"][i]["id"]
    pid_to_passage[passage_id] = passage
    pid_to_idx[passage_id] = i

100%|██████████| 40221/40221 [00:02<00:00, 19272.07it/s]


In [36]:
idx_to_pid = {v:k for k,v in pid_to_idx.items()}

In [18]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [20]:
model = model.to("cuda")

In [None]:
passages_list = [pid_to_passage[pid] for pid in sorted(pid_to_passage.keys())]
embeddings = model.encode(passages_list, show_progress_bar=True) # already normalized

Batches: 100%|██████████| 1257/1257 [00:28<00:00, 44.26it/s] 


In [30]:
questions_list = [qap["question"] for qap in question_answer_passages["test"]]

In [31]:
question_embeddings = model.encode(questions_list, show_progress_bar=True)

Batches: 100%|██████████| 148/148 [00:02<00:00, 72.26it/s]


In [35]:
import torch
question_embeddings, embeddings = torch.from_numpy(question_embeddings).to("cuda"), torch.from_numpy(embeddings).to("cuda")

In [None]:
(embeddings @ question_embeddings[0]).shape

(40221,)

In [40]:
batch_size = 128
k = 20
retrieved_passages = []
with torch.no_grad():
    for i in tqdm(range(0, len(question_embeddings), batch_size)):
        batch_question_embeddings = question_embeddings[i:i+batch_size]
        passage_scores = batch_question_embeddings @ embeddings.T # shape (batch_size, num_passages)
        topk_passage_scores, topk_passage_indices = torch.topk(passage_scores, k=k, dim=1) # shape (batch_size, k)
        topk_passage_indices = topk_passage_indices.cpu().tolist()
        for indices in topk_passage_indices:
            retrieved_passages.append([idx_to_pid[idx] for idx in indices])

100%|██████████| 37/37 [00:00<00:00, 825.81it/s]


In [49]:
# compute metrics: precision, recall, mrr, ndcg at k=1,5,10,20
import json
def compute_metrics(retrieved_passages, question_answer_passages, k):
    precision_at_k = []
    recall_at_k = []
    mrr_at_k = []
    ndcg_at_k = []
    
    for i, qap in enumerate(question_answer_passages["test"]):
        relevant_pids = set(json.loads(qap["relevant_passage_ids"]))
        retrieved_pids = retrieved_passages[i][:k]
        
        # Precision@k
        precision = len(set(retrieved_pids) & relevant_pids) / k
        precision_at_k.append(precision)
        
        # Recall@k
        recall = len(set(retrieved_pids) & relevant_pids) / len(relevant_pids) if relevant_pids else 0
        recall_at_k.append(recall)
        
        # MRR@k
        mrr = 0
        for rank, pid in enumerate(retrieved_pids, start=1):
            if pid in relevant_pids:
                mrr = 1 / rank
                break
        mrr_at_k.append(mrr)
        
        # NDCG@k
        dcg = 0
        idcg = sum(1 / torch.log2(torch.tensor(i + 2)) for i in range(min(len(relevant_pids), k)))
        for rank, pid in enumerate(retrieved_pids, start=1):
            if pid in relevant_pids:
                dcg += 1 / torch.log2(torch.tensor(rank + 1))
        ndcg = dcg / idcg if idcg > 0 else 0
        ndcg_at_k.append(ndcg)
    
    return {
        "precision": sum(precision_at_k) / len(precision_at_k),
        "recall": sum(recall_at_k) / len(recall_at_k),
        "mrr": sum(mrr_at_k) / len(mrr_at_k),
        "ndcg": sum(ndcg_at_k) / len(ndcg_at_k)
    }

In [50]:
for k in [1, 5, 10, 20]:
    metrics = compute_metrics(retrieved_passages, question_answer_passages, k)
    print(f"Metrics at k={k}: {metrics}")

Metrics at k=1: {'precision': 0.5596524687433778, 'recall': 0.13204025265349995, 'mrr': 0.5596524687433778, 'ndcg': tensor(0.5597)}
Metrics at k=5: {'precision': 0.3802076711167698, 'recall': 0.2926906106714452, 'mrr': 0.6261213533940809, 'ndcg': tensor(0.4885)}
Metrics at k=10: {'precision': 0.28436109345199984, 'recall': 0.37335606086672535, 'mrr': 0.6313573631755454, 'ndcg': tensor(0.4615)}
Metrics at k=20: {'precision': 0.18544183089637167, 'recall': 0.43850747045641564, 'mrr': 0.6339913724707721, 'ndcg': tensor(0.4505)}


In [45]:
pid_to_passage[25122144]

'PURPOSE: The phenotypic manifestations of cerebral cavernous malformation \ndisease caused by rare PDCD10 mutations have not been systematically examined, \nand a mechanistic link to Rho kinase-mediated hyperpermeability, a potential \ntherapeutic target, has not been established.\nMETHODS: We analyzed PDCD10 small interfering RNA-treated endothelial cells for \nstress fibers, Rho kinase activity, and permeability. Rho kinase activity was \nassessed in cerebral cavernous malformation lesions. Brain permeability and \ncerebral cavernous malformation lesion burden were quantified, and clinical \nmanifestations were assessed in prospectively enrolled subjects with PDCD10 \nmutations.\nRESULTS: We determined that PDCD10 protein suppresses endothelial stress fibers, \nRho kinase activity, and permeability in vitro. Pdcd10 heterozygous mice have \ngreater lesion burden than other Ccm genotypes. We demonstrated robust Rho \nkinase activity in murine and human cerebral cavernous malformation 

In [46]:
pid_to_passage[14697511]

'Cerebral cavernous malformations (CCM) are vascular malformations, mostly \nlocated in the central nervous system, which occur in 0.1-0.5% of the \npopulation. They are characterized by abnormally enlarged and often leaking \ncapillary cavities without intervening neural parenchyma. Some are clinically \nsilent, whereas others cause seizures, intracerebral haemorrhage or focal \nneurological deficits. These vascular malformations can arise sporadically or \nmay be inherited as an autosomal dominant condition with incomplete penetrance. \nAt least 45% of families affected with cerebral cavernous malformations harbour \na mutation in Krev interaction trapped-1 (Krit1) gene (cerebral cavernous \nmalformation gene-1, CCM1). This gene contains 16 coding exons which encode a \n736-amino acid protein containing three ankyrin repeats and a FERM domain. \nNeither the CCM1 pathogenetic mechanisms nor the function of the Krit1 protein \nare understood so far, although several hypotheses have bee