# MedQA Dataset parsing for English data

In [1]:
import pandas as pd 
import json
import numpy as np
import torch
import os

def read_question_answer_file(file_path):
    """Reads a JSONL file with question-answer data and returns a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line))  # Parse each line as JSON
    return data

# Load your dataset
dataset_path = r'med_data/phrases_no_exclude_train.jsonl'  # Replace with the path to your JSON file
questions_data = read_question_answer_file(dataset_path)

In [2]:
print(questions_data[0]['question'])
print(questions_data[0]['answer_idx'])
print(questions_data[0]['options'])

A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?
D
{'A': 'Ampicillin', 'B': 'Ceftriaxone', 'C': 'Doxycycline', 'D': 'Nitrofurantoin'}


<h2>Data pre-processing and extraction</h2>

In [3]:
import re

def clean_text(text):
  """Removes special characters and extra whitespace from text.

  Args:
    text: The input text to be cleaned.

  Returns:
    The cleaned text.
  """
  cleaned_text = ''
  if text is None:
     return cleaned_text
#   # Regular expression to match HTTP/HTTPS URLs
#   url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
    
#   # Regular expression to match DOI links (e.g., doi:10.2519/jospt.2024.12398)
#   doi_pattern = r'doi:\s*\d{2}\.\d{4,9}/[-._;()/:A-Z0-9]+'

  # Replace URLs with an empty string
  cleaned_text = re.sub(url_pattern, '', text)
    
  # Replace DOI links with an empty string
  cleaned_text = re.sub(doi_pattern, '', cleaned_text, flags=re.IGNORECASE)
    
  # Remove special characters, but keep letters, digits, and single spaces
  cleaned_text = re.sub(r'[^A-Za-z0-9\s]', '', text)
  # Replace multiple spaces with a single space
  cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
  # Strip leading and trailing spaces (if any)
  cleaned_text = cleaned_text.strip()
  return cleaned_text

In [8]:
def save_string_to_file(data, filename):
  """Saves a string to a text file.

  Args:
    text: The string to be saved.
    filename: The name of the file to create.
  """

  with open(filename, "w",encoding='utf-8') as f:
    json.dump(data, f)

# # Example usage:
# my_string = "This is the text I want to save."
# save_string_to_file(my_string, "output.txt")

In [9]:
def sanitize_filename(filename):
  """Sanitizes a filename by replacing special characters with underscores.

  Args:
    filename: The original filename.

  Returns:
    The sanitized filename.
  """

  # Replace non-alphanumeric characters with underscores
  filename = re.sub(r'[^\w]', '_', filename)

  # Remove leading and trailing underscores
  filename = filename.strip('_')

  return filename

In [10]:
import json

def extract_text(data):
  """Extracts text from the given JSON data.

  Args:
    data: The JSON data.

  Returns:
    A list of text strings.
  """

  texts = []
  for document in data:
    for passage in document['documents'][0]['passages']:  # Access the first document's passages
        words = passage['text'].split()
        if(len(words)> 10):
          texts.append(clean_text(passage['text']))
  return texts



<h2>Generate bert embeddings</h2>

In [11]:
import torch
import numpy as np
from transformers import BertTokenizer, BertModel

# Initialize the tokenizer and model (bert-base-uncased)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Ensure the model runs on GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bert_model.to(device)


# Function to compute embeddings
def get_embeddings(text,tokenizer,model):
    
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():  # No need to compute gradients
        outputs = model(**inputs)

    last_hidden_state = outputs.last_hidden_state
    embeddings = last_hidden_state.mean(dim=1)

    # Compute the mean of the token embeddings to get a fixed-size representation
    embeddings = embeddings.squeeze().cpu().numpy()  # (batch size x hidden size)

    return embeddings



<h2>Dictionary format data</h2>

In [12]:
# Path to the directory containing JSON files
pubmed_dir = 'Pubmed_Abstract'

# Dictionary to store the original data (not embeddings) by file key


def extract_dataDict(directory):
    data_dict = {}
    # Iterate through each JSON file in the directory
    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        with open(filepath, 'r') as file:
            data = json.load(file)
            file_key = sanitize_filename(filename[:-4])  
            data_dict[file_key] = data
    return data_dict



In [13]:
data_dict = extract_dataDict(pubmed_dir)

In [15]:
print(len(data_dict))
print(data_dict['Arthralgias'])
dic_abs = data_dict['Arthralgias']



3627
{'39115687': 'OBJECTIVES: Previous evidence suggests that bisphosphonates (BPs) may lower the risk of recurrent fractures and enhance functional recovery in patients with fractures. However, there has been controversy regarding the optimal timing of treatment initiation for patients with fragility fractures. We conducted a meta-analysis to evaluate the available evidence on the use of BPs during the perioperative period and compared it to both non-perioperative periods and non-usage.\nMETHODS: Electronic searches were performed using PubMed, EMBASE, Web of Science and the Cochrane Library published before February 2023, without any language restrictions. The primary outcomes included fracture healing rate, healing time, and new fractures. We also examined a wide range of secondary outcomes. Random effects meta-analysis was used.\nRESULTS: A total of 19 clinical trials involving 2543 patients were included in this meta-analysis. When comparing patients with non-perioperative BPs us

In [16]:
# Fetching the abstract text from the documents.
from tqdm import tqdm
filtered_abstract_data = []

for key, data in tqdm(data_dict.items()):
    for k, value in data.items():
        filtered_abstract_data.append(value)

100%|██████████| 3627/3627 [00:00<00:00, 496580.40it/s]


In [17]:
filtered_abstract_data = [element for element in filtered_abstract_data if element]
len(filtered_abstract_data)

31817

<h2>Document text arrangement to search using FIASS Index.</h2>

In [26]:
def getEmbeddingsListwise(data_list):
    embeddings = []
    for text in tqdm(data_list):
        embedding = get_embeddings(text,bert_tokenizer, bert_model)
        embeddings.append(embedding)
    return embeddings

fiass_embeddings = getEmbeddingsListwise(filtered_abstract_data)

100%|██████████| 31817/31817 [05:15<00:00, 100.74it/s]


# FIASS Installation and Storage

In [13]:
# !pip install faiss-cpu 
# !pip install faiss-gpu  # For GPU support, if you have a CUDA-capable GPU
import faiss

In [15]:
import faiss

embeddings = np.vstack(fiass_embeddings) 
# Get the dimensionality of the embeddings
embedding_dim = embeddings.shape[1]

# Create a FAISS index
# IndexFlatL2 is a simple, exact nearest neighbor search index with L2 distance
index = faiss.IndexFlatL2(embedding_dim)

# Add embeddings to the index
index.add(embeddings)

# Check the number of vectors in the index
print(f"Number of vectors in the index: {index.ntotal}")

# Save the index to a file for later use
faiss.write_index(index, 'FAISS/QA_abstract_index.index')




Number of vectors in the index: 31817


In [18]:
import faiss
# Load the FAISS index
faiss_index = faiss.read_index('FAISS/QA_abstract_index.index')

# Function to search the FAISS index
def search_faiss_index(query, index, tokenizer, model, top_k=5):
    # Compute the query embedding
    query_embedding = get_embeddings(query, tokenizer, model)

    # Reshape the query embedding to match FAISS expected input format
    query_embedding = np.expand_dims(query_embedding, axis=0).astype('float32')

    # Search the FAISS index
    distances, indices = index.search(query_embedding, top_k)

    return distances, indices

In [19]:
# Example usage
query = "Nitrofurantoin"
top_k = 10

distances, indices = search_faiss_index(query, faiss_index, bert_tokenizer, bert_model, top_k=top_k)

# Print results
print(f"Top-{top_k} results:")
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
    print(f"Result {i + 1}:")
    print(f"Index: {idx}, Distance: {distance}")
    # If you have the original text stored, you can retrieve it like this:
    # print(f"Text: {embeddings_dict['some_key'][idx]}")

Top-10 results:
Result 1:
Index: 825, Distance: 51.212196350097656
Result 2:
Index: 816, Distance: 51.5128173828125
Result 3:
Index: 23352, Distance: 52.298065185546875
Result 4:
Index: 24373, Distance: 52.298065185546875
Result 5:
Index: 22601, Distance: 53.6589241027832
Result 6:
Index: 21158, Distance: 54.455162048339844
Result 7:
Index: 834, Distance: 54.53868103027344
Result 8:
Index: 25209, Distance: 54.53868103027344
Result 9:
Index: 20967, Distance: 54.56218719482422
Result 10:
Index: 374, Distance: 54.8944206237793


# Library installations

In [16]:
!pip install langchain
!pip install transformers
!pip install accelerate
!pip install bitsandbytes
!pip install --upgrade pip
!pip install --upgrade langchain
!pip install langchain_community
!pip list | grep langchain
!pip list | grep langchain_community

!pip install -U langchain-huggingface


Collecting langchain
  Downloading langchain-0.2.12-py3-none-any.whl.metadata (7.1 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain)
  Downloading aiohttp-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting async-timeout<5.0.0,>=4.0.0 (from langchain)
  Downloading async_timeout-4.0.3-py3-none-any.whl.metadata (4.2 kB)
Collecting langchain-core<0.3.0,>=0.2.27 (from langchain)
  Downloading langchain_core-0.2.29-py3-none-any.whl.metadata (6.2 kB)
Collecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)
  Downloading langchain_text_splitters-0.2.2-py3-none-any.whl.metadata (2.1 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchain)
  Downloading langsmith-0.1.98-py3-none-any.whl.metadata (13 kB)
Collecting pydantic<3,>=1 (from langchain)
  Downloading pydantic-2.8.2-py3-none-a

# Building the pipeline with the langchain

In [7]:
import os
import langchain

### prompts
from langchain import PromptTemplate, LLMChain

### models
# from langchain.llms import HuggingFacePipeline
# from langchain.embeddings import HuggingFaceInstructEmbeddings


import torch
import transformers
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline
)


#model = "/kaggle/input/llama-2/pytorch/7b-chat-hf/1"
#model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"
model = "meta-llama/Meta-Llama-3-8B"
# model = "Undi95/Meta-Llama-3-8B-hf"

tokenizer_llama = AutoTokenizer.from_pretrained(model)

        
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.float16,
    bnb_4bit_use_double_quant = True,
)

model_llama = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config = bnb_config,
    device_map = 'auto',
    do_sample=False,
    token='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'
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [55]:
question = questions_data[60]['question']
option_dict = questions_data[60]['options']
options_str = "\n".join([f"{key}. {value}" for key, value in option_dict.items()])

# Search FAISS index for each option and retrieve the corresponding text
retrieved_contexts = []
for option_key, option_text in option_dict.items():
    distances, indices = search_faiss_index(option_text, faiss_index, bert_tokenizer, bert_model, top_k=5)
    retrieved_context = " ".join([filtered_abstract_data[idx] for i, (distance, idx) in enumerate(zip(distances[0], indices[0]))])
    retrieved_contexts.append(retrieved_context)

context = '\n'.join([s.strip() + '.' for s in retrieved_contexts])
context = truncate_to_words(context)

In [20]:
# prompt generation
template_context = """Question: {question}
Context: {context}[INST]Select the correct option from A, B, C, D after reading the context below. No explanation required[/INST]

Options: {options}

#Answer:""" # Force a single-line response

prompt_template_context = PromptTemplate(template=template_context, input_variables=["question", "options", "context"])
#prompt_context = prompt_template_context.format(question=question, options=options_str, context=context) 

In [57]:
tokenizer_llama.padding_side = 'left'
tokenizer_llama.pad_token = tokenizer_llama.eos_token 
inputs = tokenizer_llama(prompt_context, return_tensors='pt', truncation=True, padding="max_length", max_length=8000).to(model_llama.device)
outputs = model_llama.generate(**inputs, max_new_tokens=1)
response = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
print(response)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Question: A new screening test utilizing a telemedicine approach to diagnosing diabetic retinopathy has been implemented in a diabetes clinic. An ophthalmologist’s exam was also performed on all patients as the gold standard for diagnosis. In a pilot study of 500 patients, the screening test detected the presence of diabetic retinopathy in 250 patients. Ophthalmologist exam confirmed a diagnosis of diabetic retinopathy in 200 patients who tested positive in the screening test, as well as 10 patients who tested negative in the screening test. What is the sensitivity, specificity, positive predictive value, and negative predictive value of the screening test?
Context: BACKGROUND: Artificial intelligence (AI) electrocardiogram (ECG) analysis can enable detection of hyperkalemia. In this validation, we assessed the algorithm's performance in two high acuity settings. METHODS: An emergency department (ED) cohort (February to August 2021) and a mixed intensive care unit (ICU) cohort (August 

In [48]:
position = response.find('#Answer:')
prediction = response[position+8 :position+10].strip()
prediction

'B'

# Designing a context summary prompt

In [None]:
# Use the summarization prompt to generate a summary
summarization_prompt = f"""[INST] Summarize the following text concisely:

{context_text}
[/INST]
"""

with torch.no_grad():
    summary_output = model_llama.generate(
        **tokenizer(summarization_prompt, return_tensors="pt").to(model_llama.device),
        max_new_tokens=128,
    )

# Decode and print the generated summary
summary_text = tokenizer.decode(summary_output[0], skip_special_tokens=True)
print(summary_text)

# Data Loader with RAG using FAISS index

In [21]:
def truncate_to_words(text, max_words=3500):
  """Truncates a text string to a maximum number of words.

  Args:
    text: The input text string.
    max_words: The maximum number of words to keep.

  Returns:
    The truncated text string.
  """

  words = text.split()
  if len(words) <= max_words:
    return text
  else:
    truncated_text = ' '.join(words[:max_words])
    return truncated_text

In [23]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from tqdm import tqdm  # For progress bar

class QuestionAnswerDataset_RAG(Dataset):
    def __init__(self, questions_data, faiss_index, faiss_texts):
        self.questionData = questions_data
        self.faiss_index = faiss_index
        self.faiss_texts = faiss_texts
    

    def __len__(self):
        return len(self.questionData)

    def __getitem__(self, idx):
        question_data = self.questionData[idx]
        question = question_data['question']
        options = question_data['options']
        options_str = "\n".join([f"{key}. {value}" for key, value in options.items()])
        answer = question_data['answer_idx']

        # Search FAISS index for each option and retrieve the corresponding text
        retrieved_contexts = []
        for option_key, option_text in options.items():
            distances, indices = search_faiss_index(option_text, self.faiss_index, bert_tokenizer, bert_model, top_k=5)
            retrieved_context = " ".join([self.faiss_texts[idx] for i, (distance, idx) in enumerate(zip(distances[0], indices[0]))])
            retrieved_context = truncate_to_words(retrieved_context,1500)
            retrieved_contexts.append(retrieved_context)

        # Combine the contexts with the question and options for the final return
        combined_context = truncate_to_words("\n".join(retrieved_contexts))

        return question, options_str, answer, combined_context

    def get_text_embedding(self, text):
        # Placeholder for the actual embedding generation logic
        # Replace this with the method to generate embeddings from text
        return np.random.rand(768)  # Example: Replace with actual embedding


faiss_texts = [...]  # Load the list of texts associated with the FAISS index

# Load your dataset
dataset_RAG = QuestionAnswerDataset_RAG(questions_data, faiss_index, filtered_abstract_data)
dataloader_RAG = DataLoader(dataset_RAG, batch_size=1, shuffle=False)  # Adjust batch_size as needed

# Running Batches with context

In [24]:
correct_predictions = 0
total_predictions = 0
responses = []
answers = []

for batch in tqdm(dataloader_RAG):
    questions, options_strs, answer_idxs, combined_contexts = batch
    tokenizer_llama.padding_side = 'left'
    tokenizer_llama.pad_token = tokenizer_llama.eos_token
    prompts = [prompt_template_context.format(question=question, options=options_str, context=combined_context) for question, options_str,combined_context in zip(questions, options_strs,combined_contexts)]
    
    inputs = tokenizer_llama(prompts, return_tensors='pt', padding=True, truncation=True, max_length=8000).to(model_llama.device)
    
    with torch.no_grad():
        outputs = model_llama.generate(**inputs,max_new_tokens=1)
    
    decoded_outputs = tokenizer_llama.batch_decode(outputs, skip_special_tokens=True)
    
    #print(decoded_outputs)
    for decoded_output, answer in zip(decoded_outputs, answer_idxs):
        position = decoded_output.find('#Answer:')
        answer_pred = decoded_output[position+8 :position+10].strip()
        #print(answer_pred)
        if answer == answer_pred.strip():
            correct_predictions += 1
        
        responses.append(answer_pred)
        answers.append(answer)
        total_predictions += 1

  0%|          | 0/10178 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 1/10178 [00:00<2:16:51,  1.24it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 2/10178 [00:02<3:10:56,  1.13s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 3/10178 [00:03<3:25:06,  1.21s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 4/10178 [00:03<2:37:47,  1.07it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 5/10178 [00:05<2:55:15,  1.03s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 6/10178 [00:05<2:40:27,  1.06it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 7/10178 [00:06<2:44:01,  1.03it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 8

KeyboardInterrupt: 

In [30]:
print(f"Accuracy: {correct_predictions / len(responses):.2%}")
correct_predictions

Accuracy: 28.20%


282

In [24]:
responses.count('n:')
# responses.index('n:')


0

In [30]:
def free_gpu_cache():
  """Frees the GPU cache and memory."""
  if torch.cuda.is_available():
      with torch.cuda.device(0):
          torch.cuda.empty_cache()

In [1]:
import torch
import gc

# Clear unused objects and cache
gc.collect()
torch.cuda.empty_cache()