# **Initialization**

In [None]:
DEBUG = True  # Set to True to enable debugging, False to disable

def debug_message(msg: str):
    """Prints a debug message only if DEBUG is True."""
    if DEBUG:
        print(msg)

# Initialize API keys
import os

# Hugging Face API key
os.environ['HUGGINGFACE_API_KEY'] = 'your_huggingface_api_key_here'

# LangChain API key
os.environ['LANGCHAIN_API_KEY'] = 'your_langchain_api_key_here'

# OpenAI API key
os.environ['OPENAI_API_KEY'] = 'your_openai_api_key_here'

# Groq API key
os.environ['GROQ_API_KEY'] = 'your_groq_api_key_here'

## Embedding

In [2]:
from langchain_huggingface import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-mpnet-base-v2",
)

## Loading Documents

In [None]:
from langchain_community.document_loaders import DirectoryLoader, TextLoader

# Load the dataset
dataset_path = r"/path/to/your/dataset_directory"
loader = DirectoryLoader(dataset_path, glob="*.txt", loader_cls=TextLoader, show_progress=True, use_multithreading=True)

# Get the list of files and sort them
file_list = sorted(os.listdir(dataset_path), key=lambda x: int(x.split('_')[-1].split('.')[0]))

# Load files
file_list = file_list[:10]

# Load documents
docs = [doc for file in file_list for doc in TextLoader(os.path.join(dataset_path, file)).load()]

# Display the first few documents for verification
for doc in docs[:1]:
    print(doc)

## Vectorstore

In [None]:
import os
import numpy as np
import faiss
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from uuid import uuid4
from langchain_core.documents import Document

# Determine the dimensionality of the embeddings dynamically
embedding_dim = len(embeddings.embed_query(""))

# Initialize FAISS index
index = faiss.IndexFlatL2(embedding_dim)

# Create the vector store
vectorstore = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={}
)

# Prepare documents with the correct structure
documents = [
    Document(page_content=doc.page_content, metadata=doc.metadata)
    for doc in docs  # Convert your `docs` to the expected `Document` format
]

# Generate unique identifiers for each document
uuids = [str(uuid4()) for _ in range(len(documents))]

# Add documents to the vector store
vectorstore.add_documents(documents=documents, ids=uuids)

print(f"Successfully added {len(documents)} documents to the FAISS vector store.")

# Save and load the FAISS vector store
if not os.path.exists("faiss_index"):
    vectorstore.save_local("faiss_index")

if os.path.exists("faiss_index"):
    new_vector_store = FAISS.load_local(
        "faiss_index", embeddings, allow_dangerous_deserialization=True
    )

## Chunking

In [None]:
import re

# Function to chunk the document
def chunk_document(doc_content):
    sections = doc_content.split("\n\n")
    chunks = [section for section in sections]
    debug_message(f"Document chunked into {len(chunks)} sections.")
    return chunks

# Function to retrieve specific report based on the question
def retrieve_specific_report(question):
    # Match either M, L, Q followed by digits, or just digits
    match = re.search(r'\b(?:[MLQ]?)(\d+)\b', question)  # Updated regex to make M, L, Q optional
    if match:
        report_num = match.group(1)  # Capture only the numeric part
        debug_message(f"Extracted report number: {report_num}")

        # Construct the report path
        report_path = os.path.join(
            r"/path/to/your/dataset_directory",
            f"report_{report_num}.txt"
        )
        debug_message(f"Constructed report path: {report_path}")

        if os.path.exists(report_path):
            debug_message("File exists.")
            loader = TextLoader(report_path)
            doc = loader.load()[0]
            debug_message(f"Loaded document: {doc.page_content[:100]}...")  # Preview of loaded content
            return doc
        else:
            debug_message("File does not exist.")
    else:
        debug_message("No match found for report identifier.")
    return None

In [None]:
question = "What is the diagnosis of patient 0?"

# Retrieve specific report based on the question
specific_doc = retrieve_specific_report(question)

if specific_doc:
    # Chunk the document
    chunks = chunk_document(specific_doc.page_content)
    print(chunks)
else:
    print("No specific report found for the given question.")

In [None]:
from sentence_transformers import CrossEncoder

# Initialize the cross-encoder model
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# Rerank the chunks based on the question
reranked_scores = cross_encoder.predict([(question, chunk) for chunk in chunks])

# Add scores to the chunks
chunk_scores = list(zip(chunks, reranked_scores))

# Sort chunks by score
sorted_chunks = sorted(chunk_scores, key=lambda x: x[1], reverse=True)

# Print the scores for debugging purposes
for chunk, score in sorted_chunks:
    print(f"Score: {score}\nChunk: {chunk}\n")

# Get the top chunk
top_chunk = sorted_chunks[0][0]

print(top_chunk)

## Chat Model

In [None]:
# Choose one
# 1
# from langchain_groq import ChatGroq

# llm = ChatGroq(
#     model="gemma2-9b-it",
#     temperature=0,
#     max_tokens=None,
#     timeout=None,
#     max_retries=2,
# )

# llm_rewrite = ChatGroq(
#     model="gemma2-9b-it",
#     temperature=0,
#     max_tokens=None,
#     timeout=None,
#     max_retries=2,
# )

# 2
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    api_key='your_openai_api_key_here'
)

llm_rewrite = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    api_key='your_openai_api_key_here'
)

### Memory

In [9]:
from langchain_core.messages import SystemMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph

workflow = StateGraph(state_schema=MessagesState)

config = {"configurable": {"thread_id": "1"}}

# Define the function that calls the model
def call_model(state: MessagesState):
    system_prompt = (
        "You are a highly efficient and detail-oriented medical assistant. "
        "Provide direct, concise, and accurate answers based on the provided information, specifically addressing the mentioned patient Medical Record Number (MRN). Ensure responses include all relevant details, with a clear mention of the stage of the disease when applicable. Avoid unnecessary explanations, apologies, or corrections, focusing solely on delivering precise and complete information."
    )
    messages = [SystemMessage(content=system_prompt)] + state["messages"]
    response = llm.invoke(messages)
    return {"messages": response}

# Define the node and edge
workflow.add_node("model", call_model)
workflow.add_edge(START, "model")

# Add simple in-memory checkpointer
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

# **Pipeline**

In [10]:
# Global variable to store the context of the conversation
context = {
    "mrns": [],
    "docs": {},
    "specific_docs": {},
    "chunks": {},
    "top_chunks": {}
}

## Utils

In [11]:
# Utility functions
class Util:
    @staticmethod
    def extract_mrns(question):
        """Extract patient MRNs (with or without prefixes) from a question."""
        mrns = re.findall(r'\b(?:[MLQ]?)(\d+)\b', question)
        debug_message(f"Extracted MRNs: {mrns}")
        return mrns

    @staticmethod
    def resolve_mrns(question):
        """
        Extract patient MRNs from the question. If none are found, fallback to
        the MRNs stored in 'context["mrns"]'.
        """
        new_mrns = Util.extract_mrns(question)
        
        if new_mrns:
            context["mrns"] = new_mrns
            debug_message(f"New MRNs found and updated in context: {new_mrns}")
        else:
            if context["mrns"]:
                debug_message(f"No new MRNs found; reusing existing: {context['mrns']}")
            else:
                debug_message("No MRNs found in question, and none were in context.")
        
        return context["mrns"]

## Retrieval

In [35]:
# Document Retrieval
class DocRetriever:
    """
    Handles retrieval of documents for specific patient MRNs.
    """

    @staticmethod
    def retrieve_doc(patient_mrn):
        """
        Retrieve a document for a specific patient MRN.

        Args:
            patient_mrn (str): The medical record number of the patient.

        Returns:
            Document: The retrieved document or None if not found.
        """
        debug_message(f"[DocRetriever] Action: Retrieving document | Patient MRN: {patient_mrn}")
        doc = retrieve_specific_report(patient_mrn)
        if doc:
            context.setdefault("docs", {})[patient_mrn] = doc
            debug_message(f"[DocRetriever] Status: Document retrieved | Patient MRN: {patient_mrn} | Document Preview: {doc.page_content[:100]}...")
            return doc
        debug_message(f"[DocRetriever] Status: No report found | Patient MRN: {patient_mrn}")
        return None

# Document Chunking
class DocChunker:
    """
    Handles chunking and storing of document data for efficient processing.
    """

    @staticmethod
    def chunk_store_doc(patient_mrn, doc):
        """
        Chunk a specific document and store the chunks in the context.

        Args:
            patient_mrn (str): The medical record number of the patient.
            doc (Document): The document to chunk.

        Returns:
            list: The list of document chunks.
        """
        chunks = chunk_document(doc.page_content)
        context.setdefault("chunks", {})[patient_mrn] = chunks
        debug_message(f"[DocChunker] Action: Storing chunks | Patient MRN: {patient_mrn} | Total Chunks: {len(chunks)}")
        return chunks

    @staticmethod
    def retrieve_chunk_doc(patient_mrn):
        """
        Retrieve a document and chunk it if not already stored in context.

        Args:
            patient_mrn (str): The medical record number of the patient.

        Returns:
            list: The list of chunks for the document or an empty list if no document is found.
        """
        if patient_mrn not in context.get("docs", {}):
            doc = DocRetriever.retrieve_doc(patient_mrn)
            if not doc:
                debug_message(f"[DocChunker] Status: No document found for patient MRN: {patient_mrn}")
                return None
            return DocChunker.chunk_store_doc(patient_mrn, doc)
        debug_message(f"[DocChunker] Status: Chunks retrieved from context for patient MRN: {patient_mrn}")
        return context.get("chunks", {}).get(patient_mrn, [])

# Chunk Reranking
class ChunkReranker:
    """
    Handles reranking of document chunks based on their relevance to a given question.
    """

    @staticmethod
    def rerank(question, chunks):
        """
        Rerank document chunks based on their relevance to the provided question.

        Args:
            question (str): The question to evaluate relevance.
            chunks (list of str): The document chunks to rank.

        Returns:
            list: A list of tuples with chunks and their relevance scores, sorted in descending order.
        """
        debug_message(f"[ChunkReranker] Action: Reranking chunks | Question: {question}")
        scores = cross_encoder.predict([(question, chunk) for chunk in chunks])

        # Pair chunks with their scores
        chunk_scores = list(zip(chunks, scores))
        debug_message(f"[ChunkReranker] Chunk Scores:")
        for chunk, score in chunk_scores:
            debug_message(f"Chunk: {chunk[:300]}... | Score: {score:.2f}")

        # Sort chunks by score in descending order
        sorted_chunks = sorted(chunk_scores, key=lambda x: x[1], reverse=True)
        debug_message(f"[ChunkReranker] Sorted Chunks:")
        for chunk, score in sorted_chunks:
            debug_message(f"Chunk: {chunk[:300]}... | Score: {score:.2f}")

        return sorted_chunks

## Functionalities

### Question Answering

In [13]:
from langchain_core.messages import HumanMessage

class SingleQA:
    """
    Handles answering a single question for a specific patient by leveraging retrieved and ranked document chunks.
    """

    @staticmethod
    def answer(patient_mrn, question):
        """
        Retrieve and generate an answer for a given patient and question.

        Args:
            patient_mrn (str): The medical record number of the patient.
            question (str): The question to answer.

        Returns:
            str: The generated answer or an error message if no data is found.
        """
        debug_message(f"[SingleQA] Action: Retrieving answer | Patient MRN: {patient_mrn}")

        # Retrieve chunks for the patient
        if patient_mrn not in context.get("chunks", {}):
            debug_message(f"[SingleQA] Status: Chunks not found in memory | Patient MRN: {patient_mrn}. Retrieving and chunking document.")
            chunks = DocChunker.retrieve_chunk_doc(patient_mrn)
            if not chunks:
                error_message = f"No report or chunks found for patient MRN: {patient_mrn}."
                debug_message(f"[SingleQA] Status: {error_message}")
                return error_message
            context.setdefault("chunks", {})[patient_mrn] = chunks

        chunks = context["chunks"][patient_mrn]

        # Rerank chunks based on the question
        debug_message(f"[SingleQA] Action: Reranking chunks | Patient MRN: {patient_mrn}")
        sorted_chunks = ChunkReranker.rerank(question, chunks)
        top_chunk = sorted_chunks[0][0]
        context.setdefault("top_chunks", {})[patient_mrn] = top_chunk

        # Generate response using the top chunk
        input_msgs = [HumanMessage(question), HumanMessage(top_chunk)]
        debug_message(f"[SingleQA] Action: Invoking app for response generation | Patient MRN: {patient_mrn}")
        output = app.invoke({"messages": input_msgs}, config)
        answer = output["messages"][-1].content

        debug_message(f"[SingleQA] Status: Answer generated | Patient MRN: {patient_mrn} | Answer: {answer}")
        return answer


class MultiQA:
    """
    Handles multi-patient questions by dividing questions, generating individual answers,
    and summarizing the results.
    """

    @staticmethod
    def divide(question, patient_mrns):
        """
        Rewrite the question for each patient.

        Args:
            question (str): The original question.
            patient_mrns (list of str): List of patient MRNs to customize the question for.

        Returns:
            dict: A dictionary mapping each patient MRN to their customized question.

        Raises:
            ValueError: If no questions could be generated for the patients.
        """
        divide_prompt = (
            f"Rewrite the question: '{question}' for each of these patients: "
            f"{', '.join(patient_mrns)}. "
            "Include each patient's MRN in the rewritten question."
        )

        debug_message("[MultiQA] Action: Generating patient-specific questions.")
        input_msgs = [HumanMessage(divide_prompt)]
        output = llm_rewrite.invoke(input_msgs, config)

        divided_questions = output.content.split('\n')
        questions_dict = {}
        for line in divided_questions:
            for patient_mrn in patient_mrns:
                if patient_mrn in line:
                    questions_dict[patient_mrn] = line.strip()
                    break

        if not questions_dict:
            error_message = "Failed to divide the question into patient-specific queries."
            debug_message(f"[MultiQA] Status: {error_message}")
            raise ValueError(error_message)

        debug_message(f"[MultiQA] Status: Divided questions generated | Questions: {questions_dict}")
        return questions_dict

    @staticmethod
    def summarize(all_answers):
        """
        Summarize answers for multiple patients.

        Args:
            all_answers (list of str): List of answers for individual patients.

        Returns:
            str: A summarized response for all answers.
        """
        combined_answers = "\n\n".join(all_answers)
        debug_message(f"[MultiQA] Action: Summarizing answers | Combined Answers: {combined_answers}")

        summary_question = "Provide a direct summary of the answers without any introductory text."
        input_msgs = [HumanMessage(summary_question), HumanMessage(combined_answers)]

        summary_output = app.invoke({"messages": input_msgs}, config)
        summary = summary_output["messages"][-1].content

        debug_message(f"[MultiQA] Status: Summary generated | Summary: {summary}")
        return summary

    @staticmethod
    def multi_answer(question, patient_mrns):
        """
        Generate answers for multiple patients and summarize the results.

        Args:
            question (str): The question to answer.
            patient_mrns (list of str): List of patient MRNs to answer for.

        Returns:
            str: A summarized answer for all patients.
        """
        debug_message("[MultiQA] Action: Initiating multi-answer process for multiple patients.")

        # Divide the question into patient-specific queries
        divided_questions = MultiQA.divide(question, patient_mrns)
        all_answers = []

        for patient_mrn, patient_question in divided_questions.items():
            debug_message(f"[MultiQA] Action: Processing question for patient MRN {patient_mrn} | Question: {patient_question}")
            answer = SingleQA.answer(patient_mrn, patient_question)
            all_answers.append(f"Patient {patient_mrn}: {answer}")

        debug_message("[MultiQA] Action: Combining and summarizing all answers.")
        return MultiQA.summarize(all_answers)

### Summarization

In [14]:
class Summarizer:
    """
    A class for generating summaries of medical reports for patients.
    Includes functionality to summarize entire reports, specific sections, 
    and multiple patients' reports.
    """

    @staticmethod
    def summarize_report(patient_mrn):
        """
        Summarize the entire report for a single patient report.

        Args:
            patient_mrn (str): The medical record number of the patient.

        Returns:
            str: A structured summary of the patient's report or an error message if no report is found.
        """
        debug_message(f"[Summarizer] Action: Starting summarization for patient MRN: {patient_mrn}")

        # Retrieve the document for the patient
        doc = DocRetriever.retrieve_doc(patient_mrn)
        if not doc:
            error_message = f"No report found for patient MRN: {patient_mrn}."
            debug_message(f"[Summarizer] Status: {error_message}")
            return error_message

        debug_message(f"[Summarizer] Status: Document retrieved | Patient MRN: {patient_mrn}")
        debug_message(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content)

        # Build the summarization prompt
        input_msgs = [HumanMessage(
            f"You are a highly skilled medical assistant specializing in summarizing medical reports. "
            "Your task is to create concise, structured, and professional summaries intended for busy medical professionals. "
            "Follow this specific format, but only include sections present in the original report. "
            "Do not add subsections not explicitly mentioned in the report. "
            "\n\n"
            "### Patient Summary Report\n\n"
            "#### Patient Information (if provided in the report)\n"
            "- **Name:** [Patient's name]\n"
            "- **Date of Birth:** [Patient's DOB]\n"
            "- **Admission Dates:** [Admission and discharge dates]\n"
            "- **Primary Diagnosis:** [Diagnosis]\n"
            "- **Attending Physician:** [Physician's name]\n\n"
            "#### Reason for Admission (if provided in the report)\n"
            "Provide a brief description of the primary symptoms and reasons for admission, "
            "including any relevant context about delays in seeking care if applicable.\n\n"
            "#### Medical History (if provided in the report)\n"
            "- Briefly list relevant medical, family, and social history (e.g., smoking, hypertension, family illnesses).\n\n"
            "#### Diagnostic Findings (if provided in the report)\n"
            "Summarize key diagnostic tests and results, grouped by type (e.g., biopsy, imaging, blood tests). "
            "Use bullet points for clarity.\n\n"
            "#### Treatment Plan (if provided in the report)\n"
            "Outline the treatments administered during the hospital stay, "
            "including chemotherapy, surgery, radiation, and palliative care. Use concise bullet points.\n\n"
            "#### Hospital Course (if provided in the report)\n"
            "Summarize significant events and complications during the hospital stay, "
            "including patient response to treatments.\n\n"
            "#### Follow-Up Plan (if provided in the report)\n"
            "Detail the post-discharge follow-up plan, such as chemotherapy schedules, imaging, blood tests, and supportive care.\n\n"
            "#### Discharge Instructions (if provided in the report)\n"
            "Provide any specific discharge instructions given to the patient, including medication, diet, and activity recommendations.\n\n"
            "#### Prognosis and Long-Term Outlook (if provided in the report)\n"
            "Provide a brief prognosis and any long-term outlook information, including survival rates if applicable.\n\n"
            "#### Final Remarks (if provided in the report)\n"
            "Include any final remarks or recommendations provided by the attending physician.\n\n"
            f"{doc.page_content}"
        )]

        # Generate summary
        debug_message("[Summarizer] Action: Invoking app for summarization.")
        summary_output = app.invoke({"messages": input_msgs}, config)
        summary = summary_output["messages"][-1].content

        debug_message(f"[Summarizer] Status: Summary generated | Patient MRN: {patient_mrn}")
        debug_message(summary[:500] + "..." if len(summary) > 500 else summary)

        return summary

    @staticmethod
    def summarize_section(patient_mrn, section_title):
        """
        Summarize a specific section of a patient's report.

        Args:
            patient_mrn (str): The medical record number of the patient.
            section_title (str): The title of the section to summarize.

        Returns:
            str: A summary of the specified section or an error message if the section or report is not found.
        """
        debug_message(f"[Summarizer] Action: Starting section summarization | Patient MRN: {patient_mrn} | Section: '{section_title}'")

        # Retrieve the document for the patient
        doc = DocRetriever.retrieve_doc(patient_mrn)
        if not doc:
            error_message = f"No report found for patient MRN: {patient_mrn}."
            debug_message(f"[Summarizer] Status: {error_message}")
            return error_message

        # Build the section-specific prompt
        input_msgs = [HumanMessage(f"Summarize the section titled '{section_title}' in the report for patient {patient_mrn}.")]

        # Generate section summary
        debug_message("[Summarizer] Action: Invoking app for section summarization.")
        output = app.invoke({"messages": input_msgs}, config)
        summary = output["messages"][-1].content

        debug_message(f"[Summarizer] Status: Section summary generated | Patient MRN: {patient_mrn} | Section: '{section_title}'")
        debug_message(summary[:500] + "..." if len(summary) > 500 else summary)

        return summary

    @staticmethod
    def summarize_reports(patient_mrns):
        """
        Generate summaries for multiple patient reports.

        Args:
            patient_mrns (list of str): List of medical record numbers for the patients.

        Returns:
            str: Summarized reports for all patients combined into a single string.
        """
        debug_message("[Summarizer] Action: Starting summarization for multiple patients.")

        all_summaries = []
        for patient_mrn in patient_mrns:
            debug_message(f"[Summarizer] Action: Processing summarization | Patient MRN: {patient_mrn}")
            summary = Summarizer.summarize_report(patient_mrn)
            all_summaries.append(f"Patient {patient_mrn}:\n{summary}")

        combined_summaries = "\n\n".join(all_summaries)
        debug_message("[Summarizer] Status: Completed summarization for all patients.")
        debug_message(combined_summaries[:500] + "..." if len(combined_summaries) > 500 else combined_summaries)

        return combined_summaries


### Similarity

In [15]:
class SimilarFinder:
    """
    A class to find and summarize cases similar to a specified patient 
    by leveraging document embeddings and a FAISS vector store.
    """

    @staticmethod
    def find_similar(patient_mrn, k=5):
        """
        Find patients with similar cases to the specified patient using doc-level embeddings.

        Args:
            patient_mrn (str): The medical record number of the patient to find similar cases for.
            k (int): The number of similar cases to retrieve (default is 5).

        Returns:
            str: A summary of similar cases, highlighting similarities and differences, or an error message.
        """
        debug_message(f"[SimilarFinder] Action: Finding patients similar to patient {patient_mrn}")

        # Retrieve the document for the given patient MRN
        doc = DocRetriever.retrieve_doc(patient_mrn)
        if not doc:
            error_message = f"No report found for patient {patient_mrn}."
            debug_message(f"[SimilarFinder] Status: {error_message}")
            return error_message

        # Generate an embedding for the patient's report
        debug_message("[SimilarFinder] Generating embedding for patient report.")
        patient_embed = embeddings.embed_query(doc.page_content)

        # Search the doc-level vector store for top-k similar documents
        debug_message(f"[SimilarFinder] Searching for top {k} similar documents in vector store.")
        distances, indices = vectorstore.index.search(
            np.array([patient_embed]).astype('float32'), k
        )

        # Retrieve the documents for the similar cases
        similar_docs = []
        debug_message("[SimilarFinder] Retrieved the following similar documents:")
        for idx in indices[0]:
            doc_id = vectorstore.index_to_docstore_id.get(idx)
            if doc_id:
                document = vectorstore.docstore.search(doc_id)
                if document:
                    similar_docs.append(document)
                    debug_message(f"Doc ID: {doc_id} | Content Preview: {document.page_content[:300]}...")

        if not similar_docs:
            error_message = "No similar documents found in the vector store."
            debug_message(f"[SimilarFinder] Status: {error_message}")
            return error_message

        debug_message(f"[SimilarFinder] Found {len(similar_docs)} similar documents.")

        # Optional: Summarize each retrieved document
        summaries = []
        for similar_doc in similar_docs:
            doc_text = similar_doc.page_content
            doc_id = similar_doc.metadata.get('doc_id', 'Unknown ID')
            debug_message(f"[SimilarFinder] Summarizing report for document ID: {doc_id}")

            # Truncate or chunk document text to a manageable size
            truncated_text = doc_text[:1500]
            input_msgs = [
                HumanMessage(content=f"Summarize this medical report:\n{truncated_text}")
            ]
            summary_response = app.invoke({"messages": input_msgs}, config)
            summaries.append(summary_response["messages"][-1].content)

        # Combine summaries and generate a final comparison
        combined_summaries = "\n\n".join(summaries)
        final_summary_prompt = (
            "Combine and summarize the following patient cases, focusing on their similarities and differences:\n\n"
            f"{combined_summaries}"
        )
        debug_message("[SimilarFinder] Generating final summary for similar cases.")
        final_summary_response = app.invoke({"messages": [HumanMessage(content=final_summary_prompt)]}, config)

        final_summary = final_summary_response["messages"][-1].content
        debug_message(f"[SimilarFinder] Final summary generated:")
        debug_message(final_summary[:500] + "..." if len(final_summary) > 500 else final_summary)

        return final_summary

## Intent Classfier

In [None]:
from sentence_transformers import SentenceTransformer, util

reference_examples = {
    "single_qa": [
        "What is the diagnosis for patient 0?",
        "What is the date of admission for patient 1?",
        "Tell me the treatment plan for patient 2.",
        "What are the findings for patient 3?",
        "What medications is patient 4 taking?",
        "What is the discharge summary for patient 0?",
        "What is the medical history of patient 1?",
        "What are the lab results for patient 2?",
        "What is the prognosis for patient 3?",
        "What surgeries has patient 4 undergone?"
    ],
    "multi_qa": [
        "Compare the reports for patients 0 and 1.",
        "What are the diagnoses for patients 2 and 3?",
        "What is the treatment plan for both patients 0 and 1?",
        "When did patients 2 and 3 get discharged?",
        "What are the names of patients 0 and 1?",
        "What are the symptoms for patients 0 and 1?",
        "Compare the lab results of patients 2 and 3.",
        "What are the medications for patients 0 and 1?",
        "What are the findings for patients 2 and 3?",
        "What is the prognosis for patients 0 and 1?"
    ],
    "similar_patient": [
        "Find patients with cases similar to patient 0.",
        "Who has a similar diagnosis as patient 1?",
        "Show patients with the same condition as patient 2.",
        "Find patients with cases similar to patient 3.",
        "Who has a similar treatment plan as patient 4?",
        "Find patients with similar symptoms to patient 0.",
        "Who has a similar medical history as patient 1?",
        "Show patients with similar lab results as patient 2.",
        "Find patients with similar prognoses to patient 3.",
        "Who has undergone similar surgeries as patient 4?"
    ],
    "summarization": [
        "Summarize the report for patient 0.",
        "What are the key points from patient 1's file?",
        "Give me an overview of the report for patient 2.",
        "Summarize the report for patient 3."
    ]
}

# Question classifier and LLM response
class IntentClassifier:
    def __init__(self):
        """
        Initialize the IntentClassifier with a pretrained SentenceTransformer model 
        and reference embeddings for intent classification.
        """
        self.model = SentenceTransformer("all-mpnet-base-v2")
        self.ref_embeds = self._embed_refs()

    def _embed_refs(self):
        """
        Embed all reference examples and store them for similarity comparisons.

        Returns:
            dict: A dictionary mapping intents to their embeddings.
        """
        debug_message("[IntentClassifier] Action: Embedding reference examples.")
        embeds = {}
        for intent, examples in reference_examples.items():
            embeds[intent] = self.model.encode(examples, convert_to_tensor=True)
        debug_message("[IntentClassifier] Status: Reference examples embedded successfully.")
        return embeds

    def classify(self, question):
        """
        Classify the intent of a question using similarity to reference examples.

        Args:
            question (str): The input question to classify.

        Returns:
            str: The intent with the highest similarity to the question.
        """
        debug_message(f"[IntentClassifier] Action: Classifying question intent | Question: {question}")
        question_embed = self.model.encode(question, convert_to_tensor=True)

        # Compare question embedding to reference embeddings
        intent_scores = {}
        for intent, embeds in self.ref_embeds.items():
            similarities = util.pytorch_cos_sim(question_embed, embeds)
            avg_similarity = similarities.mean().item()
            intent_scores[intent] = avg_similarity

        # Find the intent with the highest average similarity
        best_intent = max(intent_scores, key=intent_scores.get)
        debug_message(f"[IntentClassifier] Status: Classified intent | Intent: {best_intent} | Scores: {intent_scores}")
        return best_intent

class LLMHandler:
    classifier = IntentClassifier()  # Instantiate the classifier

    @staticmethod
    def llm_response(question, summarize_report=False):
        """
        Generate a response for a given question by classifying intent and routing to the appropriate handler.

        Args:
            question (str): The input question to process.
            summarize_report (bool): Whether to generate a report summary (default: False).

        Returns:
            str: The response generated based on the question intent.
        """
        debug_message(f"[LLMHandler] Action: Processing question | Question: {question}")

        global context

        # Extract patient MRNs from the question
        patient_mrns = Util.extract_mrns(question)
        if not patient_mrns:
            debug_message("[LLMHandler] Status: No patient MRNs found in question. Using context MRNs.")
            patient_mrns = context["mrns"]
        else:
            context["mrns"] = patient_mrns  # Update context with new patient MRNs

        # Classify intent using the IntentClassifier
        intent = LLMHandler.classifier.classify(question)

        # Adjust intent if multiple patient MRNs are in context
        if len(context["mrns"]) >= 2 and intent == "single_qa":
            intent = "multi_qa"
            debug_message("[LLMHandler] Status: Adjusted intent to multi_qa based on context.")

        # Classify based on intent
        if intent == "similar_patient" and len(patient_mrns) == 1:
            debug_message(f"[LLMHandler] Action: Finding similar patients | Patient MRN: {patient_mrns[0]}")
            return SimilarFinder.find_similar(patient_mrns[0])

        elif intent == "single_qa" and len(patient_mrns) == 1:
            debug_message(f"[LLMHandler] Action: Answering single patient question | Patient MRN: {patient_mrns[0]}")
            return SingleQA.answer(patient_mrns[0], question)

        elif intent == "multi_qa" and len(patient_mrns) >= 2:
            debug_message(f"[LLMHandler] Action: Answering multiple patient question | Patient MRNs: {patient_mrns}")
            return MultiQA.multi_answer(question, patient_mrns)

        elif intent == "multi_qa" and len(patient_mrns) == 0 and len(context["mrns"]) >= 2:
            debug_message(f"[LLMHandler] Action: Answering multiple patient question using context MRNs | Context MRNs: {context['mrns']}")
            return MultiQA.multi_answer(question, context["mrns"])

        elif intent == "summarization":
            if len(patient_mrns) == 1:
                debug_message(f"[LLMHandler] Action: Summarizing report for single patient | Patient MRN: {patient_mrns[0]}")
                return Summarizer.summarize_report(patient_mrns[0])
            debug_message("[LLMHandler] Action: Summarizing reports for multiple patients.")
            return Summarizer.summarize_patients(patient_mrns)

        # Fallback for unsupported or unclear intents
        debug_message("[LLMHandler] Status: Intent unclear or unsupported. Returning fallback response.")
        return "I'm sorry, I couldn't understand your request. Could you rephrase it?"

# **Testing**

In [36]:
def reset_context():
    global context
    context = {
        "ids": [],
        "docs": {},
        "specific_docs": {},
        "chunks": {},
        "top_chunks": {}
    }

reset_context()

In [None]:
question = "What is the diagnosis of patient 1?"
response = LLMHandler.llm_response(question)
print("\n---------- Response ----------")
print(response)