<a href="https://colab.research.google.com/github/vgu-its24-psd/MedDiag/blob/main/MedTool.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install packages
!pip install -qU langchain-qdrant
!pip install -qU langchain-huggingface

In [None]:
# Import packages
from transformers import pipeline # to create the pipeline to LLM Model in Hugging Faces
# Qdrant from langchain
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.storage import InMemoryStore
from langchain.retrievers.multi_vector import MultiVectorRetriever
# HuggingFace embedding model for Qdrant
from langchain_huggingface import HuggingFaceEmbeddings
import torch
import os
from PIL import Image

In [None]:
"""
Dengue Fever Clinical Diagnostic System using MedGemma
This system analyzes patient symptoms and medical images to assess dengue fever likelihood
"""
class DengueDiagnosticSystem:
    def __init__(self):
      # Medgemma pipeline
      self.diagnostic_pipeline = pipeline("image-text-to-text", model="google/medgemma-4b-it", torch_dtype=torch.bfloat16, device="cuda")
      # Qdrant client
      client = QdrantClient(url="https://2fe338c1-dc5a-45ea-98fc-5a653ed6567d.us-east4-0.gcp.cloud.qdrant.io:6333", api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9BdzZ5Q-sMPLQAJQJ-Q5dVMwXqT_2J6IJoz6wWCuYoo")
      embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
      # The retriever (empty to start)
      self.RAG = MultiVectorRetriever(
                      vectorstore=QdrantVectorStore(client=client, collection_name="demo_collection", embedding=embeddings),
                      docstore=InMemoryStore(),
                      id_key="doc_id",
                      )
    def rag_retriever(self, user_input, topk = 1):
        docs = self.RAG.vectorstore.similarity_search(user_input, k=topk)
        return docs

    def generate_diagnosis(self, user_input, list_image_path, docs):
        """Generate diagnostic assessment using MedGemma"""
         # concatenate page_content
        context = "\n".join(d.page_content for d in docs)
        system_instruction = "You are an expert clinical diagnostic AI assistant specializing in infectious diseases."
        role_instruction = f"""
        User input: {user_input}
        Contextual information: {context}

        You are a medical analysis model specializing in dengue fever diagnosis.
        Your task is to analyze the provided user input, which includes travel history, personal information (age, gender, infection history, symptoms, etc.),
        and relevant context documents retrieved from a Vector Database containing medical guidelines, dengue fever symptoms, risk factors, and epidemiological data.
        Based on this information, determine the likelihood of the user being infected with dengue fever.
        Input Processing:

        User Input: Extract and evaluate details such as:
        Travel history.
        Personal information.
        Reported symptoms.
        Other relevant details.

        Context Documents: Use the retrieved documents to cross-reference symptoms, risk factors, and regional data.

        Analysis Guidelines:

        Assess the presence of key dengue symptoms.
        Consider risk factors.
        Account for user demographics.
        Evaluate prior infection history.
        Use the context from the Vector Database to weigh the likelihood based on epidemiological patterns and clinical guidelines.

        Output Requirements:

        Provide only the final decision in the format: 'The user [might/might not] be infected with dengue fever.'
        Do not include any explanations, reasoning, or instructions in the response.
        Use 'might' if the analysis suggests a plausible chance of infection based on symptoms, travel history, or risk factors aligning with dengue fever.
        Use 'might not' if the analysis indicates insufficient evidence or low likelihood of dengue fever.

        Example Output:
        The user might be infected with dengue fever.
        """

        messages =  self.build_messages(system_instruction, role_instruction, list_image_path)

        response = self.diagnostic_pipeline(text=messages, max_new_tokens=200)
            # Extract the text content from the output
        return response[0]["generated_text"][-1]["content"]

    def build_messages(self, system_instruction: str, role_instruction: str, images: list[str] = []):
        """
        images: list of image paths or base64. Can be empty.
        """
        user_content = [{"type": "text", "text": role_instruction}]
        for i in images:
            im = Image.open(i)
            user_content.append({"type": "image", "image": im})

        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_instruction}],
            },
            {
                "role": "user",
                "content": user_content,
            },
        ]
        return messages

    def process_case(self, user_input, list_image_path=[]):
        """Process a complete diagnostic case"""
        print("Processing diagnostic case...")
        docs = self.rag_retriever(user_input, topk=3)
        # Generate diagnosis
        diagnosis = self.generate_diagnosis(user_input, list_image_path, docs)

        return diagnosis

In [None]:
# Initialize Diagnostic System
diagnostic_system = DengueDiagnosticSystem()

In [None]:
# Interactive mode for Google Colab
def interactive_diagnosis():
    """Interactive mode for easier use in Colab"""
    print("Dengue Fever Diagnostic System - Interactive Mode")
    print("="*50)

    # Get user input
    symptoms = input("\nEnter patient symptoms: ")

    image_path = input("Enter image path (or press Enter to skip): ").strip()
    if not image_path:
        image_paths = [] # Pass an empty list if no image is provided
    elif not os.path.exists(image_path):
        print(f"Warning: Image file '{image_path}' not found. Proceeding without image.")
        image_paths = [] # Pass an empty list if image file not found
    else:
        image_paths = [image_path] # Pass a list with the image path

    # Run diagnosis
    result = diagnostic_system.process_case(symptoms, image_paths)

    # Display results
    print("\n" + "="*60)
    print("DENGUE FEVER DIAGNOSTIC ASSESSMENT")
    print("="*60)
    print(f"Symptoms: {symptoms}")
    if image_paths:
        print(f"Image: {image_paths[0]}") # Display the first image path if available
    print("\n CLINICAL ANALYSIS:")
    print(result)
    print("="*60)

In [None]:
# Example usage in Colab:
"""
# Example symptoms to test:
# "sudden high fever 39°C, severe headache, retro-orbital pain, myalgia, skin rash, nausea"
# "fever, headache, muscle pain, petechiae, bleeding gums, abdominal pain"
# "mild fever, cough, runny nose" (should be low likelihood)
"""

interactive_diagnosis()

In [None]:
# test rag_retriever
query = "I have fever and headache"
diagnostic_system.rag_retriever(query)

In [None]:
# test build_messages()
image = ["/content/Skin rash from dengue fever_p1_img23_434c0412.png"]
diagnostic_system.build_messages("My system instruction", "my role instruction", image)
