# MedQA Dataset parsing for English data

In [3]:
import pandas as pd 
import json
import numpy as np
import torch
import os

def read_question_answer_file(file_path):
    """Reads a JSONL file with question-answer data and returns a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line))  # Parse each line as JSON
    return data

# Load your dataset
dataset_path = r'med_data/phrases_no_exclude_train.jsonl'  # Replace with the path to your JSON file
questions_data = read_question_answer_file(dataset_path)

In [4]:
print(questions_data[0]['question'])
print(questions_data[0]['answer_idx'])
print(questions_data[0]['options'])

A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?
D
{'A': 'Ampicillin', 'B': 'Ceftriaxone', 'C': 'Doxycycline', 'D': 'Nitrofurantoin'}


<h2>Data pre-processing and extraction</h2>

In [5]:
import re

def clean_text(text):
  """Removes special characters and extra whitespace from text.

  Args:
    text: The input text to be cleaned.

  Returns:
    The cleaned text.
  """
  # Remove special characters, but keep letters, digits, and single spaces
  cleaned_text = re.sub(r'[^A-Za-z0-9\s]', '', text)
    # Replace multiple spaces with a single space
  cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
    # Strip leading and trailing spaces (if any)
  cleaned_text = cleaned_text.strip()
  return cleaned_text

In [6]:
def save_string_to_file(data, filename):
  """Saves a string to a text file.

  Args:
    text: The string to be saved.
    filename: The name of the file to create.
  """

  with open(filename, "w",encoding='utf-8') as f:
    json.dump(data, f)

# # Example usage:
# my_string = "This is the text I want to save."
# save_string_to_file(my_string, "output.txt")

In [7]:
def sanitize_filename(filename):
  """Sanitizes a filename by replacing special characters with underscores.

  Args:
    filename: The original filename.

  Returns:
    The sanitized filename.
  """

  # Replace non-alphanumeric characters with underscores
  filename = re.sub(r'[^\w]', '_', filename)

  # Remove leading and trailing underscores
  filename = filename.strip('_')

  return filename

In [8]:
import json

def extract_text(data):
  """Extracts text from the given JSON data.

  Args:
    data: The JSON data.

  Returns:
    A list of text strings.
  """

  texts = []
  for document in data:
    for passage in document['documents'][0]['passages']:  # Access the first document's passages
        words = passage['text'].split()
        if(len(words)> 10):
          texts.append(clean_text(passage['text']))
  return texts



<h2>Generate bert embeddings</h2>

In [9]:
import torch
import numpy as np
from transformers import BertTokenizer, BertModel

# Initialize the tokenizer and model (bert-base-uncased)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Ensure the model runs on GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bert_model.to(device)


# Function to compute embeddings
def get_embeddings(text,tokenizer,model):
    
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():  # No need to compute gradients
        outputs = model(**inputs)

    last_hidden_state = outputs.last_hidden_state
    embeddings = last_hidden_state.mean(dim=1)

    # Compute the mean of the token embeddings to get a fixed-size representation
    embeddings = embeddings.squeeze().cpu().numpy()  # (batch size x hidden size)

    return embeddings



<h2>Dictionary format data</h2>

In [10]:
# Path to the directory containing JSON files
pubmed_dir = 'Pubmed_Full_text'

# Dictionary to store the original data (not embeddings) by file key


def extract_dataDict(directory):
    data_dict = {}
    # Iterate through each JSON file in the directory
    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        with open(filepath, 'r') as file:
            data = json.load(file)
            file_key = sanitize_filename(filename[:-4])  
            data_dict[file_key] = data
    return data_dict



In [11]:
data_dict = extract_dataDict(pubmed_dir)

In [12]:
print(len(data_dict))
data_dict['Arthralgias']

3627


[['Osteoarthritis is the most common form of arthritis and a leading cause of disability worldwide largely due to pain the primary symptom of the disease The pain experience in knee osteoarthritis in particular is wellrecognized as typically transitioning from intermittent weightbearing pain to a more persistent chronic pain Methods to validly assess pain in osteoarthritis studies have been developed to address the complex nature of the pain experience The etiology of pain in osteoarthritis is recognized to be multifactorial with both intraarticular and extraarticular risk factors Nonetheless greater insights are needed into pain mechanisms in osteoarthritis to enable rational mechanismbased management of pain Consequences of pain related to osteoarthritis contribute to a substantial socioeconomic burden',
  'The hallmark symptom of osteoarthritis OA the most common form of arthritis is pain This is the symptom that drives individuals to seek medical attention and contributes to functi

<h2>Document text arrangement to search using FIASS Index.</h2>

In [13]:
from tqdm import tqdm
medQA_Filtered_data = []

for key, data in tqdm(data_dict.items()):
        for texts in data:
            for text in texts:
                if len(text.split())>10:
                    medQA_Filtered_data.append(clean_text(text))

100%|██████████| 3627/3627 [00:20<00:00, 178.80it/s]


In [18]:
len(medQA_Filtered_data)

1033101

In [19]:
def getEmbeddingsListwise(data_list):
    embeddings = []
    for text in tqdm(data_list):
        embedding = get_embeddings(text,bert_tokenizer, bert_model)
        embeddings.append(embedding)
    return embeddings

fiass_embeddings = getEmbeddingsListwise(medQA_Filtered_data)

100%|██████████| 1033101/1033101 [1:58:28<00:00, 145.32it/s] 


# FIASS Installation and Storage

In [14]:
!pip install faiss-cpu 
#!pip install faiss-gpu  # For GPU support, if you have a CUDA-capable GPU

Collecting faiss-cpu
  Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.8.0.post1
[0m

In [22]:
import faiss

embeddings = np.vstack(fiass_embeddings) 

# Get the dimensionality of the embeddings
embedding_dim = embeddings.shape[1]

# Create a FAISS index
# IndexFlatL2 is a simple, exact nearest neighbor search index with L2 distance
index = faiss.IndexFlatL2(embedding_dim)

# Add embeddings to the index
index.add(embeddings)

# Check the number of vectors in the index
print(f"Number of vectors in the index: {index.ntotal}")

# Save the index to a file for later use
faiss.write_index(index, 'FAISS/QA_Full_text_index.index')




Number of vectors in the index: 1033101


In [14]:
import faiss
# Load the FAISS index
faiss_index = faiss.read_index('FAISS/QA_Full_text_index.index')

# Function to search the FAISS index
def search_faiss_index(query, index, tokenizer, model, top_k=5):
    # Compute the query embedding
    query_embedding = get_embeddings(query, tokenizer, model)

    # Reshape the query embedding to match FAISS expected input format
    query_embedding = np.expand_dims(query_embedding, axis=0).astype('float32')

    # Search the FAISS index
    distances, indices = index.search(query_embedding, top_k)

    return distances, indices

In [15]:
# Example usage
query = "Nitrofurantoin"
top_k = 10

distances, indices = search_faiss_index(query, faiss_index, bert_tokenizer, bert_model, top_k=top_k)

# Print results
print(f"Top-{top_k} results:")
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
    print(f"Result {i + 1}:")
    print(f"Index: {idx}, Distance: {distance}")
    # If you have the original text stored, you can retrieve it like this:
    # print(f"Text: {embeddings_dict['some_key'][idx]}")

Top-10 results:
Result 1:
Index: 919667, Distance: 35.589881896972656
Result 2:
Index: 456033, Distance: 36.14335250854492
Result 3:
Index: 683533, Distance: 36.14335250854492
Result 4:
Index: 684042, Distance: 36.14335250854492
Result 5:
Index: 80326, Distance: 36.747894287109375
Result 6:
Index: 703824, Distance: 36.869903564453125
Result 7:
Index: 533088, Distance: 36.94190216064453
Result 8:
Index: 150071, Distance: 37.45697021484375
Result 9:
Index: 75311, Distance: 37.72026062011719
Result 10:
Index: 400110, Distance: 37.87675476074219


# Library installations

In [20]:
!pip install langchain
!pip install transformers
!pip install accelerate
!pip install bitsandbytes
!pip install --upgrade pip
!pip install --upgrade langchain
!pip install langchain_community
!pip list | grep langchain
!pip list | grep langchain_community

!pip install -U langchain-huggingface


Collecting langchain
  Downloading langchain-0.2.12-py3-none-any.whl.metadata (7.1 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain)
  Downloading aiohttp-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting async-timeout<5.0.0,>=4.0.0 (from langchain)
  Downloading async_timeout-4.0.3-py3-none-any.whl.metadata (4.2 kB)
Collecting langchain-core<0.3.0,>=0.2.27 (from langchain)
  Downloading langchain_core-0.2.29-py3-none-any.whl.metadata (6.2 kB)
Collecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)
  Downloading langchain_text_splitters-0.2.2-py3-none-any.whl.metadata (2.1 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchain)
  Downloading langsmith-0.1.98-py3-none-any.whl.metadata (13 kB)
Collecting pydantic<3,>=1 (from langchain)
  Downloading pydantic-2.8.2-py3-none-a

# Building the pipeline with the langchain

In [16]:
import os
import langchain

### prompts
from langchain import PromptTemplate, LLMChain

### models
# from langchain.llms import HuggingFacePipeline
# from langchain.embeddings import HuggingFaceInstructEmbeddings


import torch
import transformers
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline
)


#model = "/kaggle/input/llama-2/pytorch/7b-chat-hf/1"
#model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"
model = "meta-llama/Meta-Llama-3-8B"
#model = "Undi95/Meta-Llama-3-8B-hf"

tokenizer = AutoTokenizer.from_pretrained(model)

        
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.float16,
    bnb_4bit_use_double_quant = True,
)

model_llama = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config = bnb_config,
    device_map = 'auto',
    token='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'
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [26]:
question = questions_data[0]['question']
options = "\nA. Ampicillin\nB. Ceftriaxone\nC. Doxycycline\nD. Nitrofurantoin\n"

option_dict = questions_data[0]['options']

# Search FAISS index for each option and retrieve the corresponding text
retrieved_contexts = []
for option_key, option_text in option_dict.items():
    distances, indices = search_faiss_index(option_text, faiss_index, bert_tokenizer, bert_model, top_k=top_k)
    retrieved_context = " ".join([medQA_Filtered_data[idx] for i, (distance, idx) in enumerate(zip(distances[0], indices[0]))])
    retrieved_contexts.append(retrieved_context)

context = '\n'.join([s.strip() + '.' for s in retrieved_contexts])


In [27]:
# prompt generation
template_context = """Question: {question}
Context: {context}[INST]Select the correct option only. No explanation required[/INST]

Options: {options}

# Answer: """ # Force a single-line response

prompt_template_context = PromptTemplate(template=template_context, input_variables=["question", "options", "context"])
#prompt_context = prompt_template_context.format(question=question, options=options, context=context) 

In [None]:
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token 
inputs = tokenizer(prompt_context, return_tensors='pt', truncation=True, padding="max_length", max_length=4000).to(model_llama.device)
outputs = model_llama.generate(**inputs, max_new_tokens=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

In [None]:
position = response.find('#Answer:')
prediction = response[position+8 :position+10].strip()
prediction

# Designing a context summary prompt

In [None]:
# Use the summarization prompt to generate a summary
summarization_prompt = f"""[INST] Summarize the following text concisely:

{context_text}
[/INST]
"""

with torch.no_grad():
    summary_output = model_llama.generate(
        **tokenizer(summarization_prompt, return_tensors="pt").to(model_llama.device),
        max_new_tokens=128,
    )

# Decode and print the generated summary
summary_text = tokenizer.decode(summary_output[0], skip_special_tokens=True)
print(summary_text)

# Data Loader with RAG using FAISS index

In [36]:
def truncate_to_words(text, max_words=6000):
  """Truncates a text string to a maximum number of words.

  Args:
    text: The input text string.
    max_words: The maximum number of words to keep.

  Returns:
    The truncated text string.
  """

  words = text.split()
  if len(words) <= max_words:
    return text
  else:
    truncated_text = ' '.join(words[:max_words])
    return truncated_text

In [37]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from tqdm import tqdm  # For progress bar

class QuestionAnswerDataset_RAG(Dataset):
    def __init__(self, questions_data, faiss_index, faiss_texts):
        self.questionData = questions_data
        self.faiss_index = faiss_index
        self.faiss_texts = faiss_texts
    

    def __len__(self):
        return len(self.questionData)

    def __getitem__(self, idx):
        question_data = self.questionData[idx]
        question = question_data['question']
        options = question_data['options']
        options_str = "\n".join([f"{key}. {value}" for key, value in options.items()])
        answer = question_data['answer_idx']

        # Search FAISS index for each option and retrieve the corresponding text
        retrieved_contexts = []
        for option_key, option_text in option_dict.items():
            distances, indices = search_faiss_index(option_text, self.faiss_index, bert_tokenizer, bert_model, top_k=5)
            retrieved_context = " ".join([self.faiss_texts[idx] for i, (distance, idx) in enumerate(zip(distances[0], indices[0]))])
            retrieved_contexts.append(retrieved_context)

        # Combine the contexts with the question and options for the final return
        combined_context = truncate_to_words("\n".join(retrieved_contexts))

        return question, options_str, answer, combined_context

    def get_text_embedding(self, text):
        # Placeholder for the actual embedding generation logic
        # Replace this with the method to generate embeddings from text
        return np.random.rand(768)  # Example: Replace with actual embedding


faiss_texts = [...]  # Load the list of texts associated with the FAISS index

# Load your dataset
dataset_RAG = QuestionAnswerDataset_RAG(questions_data[0:1000], faiss_index, medQA_Filtered_data)
dataloader_RAG = DataLoader(dataset_RAG, batch_size=4, shuffle=False)  # Adjust batch_size as needed

# Running Batches with context

In [None]:
correct_predictions = 0
total_predictions = 0
responses = []
answers = []

for batch in tqdm(dataloader_RAG):
    questions, options_strs, answer_idxs, combined_contexts = batch
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token
    prompts = [prompt_template_context.format(question=question, options=options_str, context=combined_context) for question, options_str,combined_context in zip(questions, options_strs,combined_contexts)]
    
    inputs = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, max_length=8000).to(model_llama.device)
    
    with torch.no_grad():
        outputs = model_llama.generate(**inputs,max_new_tokens=1)
    
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    #print(decoded_outputs)
    for decoded_output, answer in zip(decoded_outputs, answer_idxs):
        position = decoded_output.find('#Answer:')
        answer_pred = decoded_output[position+8 :position+10].strip()
        #print(answer_pred)
        if answer == answer_pred.strip():
            correct_predictions += 1
        
        responses.append(answer_pred)
        answers.append(answer)
        total_predictions += 1

  0%|          | 0/250 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 1/250 [00:04<20:12,  4.87s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  1%|          | 2/250 [00:09<20:23,  4.93s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  1%|          | 3/250 [00:14<19:29,  4.73s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  2%|▏         | 4/250 [00:19<19:29,  4.75s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  2%|▏         | 5/250 [00:23<19:04,  4.67s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  2%|▏         | 6/250 [00:29<19:58,  4.91s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  3%|▎         | 7/250 [00:33<19:46,  4.88s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  3%|▎         | 8/250 [00:38<19:06,  4.74s/it]S

In [33]:
print(f"Accuracy: {correct_predictions / len(responses):.2%}")
correct_predictions

Accuracy: 0.00%


0

In [34]:
responses.count('n:')
# responses.index('n:')
# question

1000

In [35]:
def free_gpu_cache():
  """Frees the GPU cache and memory."""
  if torch.cuda.is_available():
      with torch.cuda.device(0):
          torch.cuda.empty_cache()