# VisionEmbeddingModel Tutorial

This notebook demonstrates how to use the `VisionEmbeddingModel` for medical imaging tasks in PyHealth.

**Contributors:** Josh Steier 
**Overview:**
- Load a medical imaging dataset (MIMIC-CXR or custom)
- Configure the `VisionEmbeddingModel` with different backbones
- Build an end-to-end classification pipeline
- Train and evaluate on chest X-ray classification

## 1. Environment Setup

Configure deterministic behavior and import required libraries.

In [1]:
import os
import random

import numpy as np
import torch
import torch.nn as nn

from pyhealth.datasets import ChestXray14Dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.models import VisionEmbeddingModel

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\Users\637682\AppData\Roaming\Python\Python312\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\Users\637682\AppData\Roaming\Python\Python312\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "C:\Users\637682\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelapp.py", line 739, in start
    self.io

AttributeError: _ARRAY_API not found


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\Users\637682\AppData\Roaming\Python\Python312\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\Users\637682\AppData\Roaming\Python\Python312\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "C:\Users\637682\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelapp.py", line 739, in start
    self.io

ImportError: 
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.



AttributeError: `np.unicode_` was removed in the NumPy 2.0 release. Use `np.str_` instead.

## 1. Load Dataset

ChestX-ray14 dataset will be automatically downloaded from NIH (~5k images with `partial=True`).

In [None]:
dataset = ChestXray14Dataset(
    root="./data/chestxray14",
    download=True,   # Auto-download from NIH
    partial=True,    # Only first archive (~5k images)
    dev=True,
)

## 2. Set Task

In [None]:
sample_dataset = dataset.set_task()

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError(
        "The task did not produce any samples. "
        "Disable dev mode or check dataset path."
    )

## 3. Split Dataset

In [None]:
train_ds, val_ds, test_ds = split_by_patient(
    sample_dataset, [0.7, 0.15, 0.15], seed=SEED
)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

## 4. Create DataLoaders

In [None]:
BATCH_SIZE = 16

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE)
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE)

## 5. Inspect Batch Structure

In [None]:
batch = next(iter(train_loader))

for key, value in batch.items():
    if hasattr(value, "shape"):
        print(f"{key}: shape={value.shape}, dtype={value.dtype}")
    else:
        print(f"{key}: {type(value).__name__}(len={len(value)})")

## 6. Test VisionEmbeddingModel Backbones

In [None]:
backbone_configs = [
    ("patch", {"backbone": "patch", "patch_size": 16}),
    ("cnn", {"backbone": "cnn"}),
    ("resnet18", {"backbone": "resnet18", "pretrained": True}),
]

# Get image field name from schema
image_field = list(sample_dataset.input_schema.keys())[0]

print("VisionEmbeddingModel backbone comparison:\n")

for name, config in backbone_configs:
    model = VisionEmbeddingModel(
        dataset=sample_dataset,
        embedding_dim=128,
        use_cls_token=True,
        **config,
    )
    
    info = model.get_output_info(image_field)
    n_params = sum(p.numel() for p in model.parameters())
    
    print(f"{name}:")
    print(f"  Output: {info['num_tokens']} tokens ({info['num_patches']} patches + CLS)")
    print(f"  Parameters: {n_params:,}")
    print()

## 7. Build Classifier with VisionEmbeddingModel

In [None]:
class VisionClassifier(nn.Module):
    """End-to-end classifier using VisionEmbeddingModel."""
    
    def __init__(self, dataset, embedding_dim=128, backbone="resnet18", num_classes=14):
        super().__init__()
        
        self.vision_encoder = VisionEmbeddingModel(
            dataset=dataset,
            embedding_dim=embedding_dim,
            backbone=backbone,
            pretrained=True,
            use_cls_token=True,
            dropout=0.1,
        )
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, embedding_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim, num_classes),
        )
        
        self.image_key = list(dataset.input_schema.keys())[0]
        self.label_key = list(dataset.output_schema.keys())[0]
    
    @property
    def device(self):
        return self.vision_encoder.device
    
    def forward(self, **kwargs):
        images = kwargs[self.image_key].to(self.device)
        labels = kwargs[self.label_key].to(self.device).float()
        
        embeddings = self.vision_encoder({self.image_key: images})
        cls_token = embeddings[self.image_key][:, 0, :]  # CLS token
        logits = self.classifier(cls_token)
        
        # Multilabel classification loss
        loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
        y_prob = torch.sigmoid(logits)
        
        return {
            "loss": loss,
            "y_prob": y_prob,
            "y_true": labels,
            "logit": logits,
        }


# ChestX-ray14 has 14 disease labels (multilabel classification)
model = VisionClassifier(
    dataset=sample_dataset,
    embedding_dim=128,
    backbone="resnet18",
    num_classes=14,
).to(device)

print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 8. Train Model

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc_samples", "f1_samples"],
    device=device,
)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=10,
    optimizer_params={"lr": 1e-4},
    monitor="roc_auc_samples",
)

## 9. Evaluate

In [None]:
print("Test Results:")
metrics = trainer.evaluate(test_loader)
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

## Summary

`VisionEmbeddingModel` converts images to patch embeddings `(B, num_patches, E)` for:

- Standalone classification (using CLS token)
- Multimodal fusion with EHR/text embeddings

Backbones: `patch`, `cnn`, `resnet18`, `resnet50`