In [1]:
import os

# Set env vars BEFORE importing huggingface modules
os.environ["HF_HOME"] = "/projects/sciences/computing/sheju347/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/projects/sciences/computing/sheju347/.cache/huggingface/hub"

# Now import huggingface modules
from huggingface_hub import constants

print("HF_HOME:", constants.HF_HOME)
print("HF_HUB_CACHE:", constants.HF_HUB_CACHE)

HF_HOME: /projects/sciences/computing/sheju347/.cache/huggingface
HF_HUB_CACHE: /projects/sciences/computing/sheju347/.cache/huggingface/hub


In [3]:

def format_choices(choices, answer_key, mask_correct_answer):
    final_answers = []
    for i, (x, y) in enumerate(choices.items()):
        if mask_correct_answer and x == answer_key:
            y = "None of the other answers"
        final_answers.append(f'[{x}] : {y}')
    return "\n".join(final_answers)

In [None]:

def get_RAG_context(query, formated_choices, score_threshold = None, pick_rag_index = None, use_classifier = False):
    ip = "localhost"
    port = 8080
    endpoint = "search"
    response = requests.get(f"http://{ip}:{port}/{endpoint}", params = {"q": query, "k": self.topK_searchEngine})
    if response.status_code == 200:


        # 1. BM25
        # text = response.text
        # doc_list = text.split("###RAG_DOC###")
        # for result in results:
        results = response.json() # [{docId, docNo, score, content}, ...]
        doc_data_list = []
        for i in range(len(results)):
            result = results[i]
            data = {"docId": result["docNo"], "BM25_score": result["score"], "BM25_ranking": i + 1, "content": result["content"]}
            doc_data_list.append(data)
        
        # print(f"1 len(doc_list): {len(doc_list)}")

        # 2. SPLADE
        if topK_SPLADE > 0:
            doc_data_list = self.RAG_SPLADE_filter(query, doc_data_list)
        # print(f"2 len(doc_list): {len(doc_list)}")
        

        # 3. MonoT5
        if topK_crossEncoder > 0:
            # doc_data_list = self.RAG_CrossEncoder_rerank(query + '\n' + formated_choices, doc_data_list)
            doc_data_list = self.RAG_MonoT5_rerank(query + '\n' + formated_choices, doc_data_list, score_threshold)
            # print(f"3 len(doc_list): {len(doc_list)}")            

        # 4. LLM list reranker
        if topK_LLM > 0:
            doc_data_list = self.RAG_LLM_rerank(query + '\n' + formated_choices, doc_data_list)
        
        # only feed the nth retrieved document into the model
        if pick_rag_index != None:
            doc_data_list = [doc_data_list[pick_rag_index]]

        # Use a classifier model to select which context+query can produce correct answer
        if use_classifier:
            doc_data_list = self.RAG_classifier(query + '\n' + formated_choices, doc_data_list)
            
        # doc_list -> context (str)
        context = ""
        for doc_data in doc_data_list:
            if self.check_RAG_doc_useful(query, formated_choices, doc_data):
                context += doc_data["content"]
                context += "\n\n"
            
        # Print RAG scores and rankings
        for doc_data in doc_data_list:
            del doc_data["content"]
        print(f"RAG data: {doc_data_list}")
        logging.info(f"RAG data: {doc_data_list}")
        
        return context
    else:
        return f"HTTPError: {response.status_code} - {response.text}"

In [5]:

prompt_RAG = '''
You are a medical question answering assistant.

The following context may or may not be useful. Use it only if it helps answer the question.
INSTRUCTIONS:
- If the context directly helps answer the question, use it and cite appropriately
- If the context is topically related but not diagnostically relevant, acknowledge it but rely on your medical knowledge
- If the context might mislead you toward a less likely diagnosis, explicitly state why you're not following it

Context:
{context}

Question:
{question}

{choices}
'''


from datasets import load_dataset
dataset_path = "GBaker/MedQA-USMLE-4-options"
subset_name = None
split = "test"
dataset = load_dataset(dataset_path, name = subset_name, trust_remote_code = True)
data_list = dataset[split]
# print(len(data_list))
for data in data_list:
    question = data["question"]
    choices = data["options"]
    answer_key = data["answer_idx"]
    prompt = prompt_RAG
    formated_choices = format_choices(choices, answer_key, mask_correct_answer = False)

    
    context = self.get_RAG_context(question, formated_choices, score_threshold = score_threshold, pick_rag_index = pick_rag_index, use_classifier = use_classifier)
    content = prompt.format(context = context, question = question, choices = formated_choices)

    

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'GBaker/MedQA-USMLE-4-options' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
