<a href="https://colab.research.google.com/github/long-nguyen-bao-ts/test/blob/main/dengue_diagnostic_interactive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interactive Dengue Fever Diagnostic System

## Features:
- Clinical text-based diagnosis from symptoms and lab values
- Medical image-based diagnosis
- Combined multimodal diagnosis (text + image)

## Setup and Installation

In [None]:
# Install dependencies
!pip install --upgrade --quiet transformers torch pillow gradio peft accelerate

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import gradio as gr
import os
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

## Load Fine-tuned Model

In [None]:
class DengueDiagnosticSystem:
    def __init__(self, model_id="longbao128/gemma-4b-dengue-diagnosis"):
        self.model_id = model_id
        self.processor = None
        self.model = None
        self.model_loaded = False
        logger.info(f"Initializing Dengue Diagnostic System with model: {model_id}")

    def load_model(self):
        """Load the fine-tuned model and processor"""
        if self.model_loaded:
            return True

        try:
            logger.info("Loading processor...")
            self.processor = AutoProcessor.from_pretrained(self.model_id)

            logger.info("Loading model... This may take a few minutes.")
            self.model = AutoModelForImageTextToText.from_pretrained(
                self.model_id,
                torch_dtype=torch.bfloat16,
                device_map="auto"
            )

            # Configure for inference
            self.model.generation_config.do_sample = False
            self.model.generation_config.pad_token_id = self.processor.tokenizer.eos_token_id
            self.processor.tokenizer.padding_side = "left"

            self.model_loaded = True
            logger.info("✅ Model loaded successfully!")
            logger.info(f"✅ Running on device: {self.model.device}")
            return True

        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            return False

    def is_dengue_related(self, clinical_info: str) -> tuple:
        """Check if input is related to dengue fever diagnosis"""
        info_lower = clinical_info.lower()

        # Dengue-related keywords
        dengue_keywords = ['dengue', 'fever', 'temperature', 'platelet', 'wbc', 'headache',
                          'pain behind eyes', 'joint', 'muscle aches', 'rash', 'hemorrhage',
                          'thrombocytopenia', 'hemoglobin', 'hematocrit', 'nausea', 'vomiting',
                          'abdominal pain', 'patient', 'lab values', 'symptoms', 'diagnosis']

        # Non-medical/random keywords that indicate irrelevant input
        irrelevant_keywords = ['hello', 'hi', 'test', 'random', 'asdf', 'qwerty', '123',
                              'weather', 'football', 'car', 'computer', 'programming',
                              'movie', 'game', 'music', 'food', 'restaurant']

        # Check for irrelevant content
        has_irrelevant = any(keyword in info_lower for keyword in irrelevant_keywords)
        has_medical = any(keyword in info_lower for keyword in dengue_keywords[:5])

        if has_irrelevant and not has_medical:
            return False, """⚠️ **NOT RELATED TO DENGUE DIAGNOSIS**

The input appears to be unrelated to medical diagnosis or dengue fever.

**This system is specifically designed for dengue fever diagnosis.**

Please provide:
- Patient symptoms (fever, headache, pain behind eyes, joint/muscle aches, rash, etc.)
- Laboratory values (WBC, platelet count, hemoglobin, hematocrit, etc.)
- Clinical history and fever duration
- Patient demographics

If you need general medical assistance, please consult a healthcare professional."""

        # Check for minimum medical context
        if len(info_lower.split()) < 10:
            return False, """⚠️ **INSUFFICIENT INFORMATION**

Please provide more detailed clinical information for diagnosis, including:
- Patient symptoms and their duration
- Laboratory test results
- Fever status and temperature
- Other relevant clinical details

Example format:
"Patient from [location], fever duration: X days, current temperature: XXX°F
Lab values - WBC: X, Platelet: X
Present symptoms: [list]
Absent symptoms: [list]" """

        # Check if it contains at least some medical context
        medical_word_count = sum(1 for keyword in dengue_keywords if keyword in info_lower)
        if medical_word_count < 2:
            return False, """⚠️ **CANNOT PROCESS - NOT RELATED TO SYSTEM**

This system is specifically designed for dengue fever diagnosis. The provided information does not appear to contain relevant medical or clinical data.

**For dengue diagnosis, please provide:**
- Fever status and duration
- Clinical symptoms (headache, pain behind eyes, joint/muscle aches, rash, etc.)
- Laboratory results (platelet count, WBC, hemoglobin, etc.)
- Patient demographics and history

**This system cannot process:**
- General questions
- Non-medical inquiries
- Random text or test inputs"""

        return True, ""

    def create_clinical_prompt(self, clinical_info: str) -> str:
        """Create formatted prompt for clinical text diagnosis (matches training format)"""
        prompt = f"""You are a medical AI assistant specializing in dengue fever diagnosis. Analyze the following clinical case and provide a diagnosis with detailed explanation.

Clinical Information:
{clinical_info}

Instructions:
- Consider both present and absent symptoms in your diagnosis
- When symptom information is unknown, factor this uncertainty into your assessment
- Base your diagnosis on available clinical evidence
- Dengue fever typically presents with fever, headache, pain behind eyes, joint/muscle aches, and low platelet count

Question: Based on the available clinical data, does this patient have dengue fever?

Provide your diagnosis as either "Yes - Dengue positive" or "No - Dengue negative".
Then provide a brief explanation (1-2 sentences) of the key factors supporting your diagnosis, mentioning specific symptoms or lab values that influenced your decision."""
        return prompt

    def create_image_prompt(self) -> str:
        """Create formatted prompt for medical image diagnosis (matches training format)"""
        prompt = """You are a medical AI assistant analyzing medical imagery for signs of dengue fever.

Analyze this medical image carefully for visual evidence of dengue fever symptoms or complications.

Instructions:
- Look for skin manifestations (rash, petechiae)
- Consider any visible clinical signs
- Provide your assessment based on the visual evidence

Question: Based on the visual evidence in this medical image, does this show signs consistent with dengue fever?

Respond with "Yes - Dengue positive" or "No - Dengue negative".
Then briefly explain (1 sentence) what visual features you observed that support your diagnosis."""
        return prompt

    def parse_diagnosis(self, response_text: str) -> tuple:
        """Parse model response to extract diagnosis"""
        response_lower = response_text.lower()

        # Check for exact trained responses
        if "yes - dengue positive" in response_lower or "yes-dengue positive" in response_lower:
            return "🔴 DENGUE POSITIVE", "positive", response_text
        elif "no - dengue negative" in response_lower or "no-dengue negative" in response_lower:
            return "🟢 DENGUE NEGATIVE", "negative", response_text
        else:
            # Fallback for partial matches
            if ("positive" in response_lower and "dengue" in response_lower) or \
               ("yes" in response_lower and "dengue" in response_lower):
                return "🔴 DENGUE POSITIVE (uncertain)", "positive", response_text
            elif ("negative" in response_lower and "dengue" in response_lower) or \
                 ("no" in response_lower and "dengue" in response_lower):
                return "🟢 DENGUE NEGATIVE (uncertain)", "negative", response_text
            else:
                return "⚠️ UNCERTAIN DIAGNOSIS", "uncertain", response_text

    def diagnose_text(self, clinical_info: str) -> str:
        """Diagnose dengue from clinical text information"""
        if not clinical_info or clinical_info.strip() == "":
            return "⚠️ **Error**: Please provide clinical information for diagnosis."

        # Validate input is dengue-related
        is_valid, error_message = self.is_dengue_related(clinical_info)
        if not is_valid:
            return error_message

        if not self.model_loaded:
            if not self.load_model():
                return "❌ **Error**: Failed to load model. Please check logs."

        try:
            logger.info("Processing clinical text diagnosis...")
            prompt = self.create_clinical_prompt(clinical_info)

            # Format exactly as in training
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt}
                    ]
                }
            ]

            # Apply chat template with generation prompt
            text = self.processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=False
            ).strip()

            # Tokenize and generate (increased max tokens for explanation)
            inputs = self.processor(text=text, return_tensors="pt", padding=True)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            logger.info("Generating diagnosis...")
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=150,  # Increased for explanation
                    pad_token_id=self.processor.tokenizer.eos_token_id
                )

            # Decode full response
            full_response = self.processor.decode(outputs[0], skip_special_tokens=True)

            # Extract only the assistant's response (after the last "model" marker)
            if "model" in full_response:
                parts = full_response.split("model")
                response = parts[-1].strip()
            else:
                response = full_response

            logger.info(f"Raw model response: {response}")

            # Parse diagnosis
            diagnosis, category, full_text = self.parse_diagnosis(response)

            # Format output with explanation
            result = f"### {diagnosis}\n\n"
            result += f"**AI Diagnostic Response:**\n{full_text}\n\n"
            result += f"**Clinical Information Provided:**\n{clinical_info[:300]}{'...' if len(clinical_info) > 300 else ''}"

            logger.info(f"Diagnosis completed: {category}")
            return result

        except Exception as e:
            logger.error(f"Error during diagnosis: {str(e)}")
            return f"❌ **Error during diagnosis:** {str(e)}"

    def diagnose_image(self, image_path: str) -> str:
        """Diagnose dengue from medical image"""
        if not image_path or image_path.strip() == "":
            return "⚠️ **Error**: Please provide an image path for diagnosis."

        if not self.model_loaded:
            if not self.load_model():
                return "❌ **Error**: Failed to load model. Please check logs."

        try:
            # Validate image path
            if not os.path.exists(image_path):
                return f"❌ **Error**: Image file not found at: {image_path}"

            logger.info(f"Loading image from: {image_path}")
            image = Image.open(image_path).convert('RGB')

            prompt = self.create_image_prompt()

            # Format exactly as in training
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": prompt}
                    ]
                }
            ]

            # Apply chat template with generation prompt
            text = self.processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=False
            ).strip()

            # Process with image
            inputs = self.processor(text=text, images=[image], return_tensors="pt", padding=True)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            logger.info("Generating image-based diagnosis...")
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=100,  # Increased for explanation
                    pad_token_id=self.processor.tokenizer.eos_token_id
                )

            # Decode full response
            full_response = self.processor.decode(outputs[0], skip_special_tokens=True)

            # Extract only the assistant's response
            if "model" in full_response:
                parts = full_response.split("model")
                response = parts[-1].strip()
            else:
                response = full_response

            logger.info(f"Raw model response: {response}")

            # Parse diagnosis
            diagnosis, category, full_text = self.parse_diagnosis(response)

            # Format output
            result = f"### {diagnosis}\n\n"
            result += f"**AI Diagnostic Response:**\n{full_text}\n\n"
            result += f"**Image Analyzed:** {image_path}"

            logger.info(f"Diagnosis completed: {category}")
            return result

        except Exception as e:
            logger.error(f"Error during image diagnosis: {str(e)}")
            return f"❌ **Error during diagnosis:** {str(e)}"

    def diagnose_multimodal(self, clinical_info: str, image_path: str) -> str:
        """Diagnose dengue using both clinical text and medical image"""
        has_text = clinical_info and clinical_info.strip() != ""
        has_image = image_path and image_path.strip() != ""

        if not has_text and not has_image:
            return "⚠️ **Error**: Please provide either clinical information or an image (or both) for diagnosis."

        results = ["# 🩺 Dengue Fever Diagnosis Results\n\n"]

        # Clinical text diagnosis
        if has_text:
            results.append("## 📋 Clinical Data Analysis\n\n")
            text_result = self.diagnose_text(clinical_info)
            results.append(text_result)
            results.append("\n\n---\n\n")

        # Image diagnosis
        if has_image:
            results.append("## 🖼️ Medical Image Analysis\n\n")
            image_result = self.diagnose_image(image_path)
            results.append(image_result)

        return "".join(results)

# Initialize the diagnostic system
diagnostic_system = DengueDiagnosticSystem()
print("✓ Diagnostic system initialized!")

## Launch Gradio Interface

In [None]:
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="Dengue Fever Diagnostic System") as demo:
    gr.Markdown(
        """
        # 🩺 Dengue Fever Diagnostic System

        ---
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 📝 Input Clinical Information")

            clinical_input = gr.Textbox(
                label="Clinical Information",
                placeholder="""Enter patient symptoms, lab values, and clinical history...

Example format:
Patient from Bangalore, fever duration: 10 days, current temperature: 100.0°F
Lab values - WBC: 5.0, Platelet: 140.0
Present symptoms: severe headache, metallic taste, appetite loss
Absent symptoms: pain behind eyes, joint/muscle aches""",
                lines=10
            )

            image_input = gr.Textbox(
                label="Medical Image Path (optional)",
                placeholder="/path/to/medical/image.png or leave empty",
                lines=1
            )

            with gr.Row():
                clear_btn = gr.Button("🗑️ Clear", variant="secondary")
                diagnose_btn = gr.Button("🔍 Diagnose", variant="primary", scale=2)

        with gr.Column(scale=1):
            gr.Markdown("### 📊 Diagnosis Result")
            output = gr.Markdown(label="Result")

    gr.Markdown("""
    ---
    ### 📋 Example Cases (Click to load)
    """)

    gr.Examples(
        examples=[
            [
                """Patient from New Delhi, fever duration: 4 days, current temperature: 104.0°F.
Lab values - WBC: 1.0, Hemoglobin: 9.0, Hematocrit: 22.0, Platelet: 80.0.
Present symptoms: pain behind eyes, joint/muscle aches.
Absent symptoms: severe headache, metallic taste, appetite loss, abdominal pain, nausea/vomiting, diarrhea.
Clinical notes: Low WBC and platelet count concerning for dengue.""",
                ""
            ],
            [
                """Patient from Bangalore, fever duration: 10 days, current temperature: 100.0°F.
Lab values - WBC: 5.0, Hemoglobin: 15.0, Platelet: 140.0.
Present symptoms: severe headache, metallic taste, appetite loss, abdominal pain, diarrhea.
Absent symptoms: pain behind eyes, joint/muscle aches, nausea/vomiting.
Clinical notes: Normal platelet count, symptoms not specific to dengue.""",
                ""
            ],
            [
                """Patient from Jamaica, fever duration: 5 days, current temperature: 104.0°F.
Lab values - WBC: 5.0, Platelet: 120.0.
Present symptoms: metallic taste, appetite loss, abdominal pain, nausea/vomiting.
Absent symptoms: pain behind eyes, joint/muscle aches.
Symptom status unknown for: severe headache, diarrhea.""",
                ""
            ],
            [
                """Patient from Miami, no fever reported, normal temperature.
Lab values - WBC: 7.0, Platelet: 200.0.
Present symptoms: mild cough, runny nose.
Absent symptoms: fever, headache, pain behind eyes, joint/muscle aches, rash.
Clinical notes: Symptoms more consistent with common cold.""",
                ""
            ]
        ],
        inputs=[clinical_input, image_input],
        label="Click an example to load it"
    )


    # Event handlers
    def clear_inputs():
        return "", "", ""

    diagnose_btn.click(
        fn=diagnostic_system.diagnose_multimodal,
        inputs=[clinical_input, image_input],
        outputs=output
    )

    clear_btn.click(
        fn=clear_inputs,
        inputs=[],
        outputs=[clinical_input, image_input, output]
    )

# Launch interface
print("\n🚀 Launching Gradio interface...")
demo.launch(share=True, debug=True)