# MedQA Dataset parsing for English data

In [1]:
import pandas as pd 
import json
import numpy as np
import torch
import os

def read_question_answer_file(file_path):
    """Reads a JSONL file with question-answer data and returns a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line))  # Parse each line as JSON
    return data

# Load your dataset
dataset_path = r'C:\Users\ranad\OneDrive - University of Glasgow\Attachments\Msc Final Year project\Data\MedQA-USMLE-4-options\phrases_no_exclude_train.jsonl'  # Replace with the path to your JSON file
questions_data = read_question_answer_file(dataset_path)

In [2]:
print(questions_data[0]['question'])
print(questions_data[0]['answer_idx'])
print(questions_data[0]['options'])

A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?
D
{'A': 'Ampicillin', 'B': 'Ceftriaxone', 'C': 'Doxycycline', 'D': 'Nitrofurantoin'}


<h2>Data pre-processing and extraction</h2>

In [10]:
import re

def clean_text(text):
  """Removes special characters and extra whitespace from text.

  Args:
    text: The input text to be cleaned.

  Returns:
    The cleaned text.
  """
  # Remove special characters, but keep letters, digits, and single spaces
  cleaned_text = re.sub(r'[^A-Za-z0-9\s]', '', text)
    # Replace multiple spaces with a single space
  cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
    # Strip leading and trailing spaces (if any)
  cleaned_text = cleaned_text.strip()
  return cleaned_text

In [11]:
def save_string_to_file(data, filename):
  """Saves a string to a text file.

  Args:
    text: The string to be saved.
    filename: The name of the file to create.
  """

  with open(filename, "w",encoding='utf-8') as f:
    json.dump(data, f)

# # Example usage:
# my_string = "This is the text I want to save."
# save_string_to_file(my_string, "output.txt")

In [12]:
def sanitize_filename(filename):
  """Sanitizes a filename by replacing special characters with underscores.

  Args:
    filename: The original filename.

  Returns:
    The sanitized filename.
  """

  # Replace non-alphanumeric characters with underscores
  filename = re.sub(r'[^\w]', '_', filename)

  # Remove leading and trailing underscores
  filename = filename.strip('_')

  return filename

In [13]:
import json

def extract_text(data):
  """Extracts text from the given JSON data.

  Args:
    data: The JSON data.

  Returns:
    A list of text strings.
  """

  texts = []
  for document in data:
    for passage in document['documents'][0]['passages']:  # Access the first document's passages
        words = passage['text'].split()
        if(len(words)> 10):
          texts.append(clean_text(passage['text']))
  return texts



<h2>Generate bert embeddings</h2>

In [None]:
import torch
import numpy as np
from transformers import BertTokenizer, BertModel

# Initialize the tokenizer and model (bert-base-uncased)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Ensure the model runs on GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bert_model.to(device)


# Function to compute embeddings
def get_embeddings(text,tokenizer,model):
    
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():  # No need to compute gradients
        outputs = model(**inputs)

    last_hidden_state = outputs.last_hidden_state
    embeddings = last_hidden_state.mean(dim=1)

    # Compute the mean of the token embeddings to get a fixed-size representation
    embeddings = embeddings.squeeze().cpu().numpy()  # (batch size x hidden size)

    return embeddings

<h2>Dictionary format data</h2>

In [15]:
# Path to the directory containing JSON files
pubmed_dir = 'C:/Users/ranad/Documents/Pubmed_Full_text/'

# Dictionary to store the original data (not embeddings) by file key


def extract_dataDict(directory):
    data_dict = {}
    # Iterate through each JSON file in the directory
    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        with open(filepath, 'r') as file:
            data = json.load(file)
            file_key = sanitize_filename(filename[:-4])  
            data_dict[file_key] = data
    return data_dict



In [16]:
data_dict = extract_dataDict(pubmed_dir)

In [17]:
print(len(data_dict))
data_dict['Arthralgias']

2992


[['Osteoarthritis is the most common form of arthritis and a leading cause of disability worldwide largely due to pain the primary symptom of the disease The pain experience in knee osteoarthritis in particular is wellrecognized as typically transitioning from intermittent weightbearing pain to a more persistent chronic pain Methods to validly assess pain in osteoarthritis studies have been developed to address the complex nature of the pain experience The etiology of pain in osteoarthritis is recognized to be multifactorial with both intraarticular and extraarticular risk factors Nonetheless greater insights are needed into pain mechanisms in osteoarthritis to enable rational mechanismbased management of pain Consequences of pain related to osteoarthritis contribute to a substantial socioeconomic burden',
  'The hallmark symptom of osteoarthritis OA the most common form of arthritis is pain This is the symptom that drives individuals to seek medical attention and contributes to functi

<h2>Document text arrangement to search using FIASS Index.</h2>

In [19]:
from tqdm import tqdm
medQA_Filtered_data = []

for key, data in tqdm(data_dict.items()):
        for texts in data:
            for text in texts:
                if len(text.split())>10:
                    medQA_Filtered_data.append(clean_text(text))

100%|██████████████████████████████████████████████████████████████████████████████| 2992/2992 [00:34<00:00, 86.48it/s]


In [20]:
len(medQA_Filtered_data)

849033

In [23]:
def getEmbeddingsListwise(data_list):
    embeddings = []
    for text in tqdm(data_list):
        embedding = get_embeddings(text,bert_tokenizer, bert_model)
        embeddings.append(embedding)
    return embeddings

fiass_embeddings = getEmbeddingsListwise(medQA_Filtered_data)

  1%|▍                                                                     | 5791/849033 [1:04:01<155:22:29,  1.51it/s]


KeyboardInterrupt: 

# FIASS Installation and Storage

In [25]:
# !pip install faiss-cpu 
#pip install faiss-gpu  # For GPU support, if you have a CUDA-capable GPU

In [None]:
import faiss

# Get the dimensionality of the embeddings
embedding_dim = embeddings.shape[1]

# Create a FAISS index
# IndexFlatL2 is a simple, exact nearest neighbor search index with L2 distance
index = faiss.IndexFlatL2(embedding_dim)

# Add embeddings to the index
index.add(embeddings)

# Check the number of vectors in the index
print(f"Number of vectors in the index: {index.ntotal}")

# Save the index to a file for later use
faiss.write_index(index, 'FAISS\QA200_327_index.index')




In [None]:
# Load the FAISS index
faiss_index = faiss.read_index('FAISS\QA200_327_index.index')

# Function to search the FAISS index
def search_faiss_index(query, index, tokenizer, model, top_k=5):
    # Compute the query embedding
    query_embedding = get_embeddings(query, tokenizer, model)

    # Reshape the query embedding to match FAISS expected input format
    query_embedding = np.expand_dims(query_embedding, axis=0).astype('float32')

    # Search the FAISS index
    distances, indices = index.search(query_embedding, top_k)

    return distances, indices

In [None]:
# Example usage
query = "Nitrofurantoin"
top_k = 5

distances, indices = search_faiss_index(query, index, bert_tokenizer, bert_model, top_k=top_k)

# Print results
print(f"Top-{top_k} results:")
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
    print(f"Result {i + 1}:")
    print(f"Index: {idx}, Distance: {distance}")
    # If you have the original text stored, you can retrieve it like this:
    # print(f"Text: {embeddings_dict['some_key'][idx]}")

# Library installations

In [8]:
# !pip install langchain
# !pip install transformers
# !pip install accelerate
# !pip install bitsandbytes
# !pip install --upgrade pip
# !pip install --upgrade langchain
# !pip install langchain_community
# !pip list | grep langchain
# !pip list | grep langchain_community

# !pip install -U langchain-huggingface


Collecting langchain
  Downloading langchain-0.2.12-py3-none-any.whl.metadata (7.1 kB)
Collecting langchain-core<0.3.0,>=0.2.27 (from langchain)
  Downloading langchain_core-0.2.29-py3-none-any.whl.metadata (6.2 kB)
Collecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)
  Downloading langchain_text_splitters-0.2.2-py3-none-any.whl.metadata (2.1 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchain)
  Downloading langsmith-0.1.98-py3-none-any.whl.metadata (13 kB)
Collecting jsonpatch<2.0,>=1.33 (from langchain-core<0.3.0,>=0.2.27->langchain)
  Downloading jsonpatch-1.33-py2.py3-none-any.whl.metadata (3.0 kB)
Collecting packaging<25,>=23.2 (from langchain-core<0.3.0,>=0.2.27->langchain)
  Downloading packaging-24.1-py3-none-any.whl.metadata (3.2 kB)
Collecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.17->langchain)
  Downloading orjson-3.10.6-cp311-none-win_amd64.whl.metadata (51 kB)
Downloading langchain-0.2.12-py3-none-any.whl (990 kB)
   ------------------------

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tables 3.8.0 requires blosc2~=2.0.0, which is not installed.
tables 3.8.0 requires cython>=0.29.21, which is not installed.
python-lsp-black 1.2.1 requires black>=22.3.0, but you have black 0.0 which is incompatible.


Collecting sentencepiece (from transformers)
  Downloading sentencepiece-0.2.0-cp311-cp311-win_amd64.whl.metadata (8.3 kB)
Downloading sentencepiece-0.2.0-cp311-cp311-win_amd64.whl (991 kB)
   ---------------------------------------- 0.0/991.5 kB ? eta -:--:--
   --------------------- ------------------ 524.3/991.5 kB 3.4 MB/s eta 0:00:01
   ---------------------------------------- 991.5/991.5 kB 3.9 MB/s eta 0:00:00
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0
Collecting accelerate
  Downloading accelerate-0.33.0-py3-none-any.whl.metadata (18 kB)
Collecting huggingface-hub>=0.21.0 (from accelerate)
  Downloading huggingface_hub-0.24.5-py3-none-any.whl.metadata (13 kB)
Collecting safetensors>=0.3.1 (from accelerate)
  Downloading safetensors-0.4.4-cp311-none-win_amd64.whl.metadata (3.9 kB)
Collecting fsspec>=2023.5.0 (from huggingface-hub>=0.21.0->accelerate)
  Downloading fsspec-2024.6.1-py3-none-any.whl.metadata (11 kB)
Downloading accelerat

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
s3fs 2023.4.0 requires fsspec==2023.4.0, but you have fsspec 2024.6.1 which is incompatible.


Collecting bitsandbytes
  Downloading bitsandbytes-0.43.3-py3-none-win_amd64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.43.3-py3-none-win_amd64.whl (136.5 MB)
   ---------------------------------------- 0.0/136.5 MB ? eta -:--:--
   ---------------------------------------- 0.3/136.5 MB ? eta -:--:--
   ---------------------------------------- 1.0/136.5 MB 2.6 MB/s eta 0:00:52
   ---------------------------------------- 1.3/136.5 MB 2.2 MB/s eta 0:01:01
    --------------------------------------- 2.1/136.5 MB 2.6 MB/s eta 0:00:52
    --------------------------------------- 2.6/136.5 MB 2.8 MB/s eta 0:00:48
    --------------------------------------- 2.9/136.5 MB 2.3 MB/s eta 0:00:58
    --------------------------------------- 3.4/136.5 MB 2.3 MB/s eta 0:00:59
   - -------------------------------------- 3.9/136.5 MB 2.3 MB/s eta 0:00:57
   - -------------------------------------- 4.2/136.5 MB 2.4 MB/s eta 0:00:56
   - -------------------------------------- 4.5/136.5 MB 2.3 MB/s eta

Collecting langchain_community
  Downloading langchain_community-0.2.11-py3-none-any.whl.metadata (2.7 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain_community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)
  Downloading marshmallow-3.21.3-py3-none-any.whl.metadata (7.1 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Downloading langchain_community-0.2.11-py3-none-any.whl (2.3 MB)
   ---------------------------------------- 0.0/2.3 MB ? eta -:--:--
   ---------------------------------------- 0.0/2.3 MB ? eta -:--:--
   ---------------------------------------- 0.0/2.3 MB ? eta -:--:--
   ---- ----------------------------------- 0.3/2.3 MB ? eta -:--:--
   --------- ------------------------------ 0.5/2.3 MB 985.5 kB/s eta 0:00:02
   -------

'grep' is not recognized as an internal or external command,
operable program or batch file.
'grep' is not recognized as an internal or external command,
operable program or batch file.


Collecting langchain-huggingface
  Downloading langchain_huggingface-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting sentence-transformers>=2.6.0 (from langchain-huggingface)
  Downloading sentence_transformers-3.0.1-py3-none-any.whl.metadata (10 kB)
Collecting tokenizers>=0.19.1 (from langchain-huggingface)
  Downloading tokenizers-0.19.1-cp311-none-win_amd64.whl.metadata (6.9 kB)
Collecting transformers>=4.39.0 (from langchain-huggingface)
  Downloading transformers-4.44.0-py3-none-any.whl.metadata (43 kB)
Downloading langchain_huggingface-0.0.3-py3-none-any.whl (17 kB)
Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
Downloading tokenizers-0.19.1-cp311-none-win_amd64.whl (2.2 MB)
   ---------------------------------------- 0.0/2.2 MB ? eta -:--:--
   ---- ----------------------------------- 0.3/2.2 MB ? eta -:--:--
   ------------------ --------------------- 1.0/2.2 MB 2.5 MB/s eta 0:00:01
   ---------------------------- ----------- 1.6/2.2 MB 3.0 MB/s eta 0:00:

# Building the pipeline with the langchain

In [21]:
import os
import langchain

### prompts
from langchain import PromptTemplate, LLMChain

### models
# from langchain.llms import HuggingFacePipeline
# from langchain.embeddings import HuggingFaceInstructEmbeddings


import torch
import transformers
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline
)


#model = "/kaggle/input/llama-2/pytorch/7b-chat-hf/1"
#model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"
#model = "meta-llama/Meta-Llama-3-8B"
model = "Undi95/Meta-Llama-3-8B-hf"

tokenizer = AutoTokenizer.from_pretrained(model)

        
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.float16,
    bnb_4bit_use_double_quant = True,
)

model_llama = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config = bnb_config,
    device_map = 'auto',
    token='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'
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [8]:
question = questions_data[0]['question']
options = "\nA. Ampicillin\nB. Ceftriaxone\nC. Doxycycline\nD. Nitrofurantoin\n"

In [None]:
# prompt generation
template_context = f"""Question: {question}
Context: {context}[INST]Select the correct option only. No explanation required[/INST]

Options: {options}

# Answer: """ # Force a single-line response

prompt_context = PromptTemplate(template=template_context, input_variables=["question", "options", "context"])


In [None]:
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token 
inputs = tokenizer(prompt_context, return_tensors='pt', truncation=True, padding="max_length", max_length=4000).to(model_llama.device)
outputs = model_llama.generate(**inputs, max_new_tokens=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

In [None]:
position = response.find('#Answer:')
prediction = response[position+8 :position+10].strip()
prediction

# Designing a context summary prompt

In [None]:
# Use the summarization prompt to generate a summary
summarization_prompt = f"""[INST] Summarize the following text concisely:

{context_text}
[/INST]
"""

with torch.no_grad():
    summary_output = model_llama.generate(
        **tokenizer(summarization_prompt, return_tensors="pt").to(model_llama.device),
        max_new_tokens=128,
    )

# Decode and print the generated summary
summary_text = tokenizer.decode(summary_output[0], skip_special_tokens=True)
print(summary_text)

# Data Loader with RAG using FAISS index

In [None]:
class QuestionAnswerDataset_RAG(Dataset):
    def __init__(self, questions_data, faiss_index, faiss_texts):
        self.questionData = questions_data
        self.faiss_index = faiss_index
        self.faiss_texts = faiss_texts
    

    def __len__(self):
        return len(self.questionData)

    def __getitem__(self, idx):
        question_data = self.questionData[idx]
        question = question_data['question']
        options = question_data['options']
        options_str = "\n".join([f"{key}. {value}" for key, value in options.items()])
        answer = question_data['answer_idx']

        # Search FAISS index for each option and retrieve the corresponding text
        retrieved_contexts = []
        for option_key, option_text in options.items():
            option_embedding = self.get_text_embedding(option_text)
            retrieved_indices = search_faiss_index(option_embedding, self.faiss_index)
            retrieved_context = " ".join([self.faiss_texts[i] for i in retrieved_indices])
            retrieved_contexts.append(retrieved_context)

        # Combine the contexts with the question and options for the final return
        combined_context = "\n".join(retrieved_contexts)

        return question, options_str, answer, combined_context

    def get_text_embedding(self, text):
        # Placeholder for the actual embedding generation logic
        # Replace this with the method to generate embeddings from text
        return np.random.rand(768)  # Example: Replace with actual embedding


faiss_texts = [...]  # Load the list of texts associated with the FAISS index

# Load your dataset
dataset_RAG = QuestionAnswerDataset_RAG(questions_data, faiss_index, faiss_texts)
dataloader_RAG = DataLoader(dataset, batch_size=8, shuffle=False)  # Adjust batch_size as needed

# Running Batches with context

In [34]:
correct_predictions = 0
total_predictions = 0
responses = []
answers = []

for batch in tqdm(dataloader):
    questions, options_strs, answer_idxs, combined_contexts = batch
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token
    prompts = [prompt_context.format(question=question, options=options_str, context=combined_context) for question, options_str,combined_context in zip(questions, options_strs,combined_contexts)]
    
    inputs = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, max_length=4000).to(model_llama.device)
    
    with torch.no_grad():
        outputs = model_llama.generate(**inputs,max_new_tokens=1)
    
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    #print(decoded_outputs)
    for decoded_output, answer in zip(decoded_outputs, answer_idxs):
        position = decoded_output.find('#Answer:')
        answer_pred = decoded_output[position+8 :position+10].strip()
        #print(answer_pred)
        if answer == answer_pred.strip():
            correct_predictions += 1
        
        responses.append(answer_pred)
        answers.append(answer)
        total_predictions += 1

10178

In [None]:
print(f"Accuracy: {correct_predictions / len(responses):.2%}")
correct_predictions