# MedQA Dataset parsing for English data

In [1]:
import pandas as pd 
import json
import numpy as np
import torch
import re
import os

def read_question_answer_file(file_path):
    """Reads a JSONL file with question-answer data and returns a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line))  # Parse each line as JSON
    return data

# Load your dataset
dataset_path = r'med_data/phrases_no_exclude_train.jsonl' 
# dataset_path = r'C:\Users\ranad\OneDrive - University of Glasgow\Attachments\Msc Final Year project\Data\MedQA-USMLE-4-options\phrases_no_exclude_train.jsonl'
questions_data = read_question_answer_file(dataset_path)

In [2]:
print(questions_data[0]['question'])
print(questions_data[0]['answer_idx'])
print(questions_data[0]['options'])

A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?
D
{'A': 'Ampicillin', 'B': 'Ceftriaxone', 'C': 'Doxycycline', 'D': 'Nitrofurantoin'}


<h2>Data pre-processing and extraction</h2>

In [3]:
def sanitize_filename(filename):
  """Sanitizes a filename by replacing special characters with underscores.

  Args:
    filename: The original filename.

  Returns:
    The sanitized filename.
  """

  # Replace non-alphanumeric characters with underscores
  filename = re.sub(r'[^\w]', '_', filename)

  # Remove leading and trailing underscores
  filename = filename.strip('_')

  return filename

In [4]:
def read_json_file(file_path):
  """Reads data from a JSON file.

  Args:
    file_path: The path to the JSON file.

  Returns:
    The parsed JSON data as a Python object.
  """

  with open(file_path, 'r') as f:
    data = json.load(f)
  return data

# # Example usage:
# file_path = 'data.json'
# data = read_json_file(file_path)
# print(data)

In [5]:
def save_string_to_file(data, filename):
  """Saves a string to a text file.

  Args:
    text: The string to be saved.
    filename: The name of the file to create.
  """

  with open(filename, "w",encoding='utf-8') as f:
    json.dump(data, f)

# # Example usage:
# my_string = "This is the text I want to save."
# save_string_to_file(my_string, "output.txt")

<h2>Dictionary format data</h2>

In [6]:
# pubmed_dir = 'C:/Users/ranad/Documents/Pubmed_Full_text'
#pubmed_dir = 'Pubmed_Full_text'
pubmed_dir = 'Pubmed_Full_text'
# Dictionary to store the original data (not embeddings) by file key


def extract_dataDict(directory):
    data_dict = {}
    # Iterate through each JSON file in the directory
    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        if os.path.isfile(filepath):
            with open(filepath, 'r') as file:
                data = json.load(file)
                file_key = sanitize_filename(filename[:-4])  
                data_dict[file_key.lower()] = data
    return data_dict

data_dict = extract_dataDict(pubmed_dir)

In [7]:
#save_string_to_file(data_dict,"Processed data/Full_text_dict.json")
# save_string_to_file("Processed data/abstract_text_dict.json")
fulltext_dict = read_json_file("Processed data/Full_text_dict.json")
# abstract_dict = read_json_file("Processed data/abstract_text_dict.json")

In [8]:
print(len(fulltext_dict))
# sanitize_filename("Placing the infant in a supine position on a firm mattress while sleeping")
fulltext_dict["acute_myocardial_infarction"][0]
# # ctx = fulltext_dict[sanitize_filename("Acute_myocardial_infarction")][0]
# # truncate_to_words(".\n".join(ctx))
"Acute_Myocardial_Infarction".lower()

3627


'acute_myocardial_infarction'

<h2>Document text arrangement to search using FIASS Index.</h2>

In [9]:
# save_string_to_file(medQA_Filtered_data, "Processed data/Full_text_filtered.json")
# save_string_to_file(filtered_abstract_data, "Processed data/abstract_text_filtered.json")
# medQA_Filtered_fulltext = read_json_file("Processed data/Full_text_filtered.json")
# medQA_filtered_abstract = read_json_file("Processed data/abstract_text_filtered.json")

# Library installations

In [9]:
!pip install langchain
!pip install transformers
!pip install accelerate
!pip install bitsandbytes
!pip install --upgrade pip
!pip install --upgrade langchain
!pip install langchain_community
!pip list | grep langchain
!pip list | grep langchain_community

!pip install -U langchain-huggingface


Collecting langchain
  Downloading langchain-0.2.14-py3-none-any.whl.metadata (7.1 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain)
  Downloading aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting async-timeout<5.0.0,>=4.0.0 (from langchain)
  Downloading async_timeout-4.0.3-py3-none-any.whl.metadata (4.2 kB)
Collecting langchain-core<0.3.0,>=0.2.32 (from langchain)
  Downloading langchain_core-0.2.33-py3-none-any.whl.metadata (6.2 kB)
Collecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)
  Downloading langchain_text_splitters-0.2.2-py3-none-any.whl.metadata (2.1 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchain)
  Downloading langsmith-0.1.100-py3-none-any.whl.metadata (13 kB)
Collecting pydantic<3,>=1 (from langchain)
  Downloading pydantic-2.8.2-py3-none-

# Building the pipeline with the langchain

In [10]:
import os
import langchain

### prompts
from langchain import PromptTemplate, LLMChain

### models
# from langchain.llms import HuggingFacePipeline
# from langchain.embeddings import HuggingFaceInstructEmbeddings


import torch
import transformers
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline
)


#model = "/kaggle/input/llama-2/pytorch/7b-chat-hf/1"
#model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"
model = "meta-llama/Meta-Llama-3-8B"
#model = "Undi95/Meta-Llama-3-8B-hf"

tokenizer = AutoTokenizer.from_pretrained(model)

        
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.float16,
    bnb_4bit_use_double_quant = True,
)

model_llama = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config = bnb_config,
    device_map = 'auto',
    do_sample=False,
    token='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'
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Designing a context summary prompt

In [15]:
question = questions_data[0]['question']
options = "\nA. Ampicillin\nB. Ceftriaxone\nC. Doxycycline\nD. Nitrofurantoin\n"

option_dict = questions_data[0]['options']

# Search dictionary index for each option and retrieve the corresponding text
retrieved_contexts = []
for option_key, option_text in option_dict.items():
   ctx = fulltext_dict.get(sanitize_filename(option_text).lower())
   if (ctx and len(ctx)>0) :
     retrieved_context = truncate_to_words("\n".join(ctx[0]),800)
     retrieved_contexts.append(retrieved_context)

context = '\n'.join([s.strip() + '.' for s in retrieved_contexts])
context



In [16]:
# Use the summarization prompt to generate a summary
summarization_prompt = f"""[INST] Summarize the following text concisely:[/INST]
{context}
"""

with torch.no_grad():
    summary_output = model_llama.generate(
        **tokenizer(summarization_prompt, return_tensors="pt",truncation=True).to(model_llama.device),
        max_new_tokens=128,
        pad_token_id=tokenizer.eos_token_id
    )

# Decode and print the generated summary
summary_text = tokenizer.decode(summary_output[0], skip_special_tokens=True)
summary_text = summary_text.replace("[INST] Summarize the following text concisely in short:[/INST]", "").strip()
print(summary_text)

[INST] Summarize the following text concisely:[/INST]
AmpicillinEster Bonded Branched Polymers Characterization Cyto Genotoxicity and Controlled DrugRelease Behaviour The development and characterization of novel macromolecular conjugates of ampicillin using branched biodegradable polymers has been described in this study The conjugates have been prepared coupling the βlactam antibiotic with branched polymer matrices based on the natural oligopeptide core The cyto and genotoxicity of the synthesized polymers were evaluated with a bacterial luminescence test two protozoan assays and Salmonella typhimurium TA1535 The presence of a newly formed covalent bond between the drug and the polymer matrices was confirmed by 1HNMR and FTIR studies A drug content 156 and 102 mole in the macromolecular conjugates has been determined The obtained macromolecular products have been subjected to further in vitro release studies The total percentage of ampicillin released after 21 days of incubation was 

In [19]:
# prompt generation
template_context = """
###Context###:\n {context} \n [INST]Answer the following below question using the Context.[/INST]\n
###Question###:\n {question} \n[INST]Select the correct option only. No explanation required[/INST]\n
Options: {options}

#Answer:""" # Force a single-line response

prompt_template_context = PromptTemplate(template=template_context, input_variables=["question", "options", "context"])
prompt_context = prompt_template_context.format(question=question, options=options, context=context) 

# # prompt generation
# template_context = """Question: {question}
# Context: {context}[INST]Select the correct option only. No explanation required[/INST]

# Options: {options}

# #Answer:""" # Force a single-line response

# prompt_template_context = PromptTemplate(template=template_context, input_variables=["question", "options", "context"])
# #prompt_context = prompt_template_context.format(question=question, options=options, context=context) 

In [20]:
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token 
inputs = tokenizer(prompt_context, return_tensors='pt', truncation=True, padding=False, max_length=8000).to(model_llama.device)
outputs = model_llama.generate(**inputs,max_new_tokens=1,do_sample=False, top_p=None, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)




###Context###:
 AmpicillinEster Bonded Branched Polymers Characterization Cyto Genotoxicity and Controlled DrugRelease Behaviour The development and characterization of novel macromolecular conjugates of ampicillin using branched biodegradable polymers has been described in this study The conjugates have been prepared coupling the βlactam antibiotic with branched polymer matrices based on the natural oligopeptide core The cyto and genotoxicity of the synthesized polymers were evaluated with a bacterial luminescence test two protozoan assays and Salmonella typhimurium TA1535 The presence of a newly formed covalent bond between the drug and the polymer matrices was confirmed by 1HNMR and FTIR studies A drug content 156 and 102 mole in the macromolecular conjugates has been determined The obtained macromolecular products have been subjected to further in vitro release studies The total percentage of ampicillin released after 21 days of incubation was nearly 60 and 14 and this resulted fr

In [19]:
position = response.find('#Answer:')
prediction = response[position+8 :position+10].strip()
prediction

'D'

# Data Loader with RAG using FAISS index

In [12]:
def truncate_to_words(text, max_words=4000):
  """Truncates a text string to a maximum number of words.

  Args:
    text: The input text string.
    max_words: The maximum number of words to keep.

  Returns:
    The truncated text string.
  """

  words = text.split()
  if len(words) <= max_words:
    return text
  else:
    truncated_text = ' '.join(words[:max_words])
    return truncated_text

In [21]:
# sanitize_filename("Placing the infant in a supine position on a firm mattress while sleeping")
# fulltext_dict["Acute_myocardial_infarction"]
# # ctx = fulltext_dict[sanitize_filename("Acute_myocardial_infarction")][0]
# # truncate_to_words(".\n".join(ctx))

In [21]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from tqdm import tqdm  # For progress bar
from datetime import datetime

class QuestionAnswerDataset_RAG(Dataset):
    def __init__(self, questions_data, faiss_texts):
        self.questionData = questions_data
        self.faiss_texts = faiss_texts
    

    def __len__(self):
        return len(self.questionData)

    def __getitem__(self, idx):
        question_data = self.questionData[idx]
        question = question_data['question']
        options = question_data['options']
        options_str = "\n".join([f"{key}. {value}" for key, value in options.items()])
        answer = question_data['answer_idx']
        
        # Search FAISS index for each option and retrieve the corresponding text
        retrieved_contexts = []
        for option_key, option_text in options.items():
            # code for full text extarction.
            ctx = self.faiss_texts.get(sanitize_filename(option_text).lower())
            if (ctx and len(ctx)>0) :
                retrieved_context = truncate_to_words("\n".join(ctx[0]),600)
                retrieved_contexts.append(retrieved_context)

        # Combine the contexts with the question and options for the final return
        combined_context = truncate_to_words("\n".join(retrieved_contexts))

        # Use the summarization prompt to generate a summary
        summarization_prompt = f"""[INST] Summarize the following text concisely:[/INST]
        {combined_context}
        """
        
        with torch.no_grad():
            summary_output = model_llama.generate(
                **tokenizer(summarization_prompt, return_tensors="pt").to(model_llama.device),
                max_new_tokens=128,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode and print the generated summary
        summary_text = tokenizer.decode(summary_output[0], skip_special_tokens=True)
        summary_text = summary_text.replace("[INST] Summarize the following text concisely:[/INST]", "").strip()

        return question, options_str, answer, summary_text

    def get_text_embedding(self, text):
        # Placeholder for the actual embedding generation logic
        # Replace this with the method to generate embeddings from text
        return np.random.rand(768)  # Example: Replace with actual embedding



# Load your dataset
dataset_RAG = QuestionAnswerDataset_RAG(questions_data[:1000], data_dict)
dataloader_RAG = DataLoader(dataset_RAG, batch_size=1, shuffle=False)  # Adjust batch_size as needed

# Running Batches with context

In [None]:
correct_predictions = 0
total_predictions = 0
responses = []
answers = []

option_letters = ['A', 'B', 'C', 'D']
option_tokens = [362, 426, 356, 423]
logit_values = []

for batch in tqdm(dataloader_RAG):
    questions, options_strs, answer_idxs, combined_contexts = batch
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token
    prompts = [prompt_template_context.format(question=question, options=options_str, context=combined_context) for question, options_str,combined_context in zip(questions, options_strs,combined_contexts)]
    
    inputs = tokenizer(prompts, return_tensors='pt', padding=False, truncation=True, max_length=8000).to(model_llama.device)
    
    with torch.no_grad():
        # outputs = model_llama.generate(**inputs,max_new_tokens=1,do_sample=False,top_p=None,pad_token_id=tokenizer.eos_token_id)
        op = model_llama(**inputs)
    
    # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    #retrieve logits for the next token
    logits = op.logits[:, -1,:]
    
    #print(decoded_outputs)
    # for decoded_output, answer in zip(decoded_outputs, answer_idxs):
    for answer, logit in zip(answer_idxs,logits):
        # position = decoded_output.find('#Answer:')
        # answer_pred = decoded_output[position+8 :position+10].strip()
        #print(answer_pred)

        logit_values.append(logit[option_tokens].tolist())
        # if answer == answer_pred.strip():
            # correct_predictions += 1
        
        # responses.append(answer_pred)
        # answers.append(answer)
        # total_predictions += 1

# print(f"Accuracy: {correct_predictions / len(responses):.2%}")
save_string_to_file(logit_values, "Logit_scores/Wiki_CS_Document_wise_Logit/llama_RAG_CS_Logit_"+datetime.now().strftime("%Y%m%d_%H%M%S")+".json")
# correct_predictions

  0%|          | 0/1000 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
 58%|█████▊    | 580/1000 [1:10:15<51:07,  7.30s/it]  

In [23]:
print(f"Accuracy: {correct_predictions / len(responses):.2%}")
correct_predictions

Accuracy: 36.10%


361

In [21]:
print(responses.count(''))
responses.index('')
# question
responses
# filtered_list = [element for element in responses if element not in ['A', 'B', 'C', 'D']]
# filtered_list

35


['B',
 'C',
 'A',
 'C',
 'D',
 'B',
 'D',
 'B',
 'D',
 'B',
 'A',
 'A',
 'C',
 'C',
 'C',
 'C',
 'D',
 'A',
 'B',
 'B',
 'A',
 'C',
 'A',
 'C',
 'B',
 'C',
 'C',
 '',
 'B',
 'D',
 'A',
 'D',
 'D',
 'B',
 'C',
 'A',
 'B',
 'D',
 'C',
 'D',
 'D',
 'C',
 'D',
 'D',
 'C',
 'D',
 'D',
 'D',
 'D',
 'B',
 'D',
 'D',
 'C',
 'C',
 'B',
 'D',
 'A',
 'D',
 'D',
 'B',
 'A',
 'D',
 'A',
 'D',
 'A',
 'A',
 'A',
 'D',
 'C',
 'A',
 'D',
 'D',
 'D',
 'D',
 'B',
 'B',
 'A',
 'D',
 'A',
 'B',
 'C',
 'A',
 'C',
 'C',
 'D',
 'A',
 'D',
 'B',
 'A',
 'D',
 'A',
 'D',
 'A',
 'D',
 'A',
 'D',
 'B',
 'A',
 'D',
 'B',
 'D',
 'C',
 'C',
 'D',
 'A',
 'C',
 'D',
 'A',
 'D',
 'D',
 'D',
 '',
 'A',
 'B',
 '',
 'A',
 'A',
 'D',
 'B',
 'D',
 'C',
 'A',
 'D',
 'D',
 'A',
 'D',
 'B',
 'D',
 'D',
 'D',
 'B',
 'B',
 'D',
 'C',
 'C',
 'B',
 'D',
 'C',
 'B',
 'B',
 'D',
 'D',
 'C',
 'A',
 'D',
 'A',
 'D',
 'C',
 'D',
 'B',
 'B',
 'D',
 'A',
 '',
 'C',
 'C',
 'A',
 'C',
 'C',
 'B',
 'A',
 'A',
 'D',
 'C',
 'B',
 'B',
 'D',
 '

In [35]:
def free_gpu_cache():
  """Frees the GPU cache and memory."""
  if torch.cuda.is_available():
      with torch.cuda.device(0):
          torch.cuda.empty_cache()