In [2]:
import argparse
import json
import os
import re
import torch
from src.RGAR import RGAR

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
PYTHON_SCRIPT="pipeline.py"
DATASET_NAME="ehrnoteqa"
DATASET_DIR="EHRNoteQA"
OUTPUT_PATH="results/Llama-3.2-3B-MedCPT-Textbooks-MedQA-RGAR-EHRNoteQA.json"
DEVICE_NAME="cuda:0"
LOG_FILE="logs/try-MedQA-RGAR-EHRNoteQA.log"

In [4]:
class QADataset:

    def __init__(self, data, dir="."):
        self.data = data.lower().split("_")[0]
        benchmark = json.load(open(os.path.join(dir, "ehrnoteqa.json")))
        if self.data not in benchmark:
            raise KeyError("{:s} not supported".format(data))
        self.dataset = benchmark[self.data]
        self.index = sorted(self.dataset.keys())

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, key):
        if type(key) == int:
            return self.dataset[self.index[key]]
        elif type(key) == slice:
            return [self.__getitem__(i) for i in range(self.__len__())[key]]
        else:
            raise KeyError("Key type not supported.")

def extract_answer(content):
    
    match = re.findall(r'(?:answer|Answer).*?([A-Z])', content)
    if match:
        return match[-1]
    
    match_last = re.search(r'([A-Z])(?=[^A-Z]*$)', content)
    if match_last:
        return match_last.group(1)
    return None

In [5]:
dataset = QADataset(DATASET_NAME, dir=DATASET_DIR)

rgar = RGAR(
    llm_name="meta-llama/Llama-3.2-3B-Instruct", 
    rag=True, 
    retriever_name="MedCPT", 
    corpus_name="Textbooks", 
    device=DEVICE_NAME,
    cot=False,
    me=2,
    follow_up=False,
    realme=False
)

No sentence-transformers model found with name ncbi/MedCPT-Query-Encoder. Creating a new one with CLS pooling.


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.13s/it]


In [6]:
if os.path.exists(OUTPUT_PATH):
    with open(OUTPUT_PATH, "r", encoding="utf-8") as f:
        results = json.load(f)
    print(f"Loaded {len(results)} saved results.")
else:
    results = []

correct_count = sum(1 for r in results if r["is_correct"])  
start_idx = len(results) 

In [7]:
for idx, data in enumerate(dataset[start_idx:], start=start_idx):
    question = data["question"]
    options = data["options"]
    correct_answer = data["answer"]

    answer_json, *_ = rgar.answer(question=question, options=options, k=32)
    # answer_json, snippets, scores = medrag.answer(question=question, options=options, k=args.top_k)
    print(answer_json) 

    predicted_answer = extract_answer(answer_json)

    if predicted_answer is None:
        print(f"Warning: Could not extract answer for Question {idx + 1}")
        predicted_answer = "N/A" 

    is_correct = predicted_answer == correct_answer
    if is_correct:
        correct_count += 1

    print(f"Question {idx + 1}/{len(dataset)}:")
    print(f"  Correct Answer: {correct_answer}")
    print(f"  Predicted Answer: {predicted_answer}")
    print(f"  {'Correct!' if is_correct else 'Incorrect.'}")

    results.append({
        "question": question,
        "correct_answer": correct_answer,
        "predicted_answer": predicted_answer,
        "raw_output": answer_json, 
        "is_correct": is_correct
    })

    if (idx + 1) % 10 == 0:
        with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=4, ensure_ascii=False)
        print(f"Progress saved at {idx + 1} questions.")

    torch.cuda.empty_cache()

accuracy = correct_count / len(dataset)
print(f"\nAccuracy: {accuracy * 100:.2f}%")
torch.cuda.empty_cache()
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=4, ensure_ascii=False)

print(f"Final results saved to {OUTPUT_PATH}")

Generated Subquery: What were the patient's vital signs upon hospital discharge compared to admission, specifically focusing on changes in heart rate, blood pressure, respiratory rate, oxygen saturation, and any notable fluctuations during their recovery period?
Generated Subanswer: No relevant information found.
Generated Subquery: What is the typical progression of symptoms and improvement in patients who have undergone [insert surgical procedure/procedure] while maintaining stable vital signs, tolerating oral intake, ambulating, and voiding independently after an initial episode of [condition]?
Generated Subanswer: Based on Document [19], No relevant information found.
Generated Subquery: What are the common complications associated with prolonged bed rest in hospitalized patients who maintain normal vital signs but experience persistent gastrointestinal issues such as constipation, urinary retention, and decreased appetite following surgery?
Generated Subanswer: Constipation, urina

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


```json
{
    "answer_choice": "A"
}
```
```json { "answer_choice": "A" } ```
Question 1/3:
  Correct Answer: A
  Predicted Answer: A
  Correct!
Generated Subquery: What is the typical management approach for cardiac arrest following percutaneous coronary intervention (PCI) with stenting in patients who develop severe bradycardia?
Generated Subanswer: Cardioversion, external pacing, and close monitoring.
Generated Subquery: What is the recommended protocol for managing hemodynamic instability after PCI with stenting in patients experiencing refractory bradycardia post-myocardial infarction (STEMI)?
Generated Subanswer: There is no mention of refractory bradycardia in the provided documents.
Generated Subquery: What is the optimal timing and dosing strategy for administering atropine in cases of symptomatic bradycardia during cardiac catheterization procedures such as PCI with stenting?
Generated Subanswer: Atropine is no longer considered effective for asystole or PEA, but can be used 