# 🧠 Visual AI in Healthcare with FiftyOne – VQA and Classification with MedGemma  
**Harnessing MedGemma’s multimodal medical expertise for visual question answering and image classification**

This notebook is part of the **“Visual AI in Healthcare with FiftyOne”** workshop. In this hands-on example, we use the **instruction-tuned MedGemma 4B** model (from Hugging Face) to run **visual question answering (VQA)** and **classification** tasks on a small medical dataset via FiftyOne’s Hugging Face integration.

🔬 **What you’ll learn in this notebook:**

- How to **load the MedXpertQA dataset** from Hugging Face using `fiftyone.utils.huggingface`  
- How to **run MedGemma (4B-it) locally** with a custom system prompt and input image  
- How to **perform VQA and classification** using the `image-text-to-text` pipeline  
- How to **store MedGemma’s output** back into your FiftyOne samples for analysis  

⚙️ **MedGemma quick facts**:
- Developed by **Google**, MedGemma is built on Gemma 3 with a **SigLIP image encoder**
- Trained on diverse medical images and text (CXR, dermoscopy, histopathology, ophthalmology)
- Supports **instruction tuning** for healthcare applications like report generation, diagnosis support, and multimodal Q&A  
- Available on [Hugging Face](https://huggingface.co/google/medgemma-4b-it) and via [Google Model Garden](https://cloud.google.com/model-garden)  

📚 **Part of the notebook series:**
1. `01_load_arcade_dataset.ipynb` – Load and visualize the ARCADE dataset.  
2. `02_load_deeplesion_balanced.ipynb` – Curate and balance the DeepLesion dataset.  
3. `03_vlms_analysis_arcade.ipynb` – Use VFMs like NVLabs_CRADIOV3 in dataset undersatnding for ARCADE. 
4. `04_finetune_yolo8_stenosis.ipynb` – Train and integrate YOLOv8 for stenosis detection.  
5. `05_medsam2_ct_scan.ipynb` – Run MedSAM2 on CT scans for segmentation.  
6. `06_nvidia_vista_segmentation.ipynb` – Explore NVIDIA-VISTA-3D.  
7. `07_medgemma_vqa.ipynb` – Perform visual question answering and classification with MedGemma.

All notebooks are standalone but are best experienced sequentially.


### ✅ Requirements

Before running this notebook, make sure you have the required libraries installed.

You can install them using `pip`:

```bash
pip install huggingface_hub>=0.20.0 torch torchvision transformers accelerate bitsandbytes

### 📦 Load the SLAKE Dataset from Hugging Face

We start by loading a subset of the **SLAKE** dataset — a benchmark designed to evaluate expert-level reasoning in medical visual question answering.  
This dataset includes a combination of images and textual medical questions across various specialties.

We use the `load_from_hub()` utility from `fiftyone.utils.huggingface` to fetch the dataset directly from Hugging Face and import it into FiftyOne.

🔹 **Key actions in this step:**
- Load 10 multimodal samples for quick experimentation
- Assign a name `SLAKE` to the dataset for reference in the FiftyOne App
- Set `overwrite=True` to reset the dataset if it already exists locally

In [None]:
import fiftyone as fo

from fiftyone.utils.huggingface import load_from_hub

dataset = load_from_hub(
    "Voxel51/SLAKE",
    name="SLAKE",
    overwrite=True,
    max_samples=50
    )

In [None]:
dataset

### Register and Load MedGemma from a Custom Zoo Source - Setup Zoo Model 

FiftyOne allows you to register custom model sources and use them just like built-in zoo models.

In this step, we:
- Register a **custom Zoo model source** pointing to a GitHub repo that defines the MedGemma integration
- Download the **instruction-tuned MedGemma 4B** model (`https://github.com/harpreetsahota204/medgemm`)
- Load the model using FiftyOne’s `load_zoo_model()` interface with optional quantization for efficient inference

🔹 **Why use a custom zoo?**  
It lets us easily manage models that are not part of the default FiftyOne model zoo but are useful for specific domains — like MedGemma for healthcare.

> If running this for the first time, you may need to uncomment `install_requirements=True` to automatically install dependencies listed by the model repo.


In [None]:
import fiftyone.zoo as foz

foz.register_zoo_model_source("https://github.com/harpreetsahota204/medgemma", overwrite=True)

foz.download_zoo_model(
    "https://github.com/harpreetsahota204/medgemma",
    model_name="google/medgemma-4b-it", 
)

model = foz.load_zoo_model(
    "google/medgemma-4b-it",
    quantized=True,
    # install_requirements=True #run this to install requirements if they're not already
    )

### 🏷️ Run Body System Classification with MedGemma

You can use this model for a zero-shot classification task as follows, which will add a [FiftyOne Classificaton](https://docs.voxel51.com/api/fiftyone.core.labels.html#fiftyone.core.labels.Classification) to your dataset. We can apply it to classify each image in the dataset based on its **body system**.

🔹 **What’s happening here:**
- We extract the **distinct body system labels** from the dataset (`body_system.label`)
- We configure the model’s operation to `"classify"`
- We define a **custom system prompt** that guides MedGemma to classify images into **one of the known body systems**
- Finally, we run `apply_model()` to store MedGemma’s predictions in a new field called `"pred_body_system"`

This showcases MedGemma's zero-shot classification capabilities when paired with structured prompts and small medical datasets.

> You can inspect predictions later in the FiftyOne App and compare them with ground truth labels.


In [None]:
body_system_labels = dataset.distinct("modality.label")

model.operation = "classify"

model.prompt = "As a medical expert your task is to classify this image into exactly one of the following types: " + ", ".join(body_system_labels)

dataset.apply_model(model, label_field="pred_modality")

You can customize the system prompt as follows to guide the model's response. Note that we are using an existing field on the sample by passing `prompt_field="question"` into the [`apply_model`](https://docs.voxel51.com/api/fiftyone.core.dataset.html) method of the [Dataset](https://docs.voxel51.com/api/fiftyone.core.dataset.html)

In [None]:
dataset.first()['modality']
dataset.first()['pred_modality.classifications']

# Using MedGemma for VQA

You can use the model for visual question answering as shown below. This example will use the same prompt on each [Sample](https://docs.voxel51.com/api/fiftyone.core.sample.html#module-fiftyone.core.sample) in the Dataset:

This is using the default system prompt, which you can inspect as follows:

In [None]:
print(model.system_prompt)

You can customize the system prompt as follows to guide the model's response. Note that we are using an existing field on the sample by passing `prompt_field="question"` into the [`apply_model`](https://docs.voxel51.com/api/fiftyone.core.dataset.html) method of the [Dataset](https://docs.voxel51.com/api/fiftyone.core.dataset.html).

Note that if you want to parse the model output as a [FiftyOne Classification](https://docs.voxel51.com/api/fiftyone.core.labels.html#fiftyone.core.labels.Classification) then you need to very specifically prompt the model to output in a way that this integration expects, that is:


```json
{
    "classifications": [
        {
            "label": "your answer to the question",
            ...,
        }
    ]
}
```

Notice below:

In [None]:
model.operation="classify"

model.system_prompt = """You have expert-level medical knowledge in radiology, histopathology, ophthalmology, and dermatology.

You will be asked a question and are required to provide your answer. Your answer must be in the following format:

```json
{
    "classifications": [
        {
            "label": "your answer to the question",
            ...,
        }
    ]
}
```

Always return your response as valid JSON wrapped in ```json blocks and respond only with one answer.
"""

dataset.apply_model(
    model, 
    label_field="pred_answer_6", 
    prompt_field="question_6"
    )

In [None]:
dataset.first()['question_6']

'Which organ is abnormal, heart or lung?'

In [None]:
dataset.first()['answer_6']

<Classification: {
    'id': '6830a40a7a3316c437168c0e',
    'tags': [],
    'label': 'Heart',
    'confidence': None,
    'logits': None,
}>

In [None]:
dataset.first()['pred_answer_6.classifications']

[<Classification: {
     'id': '6830d5a0d69a8c1f151f7bc2',
     'tags': [],
     'label': 'Heart',
     'confidence': None,
     'logits': None,
 }>]

For open-ended generation, you can use `vqa` mode. Note that with both modes you can use a single question on each sample as seen below:

In [None]:
model.system_prompt = None # we need to clear the custom system prompt  

model.operation="vqa"

model.prompt = "Describe any anomolies in this the image that you observe."

dataset.apply_model(model, label_field="open_response")

In [None]:
print(dataset.first()["open_response"])

Okay, I will analyze the chest X-ray provided.

**Observations:**

*   **Cardiomegaly:** The heart appears enlarged (cardiomegaly). This is evident by the increased prominence of the cardiac silhouette, particularly the left ventricle.
*   **Mediastinal Devices:** There are multiple devices visible within the mediastinum, including a pacemaker/ICD device. The leads appear to be in standard positions.
*   **Pulmonary Vascularity:** The pulmonary vasculature appears relatively normal in terms of distribution and prominence.
*   **Lung Fields:** The lung fields appear clear with no obvious consolidation, effusions, or masses.
*   **Bones:** The bony structures of the ribs and sternum appear intact.
*   **Mediastinal Width:** The mediastinal width appears to be within normal limits.

**Potential Anomalies/Considerations:**

*   **Cardiomegaly:** The cardiomegaly could be due to various factors, including hypertension, valvular heart disease, cardiomyopathy, or congenital heart defects. Fur

Or use you can use a Sample field:

In [None]:
model.operation="vqa"

dataset.apply_model(
    model, 
    label_field="answer_5_vqa", 
    prompt_field="question_5"
    )

   0% ||------------------|  0/50 [3.2ms elapsed, ? remaining, ? samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


   2% |/------------------|  1/50 [6.7s elapsed, 5.5m remaining, 0.1 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


   4% |-------------------|  2/50 [8.7s elapsed, 3.5m remaining, 0.2 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


   6% |█\-----------------|  3/50 [10.7s elapsed, 2.8m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


   8% |█|-----------------|  4/50 [12.0s elapsed, 2.3m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


  10% |█/-----------------|  5/50 [17.5s elapsed, 2.6m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


  12% |██-----------------|  6/50 [19.2s elapsed, 2.3m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


  14% |██\----------------|  7/50 [20.7s elapsed, 2.1m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


  16% |███|---------------|  8/50 [23.3s elapsed, 2.0m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


  18% |███/---------------|  9/50 [32.6s elapsed, 2.5m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


  20% |███----------------| 10/50 [33.9s elapsed, 2.2m remaining, 0.3 samples/s] 

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


In [None]:
dataset.first()["question_5"]

In [None]:
dataset.first()["answer_5"]

In [None]:
dataset.first()["answer_5_vqa"]

# Evaluating the model

You can use FiftyOne's [Evaluation API](https://docs.voxel51.com/user_guide/evaluation.html) to evaluate model performance via the SDK. 

For example:

In [None]:
results = dataset.evaluate_classifications(
    "pred_modality",
    gt_field="modality",
    eval_key="eval_simple",
)

You can then review as follows:

In [None]:
results.print_report()

In [None]:
plot = results.plot_confusion_matrix()
plot.show()

You can, of course, do all of this in the App as shown here: