<a href="https://colab.research.google.com/github/vgu-its24-psd/MedDiag/blob/main/MedTool.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MMRAG - LLM - START

In [None]:
# Install packages
!pip install -qU langchain-qdrant
!pip install -qU langchain-huggingface

In [None]:
# Import packages
from transformers import pipeline # to create the pipeline to LLM Model in Hugging Faces
# Qdrant from langchain
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.storage import InMemoryStore
from langchain.retrievers.multi_vector import MultiVectorRetriever
# HuggingFace embedding model for Qdrant
from langchain_huggingface import HuggingFaceEmbeddings
import torch
import os
from PIL import Image

In [None]:
QdrantURL = ""
QdrantAPIKey= ""

In [None]:
"""
Dengue Fever Clinical Diagnostic System using MedGemma
This system analyzes patient symptoms and medical images to assess dengue fever likelihood
"""
class DengueDiagnosticSystem:
    def __init__(self):
      # Medgemma pipeline
      #self.diagnostic_pipeline = pipeline("image-text-to-text", model="google/medgemma-4b-it", torch_dtype=torch.bfloat16, device="cuda")
      # self.diagnostic_pipeline = pipeline("image-text-to-text", model="longbao128/gemma-4b-dengue-diagnosis", torch_dtype=torch.bfloat16, device="cuda")
      self.diagnostic_pipeline = pipeline("image-text-to-text", model="google/gemma-3-4b-it", torch_dtype=torch.bfloat16, device="cuda")

      # Qdrant client
      client = QdrantClient(url=QdrantURL, api_key=QdrantAPIKey)
      embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
      # The retriever (empty to start)
      self.RAG = MultiVectorRetriever(
                      vectorstore=QdrantVectorStore(client=client, collection_name="demo_collection", embedding=embeddings),
                      docstore=InMemoryStore(),
                      id_key="doc_id",
                      )
    def rag_retriever(self, user_input, topk = 1):
        docs = self.RAG.vectorstore.similarity_search(user_input, k=topk)
        return docs

    def generate_diagnosis(self, user_input, list_image_path, docs):
        """Generate diagnostic assessment using MedGemma"""
         # concatenate page_content
        context = "\n".join(d.page_content for d in docs)
        system_instruction = "You are an expert clinical diagnostic AI assistant specializing in infectious diseases."
        role_instruction = f"""
        User input: {user_input}
        Retrieved context: {context}

        You are a medical analysis model specialized in dengue fever.

        TASK:
        Analyze the user input (travel history, age, gender, infection history, reported symptoms) together with the retrieved context (medical guidelines, clinical reports, epidemiological data, and any provided images). Determine the likelihood that the user is infected with dengue fever.

        ASSESSMENT RULES:
        - Consider key dengue symptoms, risk factors, demographics, and infection history.
        - Cross-check user details against contextual evidence and images if provided.
        - Decide based on alignment with clinical and epidemiological patterns.

        OUTPUT FORMAT:
        Respond with one sentence only, using exactly one of the two options:
        - "The user might be infected with dengue fever."
        - "The user might not be infected with dengue fever."

        Do not provide explanations, reasoning, or instructions.
        """

        messages =  self.build_messages(system_instruction, role_instruction, list_image_path)

        response = self.diagnostic_pipeline(text=messages, max_new_tokens=200)
        #response = self.diagnostic_pipeline(text=messages, max_new_tokens=128, return_full_text=False)
            # Extract the text content from the output
        return response[0]["generated_text"][-1]["content"]
        #return response[0]["generated_text"]

    def build_messages(self, system_instruction: str, role_instruction: str, images: list[str] = []):
        """
        images: list of image paths or base64. Can be empty.
        """
        user_content = [{"type": "text", "text": role_instruction}]
        for i in images:
            im = Image.open(i)
            user_content.append({"type": "image", "image": im})

        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_instruction}],
            },
            {
                "role": "user",
                "content": user_content,
            },
        ]
        return messages

    def process_case(self, user_input, list_image_path=[]):
        """Process a complete diagnostic case"""
        print("Processing diagnostic case...")
        docs = self.rag_retriever(user_input, topk=3)
        # Generate diagnosis
        diagnosis = self.generate_diagnosis(user_input, list_image_path, docs)

        return diagnosis

In [None]:
# Initialize Diagnostic System
diagnostic_system = DengueDiagnosticSystem()

In [None]:
# Interactive mode for Google Colab
def interactive_diagnosis():
    """Interactive mode for easier use in Colab"""
    print("Dengue Fever Diagnostic System - Interactive Mode")
    print("="*50)

    # Get user input
    symptoms = input("\nEnter patient symptoms: ")

    image_path = input("Enter image path (or press Enter to skip): ").strip()
    if not image_path:
        image_paths = [] # Pass an empty list if no image is provided
    elif not os.path.exists(image_path):
        print(f"Warning: Image file '{image_path}' not found. Proceeding without image.")
        image_paths = [] # Pass an empty list if image file not found
    else:
        image_paths = [image_path] # Pass a list with the image path

    # Run diagnosis
    result = diagnostic_system.process_case(symptoms, image_paths)

    # Display results
    print("\n" + "="*60)
    print("DENGUE FEVER DIAGNOSTIC ASSESSMENT")
    print("="*60)
    print(f"Symptoms: {symptoms}")
    if image_paths:
        print(f"Image: {image_paths[0]}") # Display the first image path if available
    print("\n CLINICAL ANALYSIS:")
    print(result)
    print("="*60)

In [None]:
# Example usage in Colab:
"""
# Example symptoms to test:
# "sudden high fever 39°C, severe headache, retro-orbital pain, myalgia, skin rash, nausea"
# "fever, headache, muscle pain, petechiae, bleeding gums, abdominal pain"
# "mild fever, cough, runny nose" (should be low likelihood)
"""

interactive_diagnosis()

In [None]:
# test rag_retriever
query = "I have fever and headache"
diagnostic_system.rag_retriever(query)

In [None]:
# test build_messages()
image = ["/content/Skin rash from dengue fever_p1_img23_434c0412.png"]
diagnostic_system.build_messages("My system instruction", "my role instruction", image)


# MMRAG - LLM - END

# Added interface to GUI - START

In [None]:
!pip install flask flask-cors pyngrok

In [None]:
from flask import Flask, request, jsonify
from pyngrok import ngrok
import base64
import tempfile

In [None]:
ngrokTOKEN = ""

In [None]:
# assume diagnostic_system = DengueDiagnosticSystem(...)
# already initialized

app = Flask(__name__)

@app.route("/diagnose", methods=["POST"])
def diagnose():
    data = request.get_json()
    user_input = data.get("user_input", "")
    images_b64 = data.get("images", [])

    list_image_path = []
    for img_b64 in images_b64:
        # save temporary file in Colab
        fd, tmp_path = tempfile.mkstemp(suffix=".png")
        with open(tmp_path, "wb") as f:
            f.write(base64.b64decode(img_b64))
        list_image_path.append(tmp_path)

    result = diagnostic_system.process_case(user_input, list_image_path)
    return jsonify({"result": result})

# Set your ngrok authtoken here
# Replace "YOUR_NGROK_AUTHTOKEN" with your actual authtoken
ngrok.set_auth_token(ngrokTOKEN)

# tunnel
public_url = ngrok.connect(5000)
print("Public URL:", public_url)

app.run(port=5000)

# Added interface to GUI - END

# Evaluation  - START

In [None]:
!pip install -qU ragas
!pip install -qU pandas numpy

In [None]:
# Import RAGAS and evaluation dependencies
import pandas as pd
import numpy as np
from ragas import evaluate
from ragas.metrics import (
    answer_relevancy,
    faithfulness,
    context_precision,
    context_recall,
    answer_correctness,
    answer_similarity
)
from datasets import Dataset
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Load and prepare evaluation dataset
eval_df = pd.read_csv('eval_dengue.csv')
print("Evaluation dataset shape:", eval_df.shape)
print("\nDataset columns:", eval_df.columns.tolist())
print("\nFirst few rows:")
print(eval_df.head())

In [None]:
# Prepare evaluation data from dengue CSV
def prepare_evaluation_data(eval_df, sample_size=10):
    """Convert dengue patient data to evaluation format"""
    evaluation_data = []

    for idx, row in eval_df.head(sample_size).iterrows():
        # Create symptom description from patient data
        symptoms = []
        if pd.notna(row.get('dengue.current_temp', 0)) and row.get('dengue.current_temp', 0) > 0:
            symptoms.append(f"fever {row['dengue.current_temp']}°F")
        if row.get('dengue.servere_headche') == 'yes':
            symptoms.append("severe headache")
        if row.get('dengue.pain_behind_the_eyes') == 'yes':
            symptoms.append("retro-orbital pain")
        if row.get('dengue.joint_muscle_aches') == 'yes':
            symptoms.append("muscle and joint aches")
        if row.get('dengue.metallic_taste_in_the_mouth') == 'yes':
            symptoms.append("metallic taste")
        if row.get('dengue.appetite_loss') == 'yes':
            symptoms.append("loss of appetite")
        if row.get('dengue.addominal_pain') == 'yes':
            symptoms.append("abdominal pain")
        if row.get('dengue.nausea_vomiting') == 'yes':
            symptoms.append("nausea and vomiting")
        if row.get('dengue.diarrhoea') == 'yes':
            symptoms.append("diarrhea")

        # Create patient query
        travel_history = f"recently traveled to {row['dengue.residence']}" if pd.notna(row['dengue.residence']) else ""
        duration = f"symptoms for {row['dengue.days']}" if pd.notna(row['dengue.days']) else ""

        question = f"Patient with {', '.join(symptoms)}. {travel_history}. {duration}. Could this be dengue fever?"

        # Ground truth answer
        ground_truth = "The user might be infected with dengue fever." if row['dengue.dengue'] == 'yes' else "The user might not be infected with dengue fever."

        evaluation_data.append({
            'patient_id': row['dengue.p_i_d'],
            'question': question,
            'ground_truth': ground_truth,
            'actual_diagnosis': row['dengue.dengue']
        })

    return evaluation_data

# Prepare evaluation dataset
eval_data = prepare_evaluation_data(eval_df, sample_size=10)
print(f"Prepared {len(eval_data)} evaluation cases")
for i, case in enumerate(eval_data[:3]):
    print(f"\nCase {i+1}:")
    print(f"Question: {case['question']}")
    print(f"Ground Truth: {case['ground_truth']}")
    print(f"Actual: {case['actual_diagnosis']}")

In [None]:
# RAG System Evaluation with RAGAS
def evaluate_diagnostic_system(diagnostic_system, eval_data):
    """Evaluate the diagnostic system using RAGAS metrics"""

    questions = []
    answers = []
    contexts = []
    ground_truths = []

    print("Running diagnostic system on evaluation cases...")

    for i, case in enumerate(eval_data):
        try:
            print(f"Processing case {i+1}/{len(eval_data)}: {case['patient_id']}")

            # Get RAG response
            query = case['question']

            # Retrieve context documents
            docs = diagnostic_system.rag_retriever(query, topk=3)
            context_text = "\n".join([doc.page_content for doc in docs])

            # Generate diagnosis
            diagnosis = diagnostic_system.generate_diagnosis(query, [], docs)

            # Store results for RAGAS evaluation
            questions.append(query)
            answers.append(diagnosis)
            contexts.append([context_text])  # RAGAS expects list of contexts
            ground_truths.append(case['ground_truth'])

        except Exception as e:
            print(f"Error processing case {case['patient_id']}: {str(e)}")
            continue

    # Create RAGAS dataset
    ragas_dataset = Dataset.from_dict({
        "question": questions,
        "answer": answers,
        "contexts": contexts,
        "ground_truth": ground_truths
    })

    return ragas_dataset

# Run evaluation
print("Starting RAGAS evaluation...")
ragas_dataset = evaluate_diagnostic_system(diagnostic_system, eval_data)
print(f"Created dataset with {len(ragas_dataset)} samples")

In [None]:
def evaluate_diagnostic_system_custom(diagnostic_system, eval_data):
    """Evaluate using your existing MedGemma + RAG system"""

    questions = []
    answers = []
    contexts = []
    ground_truths = []

    print("Running diagnostic system evaluation using MedGemma + RAG...")

    for i, case in enumerate(eval_data):
        try:
            print(f"Processing case {i+1}/{len(eval_data)}: {case['patient_id']}")

            # Use your existing process_case method
            query = case['question']

            # Get diagnosis using your complete system (RAG + MedGemma)
            diagnosis = diagnostic_system.process_case(query, [])  # No images for now

            # Get context for evaluation (retrieve separately for analysis)
            docs = diagnostic_system.rag_retriever(query, topk=3)
            context_text = "\n".join([doc.page_content for doc in docs])

            # Store results
            questions.append(query)
            answers.append(diagnosis)
            contexts.append([context_text])
            ground_truths.append(case['ground_truth'])

        except Exception as e:
            print(f"Error processing case {case['patient_id']}: {str(e)}")
            continue

    # Create evaluation dataset
    eval_results = {
        "question": questions,
        "answer": answers,
        "contexts": contexts,
        "ground_truth": ground_truths
    }

    return eval_results

def calculate_custom_metrics(eval_results):
    """Calculate evaluation metrics without external APIs"""

    print("Calculating custom evaluation metrics...")

    results = []

    for i in range(len(eval_results['question'])):
        question = eval_results['question'][i]
        answer = eval_results['answer'][i]
        ground_truth = eval_results['ground_truth'][i]
        context = eval_results['contexts'][i][0] if eval_results['contexts'][i] else ""

        # 1. Answer Relevancy (keyword overlap with question)
        question_words = set(question.lower().split())
        answer_words = set(answer.lower().split())
        common_words = question_words.intersection(answer_words)
        answer_relevancy = len(common_words) / len(question_words) if question_words else 0

        # 2. Answer Correctness (check if diagnosis matches)
        # Extract diagnosis prediction from answer
        answer_lower = answer.lower()
        ground_lower = ground_truth.lower()

        if "might be infected" in ground_lower and "might be infected" in answer_lower:
            answer_correctness = 1.0
        elif "might not be infected" in ground_lower and "might not be infected" in answer_lower:
            answer_correctness = 1.0
        elif "might be infected" in ground_lower and "might not" in answer_lower:
            answer_correctness = 0.0
        elif "might not be infected" in ground_lower and ("might be" in answer_lower and "might not" not in answer_lower):
            answer_correctness = 0.0
        else:
            # Partial match based on key terms
            if "dengue" in answer_lower and "dengue" in ground_lower:
                answer_correctness = 0.5
            else:
                answer_correctness = 0.0

        # 3. Context Precision (how much of context is used in answer)
        context_words = set(context.lower().split()) if context else set()
        context_used = context_words.intersection(answer_words)
        context_precision = len(context_used) / len(context_words) if context_words else 0

        # 4. Faithfulness (answer should be based on context, not hallucinated)
        # Check if answer contains information not in context
        medical_terms = {"fever", "headache", "dengue", "symptoms", "temperature", "travel", "infection"}
        answer_medical = medical_terms.intersection(answer_words)
        context_medical = medical_terms.intersection(context_words) if context_words else set()

        if answer_medical and context_medical:
            faithfulness = len(answer_medical.intersection(context_medical)) / len(answer_medical)
        else:
            faithfulness = 0.5  # neutral if no medical terms

        # 5. Answer Similarity (semantic similarity to ground truth)
        # Simple word overlap similarity
        ground_words = set(ground_truth.lower().split())
        similarity = len(answer_words.intersection(ground_words)) / len(answer_words.union(ground_words)) if answer_words.union(ground_words) else 0

        results.append({
            'question': question,
            'answer': answer,
            'ground_truth': ground_truth,
            'contexts': eval_results['contexts'][i],
            'answer_relevancy': answer_relevancy,
            'answer_correctness': answer_correctness,
            'context_precision': context_precision,
            'faithfulness': faithfulness,
            'answer_similarity': similarity
        })

    # Create results DataFrame
    results_df = pd.DataFrame(results)

    # Create a results object that mimics RAGAS output
    class CustomResults:
        def __init__(self, df):
            self.df = df
        def to_pandas(self):
            return self.df

    return CustomResults(results_df)

# Run evaluation using your MedGemma system
print("Starting evaluation with your MedGemma + RAG system...")
eval_results = evaluate_diagnostic_system_custom(diagnostic_system, eval_data)
evaluation_results = calculate_custom_metrics(eval_results)
print(f"Evaluation completed on {len(eval_results['question'])} test cases")

In [None]:
# Display Comprehensive Results
print("\n" + "="*50)
print("RAGAS EVALUATION RESULTS")
print("="*50)

# Convert results to DataFrame for better visualization
results_df = evaluation_results.to_pandas()
print(f"\nEvaluation completed on {len(results_df)} test cases")
print("\nOverall Metrics Summary:")
print("-" * 50)

# Calculate and display metric means
for column in results_df.columns:
    if column not in ['question', 'answer', 'contexts', 'ground_truth']:
        mean_score = results_df[column].mean()
        print(f"{column.replace('_', ' ').title()}: {mean_score:.4f}")

print("\nDetailed Results:")
print("-" * 50)
print(results_df.round(4))

# Show specific cases with high and low performance
if 'answer_relevancy' in results_df.columns:
    print("\n" + "="*60)
    print("CASE ANALYSIS")
    print("="*60)

    # Best performing case
    best_idx = results_df['answer_relevancy'].idxmax()
    print(f"\nBest Performing Case (Answer Relevancy: {results_df.loc[best_idx, 'answer_relevancy']:.4f}):")
    print(f"Question: {results_df.loc[best_idx, 'question']}")
    print(f"Generated Answer: {results_df.loc[best_idx, 'answer']}")
    print(f"Ground Truth: {results_df.loc[best_idx, 'ground_truth']}")

    # Worst performing case
    worst_idx = results_df['answer_relevancy'].idxmin()
    print(f"\nWorst Performing Case (Answer Relevancy: {results_df.loc[worst_idx, 'answer_relevancy']:.4f}):")
    print(f"Question: {results_df.loc[worst_idx, 'question']}")
    print(f"Generated Answer: {results_df.loc[worst_idx, 'answer']}")
    print(f"Ground Truth: {results_df.loc[worst_idx, 'ground_truth']}")

print("\n" + "="*80)

In [None]:
print("\n" + "="*80)
print("PERFORMANCE ANALYSIS")
print("="*80)

# Analyze results and provide insights
if 'answer_relevancy' in results_df.columns:
    relevancy_score = results_df['answer_relevancy'].mean()
    print(f"\nANSWER RELEVANCY: {relevancy_score:.4f}")

if 'faithfulness' in results_df.columns:
    faithfulness_score = results_df['faithfulness'].mean()
    print(f"FAITHFULNESS: {faithfulness_score:.4f}")

if 'context_precision' in results_df.columns:
    precision_score = results_df['context_precision'].mean()
    print(f"CONTEXT PRECISION: {precision_score:.4f}")

if 'answer_correctness' in results_df.columns:
    correctness_score = results_df['answer_correctness'].mean()
    print(f"\nANSWER CORRECTNESS: {correctness_score:.4f}")


# Evaluation  - END