This is the notebook version of running the Gradio interface of Medical RAG.

Installing the packages

In [None]:
!pip install langchain spacy nltk faiss-cpu cohere python-dotenv

Loading the API keys

In [None]:
from dotenv import load_dotenv
import os

# Load variables from .env
load_dotenv()

# Access API key
cohere_api_key = os.getenv("COHERE_API_KEY")

In [None]:
import os
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import re
from sklearn.metrics.pairwise import cosine_similarity
from langchain.text_splitter import RecursiveCharacterTextSplitter, SpacyTextSplitter, NLTKTextSplitter
import nltk
import torch
import gradio as gr
import cohere
nltk.download('punkt_tab')




All chunking functions

In [None]:
#semantic chunking
def semantic_chunking(text, similarity_threshold):
    """
    Perform semantic chunking on a given text by grouping semantically similar sentences.

    Parameters:
    - text (str): The input text to chunk.
    - model (SentenceTransformer): Preloaded SentenceTransformer model for embeddings.
    - similarity_threshold (float): Threshold to determine if sentences belong in the same chunk.

    Returns:
    - List[str]: A list of chunks, each containing semantically similar sentences.
    """
    # Split the text into sentences
    print("Splitting the sentences")
    # sentence_pattern = r"(.*?[.!?])"
    sentence_pattern = r'([^.!?]*[.!?])'
    sentences = re.findall(sentence_pattern, text)

    # Generate embeddings for sentences
    print("Creating the embeddings")
    embeddings = embedding_model.encode(sentences)
    embeddings_list = embeddings.tolist()

    semantic_chunks = []

    for i in range(len(sentences)):
        if i == 0:
            semantic_chunks.append(sentences[i])
        else:
            # Reshape embeddings to 2D arrays for cosine similarity
            embedding1 = np.array(embeddings_list[i - 1]).reshape(1, -1)
            embedding2 = np.array(embeddings_list[i]).reshape(1, -1)
            similarity = cosine_similarity(embedding1, embedding2)

            if similarity[0][0] >= similarity_threshold:
                # Combine the current sentence with the previous chunk
                semantic_chunks[-1] += " " + sentences[i]
            else:
                # Start a new chunk
                semantic_chunks.append(sentences[i])

    return semantic_chunks

#Section based chunking
headers = [
    "Patient ID",
    "Admission Date",
    "Discharge Date",
    "Date of Birth",
    "Sex",
    "Service",
    "Chief Complaint",
    "History of Present Illness",
    "Past Medical History",
    "Past Surgical History",
    "Social History",
    "Family History",
    "Allergies",
    "Medications on Admission",
    "Hospital Course",
    "Investigations",
    "Procedures",
    "Discharge Plan"
]


def extract_sections_chunks(text):
    # Sort headers by length (descending) to avoid partial matches
    chunks=[]
    headers = sorted(headers, key=len, reverse=True)

    # Create a regex pattern to detect headers
    pattern = r"(?:" + "|".join(re.escape(header) for header in headers) + r")"

    # Use regex to find all occurrences of headers
    matches = list(re.finditer(pattern, text))

    section_dict = {}

    for i, match in enumerate(matches):
        header = match.group().strip()
        start_idx = match.end()  # Start of the section content

        # Determine end index (either next header or end of text)
        end_idx = matches[i + 1].start() if i + 1 < len(matches) else len(text)

        content = text[start_idx:end_idx].strip()
        section_dict[header] = content
        chunks.append(content)
    return chunks

#Recursive chunking
def recursive_chunking(text):
  text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=400,
    chunk_overlap=100
  )
  docs = text_splitter.create_documents([text])
  required_notes_chunks4=[]
  for doc in docs:
    required_notes_chunks4.append(doc.page_content)

  return required_notes_chunks4

#Fixed size chunking
def chunk_string_with_overlap(string):
    chunks = []
    chunk_size = len(string) // 10
    overlap_size = chunk_size // 10
    i = 0
    while i < len(string) - chunk_size + 1:
        chunk = string[i:i + chunk_size]
        chunks.append(chunk)
        i += chunk_size - overlap_size
    # Add the last chunk if there's remaining string
    if i < len(string):
        chunks.append(string[i:])
    return chunks


#Context aware spacy chunking
def context_aware_chunking_spacy(text):
  text_splitter =SpacyTextSplitter(chunk_size=400)
  docs = text_splitter.split_text(text)
  doc_final=[]
  for i, doc in enumerate(docs):
    doc_final.append(doc)

  return doc_final

#Context aware nltk chunking
def context_aware_chunking_nltk(text):
  text_splitter = NLTKTextSplitter(chunk_size=400)
  docs = text_splitter.split_text(text)
  doc_final=[]
  for i, doc in enumerate(docs):
    doc_final.append(doc)

  return doc_final



Initializing the embedding model

In [None]:
def prepare_model():
    model_name='sentence-transformers/all-MiniLM-L6-v2'
    embedding_model = SentenceTransformer(model_name)

    return embedding_model

embedding_model = prepare_model()

Data Ingestion and Preprocessing Functions

In [None]:
def extract_id_query(query:str):
  match = re.search(r"(?:patient\s*id|id)\s*[:=]?\s*(\d+)", query, re.IGNORECASE)
  if match:
    return match.group(1)

def extract_id(note:str):
  match = re.search(r"Patient ID:\s*(\d+)", note)
  if match:
    return match.group(1)

def read_all_txt_files(folder_path:str,chunk_choice:int) -> list[str]:
    contents = []
    ids = []

    if os.path.exists(folder_path) and os.path.isdir(folder_path):
        for filename in os.listdir(folder_path):
            if filename.endswith(".txt"):
                file_path = os.path.join(folder_path, filename)
                with open(file_path, "r", encoding="utf-8") as f:
                    content = f.read()
                    id = extract_id(content)
                    chunks = []
                    if chunk_choice == 0:
                        chunks = chunk_string_with_overlap(content)
                    elif chunk_choice == 1:
                        chunks = semantic_chunking(content,0.50)
                    elif chunk_choice == 2:
                        chunks = extract_sections_chunks(content)
                    elif chunk_choice == 3:
                        chunks = recursive_chunking(content)
                    elif chunk_choice == 4:
                        chunks = context_aware_chunking_spacy(content)
                    else:
                        chunks = context_aware_chunking_nltk(content)

                for chunk in chunks:
                    contents.append(chunk)
                    ids.append(id)


    else:
        print(f"Invalid folder path: {folder_path}")

    return contents, ids


def get_embeddings(embedding_model,text):
    # embedding_model = prepare_model()
    embeddings = embedding_model.encode(text,convert_to_numpy=True)
    return embeddings.tolist()

def prepare_vector_database(embedding_model,contents : list[str], ids):
    df = pd.DataFrame({'Document content':contents,"Patient ID":ids})
    df["Embeddings"] = df['Document content'].apply(lambda x:get_embeddings(embedding_model,text = x))

    index = faiss.IndexFlatL2(384)

    metadata_store = {}
    for i in range(df.shape[0]):
        emb = np.array(df['Embeddings'][i]).astype('float32')
        emb = emb.reshape(1,-1)
        index.add(emb)

        metadata_store[i] = {
            "patient_id": df['Patient ID'][i],
            "note": df['Document content'][i]
        }

    return index, df, metadata_store


Context preparation and Query handling functions

In [None]:
def prepare_content(query:str, index, df,embedding_model,k = 3):
    query_id = extract_id_query(query)
    query_embedding = get_embeddings(embedding_model,query)
    query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1)
    D, I = index.search(query_embedding,k = k)
    retrieved_content = ""
    tot_cis_score = 0.0
    retrieved_count = 0
    for doc_index, distance in zip(I[0], D[0]):
        if doc_index < 0 or doc_index>=len(df):
            print(f"Invalid index: {doc_index}, dataframe length = {len(df)}")
            continue
        # else:
        #     document = df.iloc[doc_index]["Document content"]
        #     retrieved_content = retrieved_content + " " + document
        #     cosine_score = 1 - distance / 2
        #     tot_cis_score = tot_cis_score + cosine_score
        patient_id = df.iloc[doc_index]["Patient ID"]
        if str(patient_id) != str(query_id):
          continue

        document = df.iloc[doc_index]["Document content"]
        retrieved_content = retrieved_content + document
        cosine_score = 1 - distance / 2
        tot_cis_score = tot_cis_score + cosine_score
        retrieved_count = retrieved_count + 1
        if retrieved_count>=k:
          break

    if retrieved_content.strip() == "":
        retrieved_content = "No notes found for this patient ID."

    return retrieved_content, tot_cis_score


def answer_with_content(query:str, index, df,embedding_model):

    co = cohere.ClientV2(cohere_api_key)
    retrieved_content, tot_score = prepare_content(query,index, df,embedding_model)
    additional_context = """
      You are an expert medical assistant tasked with answering questions based on the provided clinical context.
      The context contains important sections or headers. Please ensure that you pay close attention to these headers
      and use the most relevant section to answer the query. Your response should be clear, concise, medically accurate,
      and tailored to the specific medical scenario. Always provide a full-sentence answer to the query based on the
      information from the most appropriate section of the context.

      Context: {context}

      Question: {query}

      Answer:
      """
    combined_prompt = additional_context.format(context=retrieved_content, query=query)


    response = co.chat(
        model = 'command-xlarge-nightly',
         messages=[
        {
            "role": "user",
            "content": combined_prompt,
        }
    ]
    )
    print(response.message.content[0].text)
    print(type(response.message.content))
    print(len(response.message.content))
    response1 = response.message.content[0].text

    return response1,tot_score


Main function

In [None]:
print("Reading all text files")
all_contents, ids = read_all_txt_files("/content/sample_data",1)

print("Preparing Faiss vector database")
faiss_index, content_df, metadata = prepare_vector_database(embedding_model,all_contents,ids)

def process_query(query:str):
    if not query.strip():
        return "Please enter a query"

    required_content, score = prepare_content(query,faiss_index,content_df,embedding_model)
    print(required_content)
    print(score)
    answer, score = answer_with_content(query,faiss_index,content_df,embedding_model)

    return answer, f"Confidence Score: {score}"



Gradio  interface launching

with gr.Blocks() as demo:
    gr.Markdown("Medical RAG system")
    gr.Markdown("Ask query and get answer based on the notes provided.")
    with gr.Row():
        query_input = gr.Textbox(label="Enter your query", placeholder="Type your question here...")

    submit_btn = gr.Button("Get Answer")

    with gr.Row():
        answer_output = gr.Textbox(label="Answer", lines=5)

    with gr.Row():
        score_output = gr.Number(label="Score")

    submit_btn.click(fn=process_query, inputs=query_input, outputs=[answer_output, score_output])


if __name__ == "__main__":
    demo.launch(debug = True)