In [None]:
import subprocess



# Uninstall old gradio first, then install latest

subprocess.run(["pip", "uninstall", "-y", "gradio", "gradio-client"], check=False)

subprocess.run(["pip", "install", "-q", "gradio", "pillow", "accelerate", "transformers"], check=True)



from kaggle_secrets import UserSecretsClient

secrets = UserSecretsClient()

hf_token = secrets.get_secret("HF_TOKEN")



import torch

from transformers import AutoProcessor, AutoModelForCausalLM

from transformers import StoppingCriteria, StoppingCriteriaList

import numpy as np

from PIL import Image

import time

import gradio as gr

import re



print("Gradio version:", gr.__version__)



model_id = "google/medgemma-4b-it"

print("Loading MedGemma 1.5...")



processor = AutoProcessor.from_pretrained(model_id, token=hf_token)

model = AutoModelForCausalLM.from_pretrained(

    model_id,

    torch_dtype=torch.bfloat16,

    device_map="auto",

    token=hf_token

)

print("Model loaded on:", next(model.parameters()).device)



stop_token_id = processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")



class StopOnEndOfTurn(StoppingCriteria):

    def __call__(self, input_ids, scores, **kwargs):

        return input_ids[0, -1].item() == stop_token_id



stop_criteria = StoppingCriteriaList([StopOnEndOfTurn()])





def truncate_at_sentence(text, max_chars=600):

    # Clean up MedGemma QA artifacts (e.g. \boxed{})

    text = re.sub(r'(?i)final answer:.*', '', text)

    text = re.sub(r'(?i)the final answer is.*', '', text)

    text = text.replace('\\boxed{', '').replace('}', '')

    text = text.strip()



    if len(text) <= max_chars:

        garbage_phrases = [

            "example reports", "reference purposes",

            "Please consult", "Thank You", "Thankyou",

            "This response provides", "NOT constitute",

            "intellectual property", "femtogram",

            "nanosecond", "election", "calamit",

            "Do Not use", "etcetera", "disclaimer",

            "Disclaimer", "for your reference"

        ]

        for g in garbage_phrases:

            if g.lower() in text.lower():

                idx = text.lower().index(g.lower())

                if idx > 30:

                    text = text[:idx].rstrip(" .,;:")

                    break

        return text.strip()



    text = text[:max_chars]

    for punct in [". ", ".\n", "! ", "? "]:

        last = text.rfind(punct)

        if last > max_chars // 2:

            return text[:last + 1].strip()

    return text.strip()





def generate_with_image(image_pil, prompt, max_tokens=200):

    messages = [

        {

            "role": "user",

            "content": [

                {"type": "image", "image": image_pil},

                {"type": "text", "text": prompt}

            ]

        }

    ]

    inputs = processor.apply_chat_template(

        messages,

        add_generation_prompt=True,

        tokenize=True,

        return_dict=True,

        return_tensors="pt"

    ).to(model.device)



    with torch.no_grad():

        output = model.generate(

            **inputs,

            max_new_tokens=max_tokens,

            stopping_criteria=stop_criteria,

            repetition_penalty=1.2,

            do_sample=False

        )



    new_tokens = output[0][inputs["input_ids"].shape[1]:]

    result = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)

    return truncate_at_sentence(result, max_chars=600)





def generate_text_only(prompt, max_tokens=100):

    input_text = "user\n" + prompt + "\nmodel\n"

    inputs = processor.tokenizer(

        input_text, return_tensors="pt"

    ).to(model.device)



    with torch.no_grad():

        output = model.generate(

            **inputs,

            max_new_tokens=max_tokens,

            stopping_criteria=stop_criteria,

            repetition_penalty=1.2,

            do_sample=False

        )



    new_tokens = output[0][inputs["input_ids"].shape[1]:]

    result = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)

    return truncate_at_sentence(result, max_chars=400)





def rule_based_triage(findings, context):

    findings_text = findings.lower()



    # Advanced sentence-level negation checker

    def has_positive_mention(keywords, text):

        # Split text into independent sentences/clauses

        sentences = re.split(r'[.;\n]', text)

        negations = ["no ", "not ", "without ", "negative ", "absence ", "clear ", "unremarkable", "free of"]

        

        for sentence in sentences:

            sentence = sentence.strip()

            if not sentence: continue

            

            for kw in keywords:

                # Look for exact word match

                match = re.search(r'\b' + re.escape(kw) + r'\b', sentence)

                if match:

                    # Grab everything in the sentence BEFORE the keyword

                    preceding_text = sentence[:match.start()]

                    # If a negation word is anywhere before the keyword in this sentence, ignore it

                    if not any(neg in preceding_text for neg in negations):

                        return True

        return False



    # CRITICAL ðŸ”´ - Medical emergencies requiring immediate intervention

    critical_keywords = [

        "tension pneumothorax", "massive pleural effusion", 

        "pulmonary edema", "edema", "severe pulmonary edema",

        "aortic dissection", "cardiac tamponade", 

        "whiteout", "pneumoperitoneum", "tracheal deviation"

    ]

    if has_positive_mention(critical_keywords, findings_text):

        return "CRITICAL"



    # URGENT ðŸŸ¡ - Acute abnormalities requiring prompt attention

    urgent_keywords = [

        "consolidation", "pneumonia", "pleural effusion", "cardiomegaly",

        "enlarged heart", "cardiac enlargement", "pacemaker", 

        "airspace opacity", "increased opacity", "ground glass", "opacities",

        "infiltrate", "atelectasis", "mass", "nodule", "fracture", "pneumothorax",

        "pleural fluid"

    ]

    if has_positive_mention(urgent_keywords, findings_text):

        return "URGENT"



    # ROUTINE ðŸŸ¢ - No positive findings detected

    return "ROUTINE"





def run_pipeline(image, patient_context):

    if image is None:

        return ("Please upload an image", "", "", "", "", "")



    if not patient_context or patient_context.strip() == "":

        patient_context = "No clinical context provided"



    log = []

    start_total = time.time()



    try:

        if isinstance(image, np.ndarray):

            img_pil = Image.fromarray(image.astype("uint8")).convert("RGB")

        else:

            img_pil = image.convert("RGB")

        img_pil = img_pil.resize((512, 512))

    except Exception as e:

        return ("Error loading image: " + str(e), "", "", "", "", "")



    # STEP 1

    t = time.time()

    log.append("Step 1: Extracting visual findings with MedGemma...")

    try:

        findings = generate_with_image(

            img_pil,

            "You are an expert radiologist. Carefully analyze this chest X-ray.\n"

            "Patient context: " + patient_context + "\n\n"

            "Write exactly 5 lines, one per region. Keep it professional and brief:\n"

            "Lungs: [describe findings]\n"

            "Heart: [describe findings]\n"

            "Mediastinum: [describe findings]\n"

            "Pleura: [describe findings]\n"

            "Bones: [describe findings]\n",

            250

        )

    except Exception as e:

        findings = "Error in visual analysis: " + str(e)

    log.append("Step 1 complete (" + str(round(time.time()-t, 1)) + "s)")



    # STEP 2

    t = time.time()

    log.append("Step 2: Classifying triage level...")

    triage_level = rule_based_triage(findings, patient_context)

    triage_descriptions = {

        "CRITICAL": "Critical findings identified requiring immediate radiologist attention.",

        "URGENT": "Significant findings identified requiring prompt radiologist review.",

        "ROUTINE": "No acute findings identified. Standard reporting timeline applies."

    }

    triage_reason = triage_descriptions[triage_level]

    log.append("Step 2 complete: " + triage_level + " (" + str(round(time.time()-t, 1)) + "s)")



    # STEP 3

    t = time.time()

    log.append("Step 3: Generating radiology report with MedGemma...")

    try:

        findings_paragraph = generate_text_only(

            "You are a radiologist. Write exactly 2 professional sentences "

            "summarizing these chest X-ray findings. "

            "Stop after 2 sentences. No extra commentary.\n\n"

            "Findings: " + findings[:350],

            90

        )

    except Exception as e:

        findings_paragraph = findings[:200]



    impression_map = {

        "CRITICAL": (

            "1. Critical cardiopulmonary finding identified.\n"

            "2. Immediate radiologist review and clinical intervention required."

        ),

        "URGENT": (

            "1. Significant cardiopulmonary finding identified.\n"

            "2. Prompt radiologist review and clinical correlation recommended."

        ),

        "ROUTINE": (

            "1. No acute cardiopulmonary abnormality identified.\n"

            "2. Findings within normal limits for patient age and clinical context."

        )

    }

    recommendation_map = {

        "CRITICAL": "Immediate radiologist review and clinical intervention required.",

        "URGENT": "Radiologist review recommended within 1 hour. Clinical correlation advised.",

        "ROUTINE": "No immediate follow-up required. Routine clinical management."

    }

    report = (

        "**EXAMINATION:** PA chest radiograph\n\n"

        "**INDICATION:** " + patient_context + "\n\n"

        "**COMPARISON:** None available\n\n"

        "**FINDINGS:** " + findings_paragraph + "\n\n"

        "**IMPRESSION:**\n" + impression_map[triage_level] + "\n\n"

        "**RECOMMENDATION:** " + recommendation_map[triage_level]

    )

    log.append("Step 3 complete (" + str(round(time.time()-t, 1)) + "s)")



    # STEP 4

    t = time.time()

    log.append("Step 4: Creating patient summary...")

    patient_summaries = {

        "CRITICAL": (

            "Your chest X-ray has shown some findings that need urgent attention "

            "from your doctor right away. "

            "Please do not leave the hospital â€” your care team has been alerted "

            "and will see you very soon. "

            "This does not mean the worst, but it is important we act quickly "

            "to keep you safe."

        ),

        "URGENT": (

            "Your chest X-ray has shown some findings that your doctor needs "

            "to review soon. "

            "This is not an emergency, but your care team will follow up with "

            "you shortly to discuss next steps. "

            "Please let a nurse know if your symptoms get worse while you wait."

        ),

        "ROUTINE": (

            "Your chest X-ray looks generally normal and no urgent problems "

            "were found. "

            "Your doctor will review the full results with you at your appointment. "

            "If your symptoms change or worsen before then, please contact "

            "your healthcare provider."

        )

    }

    patient_summary = patient_summaries[triage_level]

    log.append("Step 4 complete (" + str(round(time.time()-t, 1)) + "s)")



    # STEP 5

    t = time.time()

    log.append("Step 5: Running safety validation...")

    safety_flags = []

    if triage_level == "CRITICAL":

        safety_flags.append("CRITICAL - Immediate radiologist review required")

    elif triage_level == "URGENT":

        safety_flags.append("URGENT - Radiologist review within 1 hour")

    else:

        safety_flags.append("ROUTINE - Standard reporting timeline")

    safety_flags.append("AI-generated draft - Must be verified by a qualified radiologist")

    safety_flags.append("For demonstration purposes only - Not a medical device")

    log.append("Step 5 complete (" + str(round(time.time()-t, 1)) + "s)")

    log.append("Total pipeline time: " + str(round(time.time()-start_total, 1)) + "s")



    emoji_map = {"CRITICAL": "ðŸ”´", "URGENT": "ðŸŸ¡", "ROUTINE": "ðŸŸ¢"}

    emoji = emoji_map.get(triage_level, "âšª")



    return (

        "## " + emoji + " " + triage_level + "\n\n" + triage_reason,

        "## Detailed Findings\n\n" + findings,

        "## Radiology Report\n\n" + report,

        "## Patient Summary\n\n" + patient_summary,

        "\n".join(safety_flags),

        "\n".join(log)

    )





with gr.Blocks(title="ChestAI Copilot") as demo:

    gr.Markdown(

        "# ChestAI Copilot\n"

        "### Agentic Chest X-Ray Triage and Reporting Assistant\n"

        "**Powered by Google MedGemma 1.5 (HAI-DEF)** | **MedGemma Impact Challenge**\n\n"

        "> Demonstration only. Not a medical device. "

        "All outputs must be verified by a qualified radiologist.\n\n"

        "---\n\n"

        "**How it works:** Upload a chest X-ray and the 5-step "

        "AI agent pipeline produces Triage + Report + Patient Summary"

    )



    with gr.Row():

        with gr.Column(scale=1):

            img_input = gr.Image(

                label="Upload Chest X-Ray",

                type="numpy",

                height=300

            )

            context_input = gr.Textbox(

                label="Clinical Context",

                placeholder="e.g., 65yo male, cough and fever x3 days, history of CHF",

                lines=3

            )

            gr.Markdown(

                "**Test contexts to copy/paste:**\n"

                "- **Normal:** 28yo female, routine pre-employment screening.\n"

                "- **Cardiomegaly:** 72yo female, shortness of breath, history of heart failure and pacemaker.\n"

                "- **Pneumonia:** 55yo male, fever, cough, right-sided chest pain x3 days.\n"

                "- **Edema:** 68yo male, acute severe respiratory distress, audible crackles."

            )

            submit_btn = gr.Button("Analyze Chest X-Ray", variant="primary")



        with gr.Column(scale=2):

            triage_output = gr.Markdown(label="Triage Level")

            safety_output = gr.Textbox(label="Safety Flags", lines=3)

            with gr.Tabs():

                with gr.TabItem("Radiology Report"):

                    report_output = gr.Markdown()

                with gr.TabItem("Detailed Findings"):

                    findings_output = gr.Markdown()

                with gr.TabItem("Patient Summary"):

                    patient_output = gr.Markdown()

                with gr.TabItem("Pipeline Log"):

                    log_output = gr.Textbox(lines=10, label="Agent Execution Log")



    gr.Markdown(

        "---\n"

        "### Architecture\n"

        "```\n"

        "Upload CXR -> [Agent 1: Visual Analysis (MedGemma)] "

        "-> [Agent 2: Triage Classification] "

        "-> [Agent 3: Report Generation (MedGemma)] "

        "-> [Agent 4: Patient Summary] "

        "-> [Agent 5: Safety Validation] -> Dashboard\n"

        "```\n\n"

        "**Model:** MedGemma 1.5 4B-IT (multimodal) from Google HAI-DEF\n\n"

        "*Built for the MedGemma Impact Challenge on Kaggle*"

    )



    submit_btn.click(

        fn=run_pipeline,

        inputs=[img_input, context_input],

        outputs=[

            triage_output,

            findings_output,

            report_output,

            patient_output,

            safety_output,

            log_output

        ]

    )



demo.queue().launch(share=True, debug=True, show_error=True)

Found existing installation: gradio 6.6.0
Uninstalling gradio-6.6.0:
  Successfully uninstalled gradio-6.6.0
Found existing installation: gradio_client 2.1.0
Uninstalling gradio_client-2.1.0:
  Successfully uninstalled gradio_client-2.1.0
Gradio version: 6.6.0
Loading MedGemma 1.5...


The image processor of type `Gemma3ImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 
`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/883 [00:00<?, ?it/s]

Model loaded on: cuda:0
* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://e1845c7b38541fffa3.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
