# VisionEmbeddingModel Tutorial

This notebook demonstrates how to use the `VisionEmbeddingModel` for medical imaging tasks in PyHealth.

**Contributors:** Josh Steier 


**Overview:**
- Load a medical imaging dataset (MIMIC-CXR or custom)
- Configure the `VisionEmbeddingModel` with different backbones
- Build an end-to-end classification pipeline
- Train and evaluate on chest X-ray classification

## 1. Environment Setup

Configure deterministic behavior and import required libraries.

In [None]:
import os
import random
import tempfile
import shutil

import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from pyhealth.datasets import create_sample_dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.processors import ImageProcessor

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

## 2. Load Dataset

**Option A:** Synthetic data (default, no download required)

**Option B:** MIMIC-CXR (requires PhysioNet credentials)

In [None]:
USE_SYNTHETIC = True
MIMIC_CXR_ROOT = "/path/to/physionet.org/files/mimic-cxr-jpg/2.0.0"

if USE_SYNTHETIC:
    DATA_ROOT = tempfile.mkdtemp(prefix="chest_xray_")
    images_dir = os.path.join(DATA_ROOT, "images")
    os.makedirs(images_dir, exist_ok=True)
    
    samples = []
    for i in range(200):
        label = i % 2
        img_array = np.random.normal(80, 25, (224, 224))
        y, x = np.ogrid[:224, :224]
        left_mask = ((x - 72)**2 / 3000 + (y - 112)**2 / 8000) < 1
        right_mask = ((x - 152)**2 / 3000 + (y - 112)**2 / 8000) < 1
        img_array[left_mask] -= 20
        img_array[right_mask] -= 20
        
        if label == 1:
            cx, cy = 112 + np.random.randint(-50, 50), 112 + np.random.randint(-30, 30)
            radius = np.random.randint(15, 40)
            opacity_mask = (x - cx)**2 + (y - cy)**2 <= radius**2
            img_array[opacity_mask] += 60 + np.random.randint(0, 40)
        
        img_array = np.clip(img_array, 0, 255).astype(np.uint8)
        img = Image.fromarray(img_array, mode='L')
        img_path = os.path.join(images_dir, f"cxr_{i:04d}.png")
        img.save(img_path)
        
        samples.append({
            "patient_id": f"p{i // 4}",
            "visit_id": f"v{i}",
            "image": img_path,
            "label": label,
        })
    
    print(f"Created {len(samples)} synthetic images in {DATA_ROOT}")
    
else:
    DATA_ROOT = MIMIC_CXR_ROOT
    import pandas as pd
    
    metadata = pd.read_csv(f"{DATA_ROOT}/mimic-cxr-2.0.0-metadata.csv")
    labels = pd.read_csv(f"{DATA_ROOT}/mimic-cxr-2.0.0-chexpert.csv")
    df = metadata.merge(labels, on=["subject_id", "study_id"])
    
    def build_path(row):
        subject = f"p{row['subject_id']}"
        return f"{DATA_ROOT}/files/{subject[:3]}/{subject}/s{row['study_id']}/{row['dicom_id']}.jpg"
    
    df["image_path"] = df.apply(build_path, axis=1)
    df = df[df["image_path"].apply(os.path.exists)].head(1000)
    df["label"] = (df["No Finding"] == 1.0).astype(int)
    
    samples = [
        {"patient_id": str(r["subject_id"]), "visit_id": str(r["study_id"]), 
         "image": r["image_path"], "label": r["label"]}
        for _, r in df.iterrows()
    ]
    print(f"Loaded {len(samples)} MIMIC-CXR images from {DATA_ROOT}")

## 3. Prepare PyHealth Dataset

Create a SampleDataset using `create_sample_dataset()` with ImageProcessor for the image field.

In [None]:
image_mode = "L" if USE_SYNTHETIC else "RGB"

sample_dataset = create_sample_dataset(
    samples=samples,
    input_schema={"image": "image"},
    output_schema={"label": "binary"},
    input_processors={"image": ImageProcessor(image_size=224, mode=image_mode)},
    dataset_name="chest_xray",
)

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.15, 0.15], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

## 4. Inspect Batch Structure

Build PyHealth dataloaders and verify the keys and tensor shapes emitted before training.

In [None]:
BATCH_SIZE = 16

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE)
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE)

first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print(batch_summary)

label_targets = first_batch["label"]
if hasattr(label_targets, "shape"):
    preview = label_targets[:5].cpu().tolist()
else:
    preview = list(label_targets)[:5]
print(f"Sample labels: {preview}")

In [None]:
batch = next(iter(train_loader))

for key, value in batch.items():
    if hasattr(value, "shape"):
        print(f"{key}: shape={value.shape}, dtype={value.dtype}")
    else:
        print(f"{key}: {type(value).__name__}(len={len(value)})")

## 4.5. VisionEmbeddingModel: Output Shape Verification

The core transformation performed by `VisionEmbeddingModel`:

```
Input:  (B, C, H, W)  →  e.g., (16, 1, 224, 224)
Output: (B, P+1, E)   →  e.g., (16, 50, 128)  # P patches + 1 CLS token
```

This converts images to sequences of token embeddings, compatible with Transformer architectures.

In [None]:
from pyhealth.models import VisionEmbeddingModel

# Test all backbones and verify output shapes
backbone_configs = [
    ("patch", {"backbone": "patch", "patch_size": 16}),
    ("cnn", {"backbone": "cnn"}),
    ("resnet18", {"backbone": "resnet18", "pretrained": True}),
    ("resnet50", {"backbone": "resnet50", "pretrained": True}),
]

print("VisionEmbeddingModel Output Shape Verification")
print("=" * 60)
print(f"\nInput shape: {batch['image'].shape}  # (B, C, H, W)")
print()

for name, config in backbone_configs:
    model = VisionEmbeddingModel(
        dataset=sample_dataset,
        embedding_dim=128,
        use_cls_token=True,
        **config,
    )
    
    # Forward pass
    with torch.no_grad():
        output = model({"image": batch["image"]})
    
    output_shape = output["image"].shape
    info = model.get_output_info("image")
    n_params = sum(p.numel() for p in model.parameters())
    
    print(f"{name}:")
    print(f"  Output shape: {tuple(output_shape)}  # (B, num_tokens, E)")
    print(f"  Tokens: {info['num_tokens']} = {info['num_patches']} patches + 1 CLS")
    print(f"  Parameters: {n_params:,}")
    print()

## 5. Instantiate VisionEmbeddingModel

Create the PyHealth VisionEmbeddingModel with custom hyperparameters and inspect the parameter footprint.

In [None]:
from pyhealth.models import VisionEmbeddingModel

backbone_configs = [
    ("patch", {"backbone": "patch", "patch_size": 16}),
    ("cnn", {"backbone": "cnn"}),
    ("resnet18", {"backbone": "resnet18", "pretrained": True}),
    ("resnet50", {"backbone": "resnet50", "pretrained": True}),
]

print("VisionEmbeddingModel Backbone Comparison")
print("=" * 50)

for name, config in backbone_configs:
    model = VisionEmbeddingModel(
        dataset=sample_dataset,
        embedding_dim=128,
        use_cls_token=True,
        **config,
    )
    info = model.get_output_info("image")
    n_params = sum(p.numel() for p in model.parameters())
    print(f"\n{name}:")
    print(f"  Tokens: {info['num_tokens']} ({info['num_patches']} patches + CLS)")
    print(f"  Parameters: {n_params:,}")

In [None]:
class VisionClassifier(nn.Module):
    """End-to-end classifier using VisionEmbeddingModel."""
    
    def __init__(self, dataset, embedding_dim=128, backbone="resnet18"):
        super().__init__()
        
        self.vision_encoder = VisionEmbeddingModel(
            dataset=dataset,
            embedding_dim=embedding_dim,
            backbone=backbone,
            pretrained=True,
            use_cls_token=True,
            dropout=0.1,
        )
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, embedding_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim, 1),
        )
        
        self.feature_keys = ["image"]
        self.label_key = "label"
        self.mode = "binary"  
    
    @property
    def device(self):
        return self.vision_encoder.device
    
    def forward(self, **kwargs):
        images = kwargs["image"].to(self.device)
        labels = kwargs["label"].to(self.device).float()
        
        # Get embeddings: (B, num_tokens, E)
        embeddings = self.vision_encoder({"image": images})
        
        # Use CLS token (first token) for classification
        cls_token = embeddings["image"][:, 0, :]
        
        # Classification head
        logits = self.classifier(cls_token)
        
        # Flatten to 1D for binary classification
        logits = logits.squeeze(-1)  # (B,)
        labels = labels.squeeze(-1)  # (B,)
        
        # Compute loss
        loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
        
        # Compute probabilities - must be (B,) for binary classification
        y_prob = torch.sigmoid(logits)
        
        return {
            "loss": loss,
            "y_prob": y_prob,  # (B,)
            "y_true": labels,   # (B,)
        }


model = VisionClassifier(
    dataset=sample_dataset,
    embedding_dim=128,
    backbone="resnet18",
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Mode: {model.mode}")
print(f"Total parameters: {total_params:,}")

## 7. Configure Trainer

Wrap the model with the PyHealth Trainer to handle optimisation, gradient clipping, and metric logging.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc"],
    device=str(device),
    enable_logging=False,
)

training_config = {
    "epochs": 5,
    "optimizer_params": {"lr": 1e-4},
    "max_grad_norm": 1.0,
    "monitor": "roc_auc",}

## 8. Train the Model

Run multiple epochs with gradient clipping, scheduler updates, and logging of loss/metrics per epoch.

In [None]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **training_config,
)

## 9. Evaluate on Validation/Test Split

Switch to evaluation mode, collect predictions for validation and test splits, and compute AUROC and loss.

In [None]:
evaluation_results = {}
for split_name, loader in {"validation": val_loader, "test": test_loader}.items():
    if loader is None:
        continue
    metrics = trainer.evaluate(loader)
    evaluation_results[split_name] = metrics
    formatted = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
    print(f"{split_name.title()} metrics: {formatted}")

## 10. Inspect Sample Predictions

Run a quick inference pass on the validation split to preview predicted probabilities alongside ground-truth labels.

In [None]:
target_loader = val_loader if val_loader is not None else train_loader

y_true, y_prob, mean_loss = trainer.inference(target_loader)
positive_prob = y_prob if y_prob.ndim == 1 else y_prob[..., -1]
preview_pairs = list(zip(y_true[:5].tolist(), positive_prob[:5].tolist()))
print(f"Mean loss: {mean_loss:.4f}")
print(f"Preview (label, positive_prob): {preview_pairs}")

## 11. Cleanup

In [None]:
if USE_SYNTHETIC:
    shutil.rmtree(DATA_ROOT)
    print(f"Cleaned up: {DATA_ROOT}")

## Summary

### VisionEmbeddingModel

Converts medical images to sequence embeddings: `(B, C, H, W) → (B, num_tokens, E)`

**Backbones:**
| Backbone | Tokens (224px) | Parameters | Use Case |
|----------|----------------|------------|----------|
| `patch` | 196 + CLS | ~300K | Lightweight, ViT-style |
| `cnn` | 49 + CLS | ~100K | Good inductive bias |
| `resnet18` | 49 + CLS | ~11M | Pretrained features |
| `resnet50` | 49 + CLS | ~24M | Best pretrained features |

**Key Features:**
- Pretrained ImageNet weights (ResNet backbones)
- Optional [CLS] token for classification
- Positional embeddings
- Compatible with PyHealth's `ImageProcessor`

**Multimodal Integration:**
```python
# Vision: (B, P, E) where P = num_patches
# Text: (B, T, E) where T = text_tokens  
# EHR: (B, S, E) where S = sequence_length
combined = torch.cat([vision_emb, text_emb, ehr_emb], dim=1)
# → (B, P+T+S, E) → Transformer fusion
```