# TimeImageProcessor Tutorial

This notebook demonstrates how to use the `TimeImageProcessor` for multimodal PyHealth pipelines.

**Contributors:** Josh Steier

## Overview

The `TimeImageProcessor` is a time-aware image processor that pairs image loading with temporal metadata. It is designed for tasks where each patient has **multiple images taken at different times** (e.g., serial chest X-rays during an ICU stay).

**Input:** `(List[image_path], List[time_diff_from_first_admission])`

**Output:** `(N×C×H×W image tensor, N timestamp tensor, "image")`

### Steps
1. Create synthetic time-stamped chest X-ray data
2. Standalone processor usage and verification
3. Processor with normalization and truncation
4. Integration with `create_sample_dataset`
5. Verify multimodal compatibility for downstream fusion

## 1. Environment Setup

In [None]:
import os
import random
import shutil
import tempfile

import numpy as np
import torch
from PIL import Image

from pyhealth.datasets import create_sample_dataset
from pyhealth.processors.time_image_processor import TimeImageProcessor

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Create Synthetic Time-Stamped X-ray Data

We simulate a scenario where each patient has 1–5 chest X-rays taken at different times during their hospital stay. Each image gets a timestamp representing days from the patient's first admission.

In [None]:
DATA_ROOT = tempfile.mkdtemp(prefix="time_image_example_")
images_dir = os.path.join(DATA_ROOT, "images")
os.makedirs(images_dir, exist_ok=True)

NUM_PATIENTS = 20
MAX_IMAGES_PER_PATIENT = 5

samples = []
for pid in range(NUM_PATIENTS):
    # Each patient has 1-5 X-rays taken at different times
    n_images = np.random.randint(1, MAX_IMAGES_PER_PATIENT + 1)

    # Time differences from first admission in days
    time_diffs = sorted(np.random.uniform(0, 30, size=n_images).tolist())
    time_diffs[0] = 0.0  # First image always at t=0

    image_paths = []
    for j in range(n_images):
        # Synthetic grayscale X-ray with noise
        img_array = np.random.normal(80, 25, (224, 224))

        # Add lung-shaped regions
        y, x = np.ogrid[:224, :224]
        left_mask = ((x - 72)**2 / 3000 + (y - 112)**2 / 8000) < 1
        right_mask = ((x - 152)**2 / 3000 + (y - 112)**2 / 8000) < 1
        img_array[left_mask] -= 20
        img_array[right_mask] -= 20

        img_array = np.clip(img_array, 0, 255).astype(np.uint8)
        img = Image.fromarray(img_array, mode="L")

        img_path = os.path.join(images_dir, f"p{pid:03d}_t{j:02d}.png")
        img.save(img_path)
        image_paths.append(img_path)

    # Binary mortality label
    label = pid % 2

    samples.append({
        "patient_id": f"p{pid}",
        "visit_id": f"v{pid}",
        "chest_xray": (image_paths, time_diffs),
        "label": label,
    })

print(f"Created {NUM_PATIENTS} patients in {DATA_ROOT}")
print(f"Images per patient: 1-{MAX_IMAGES_PER_PATIENT}")

In [None]:
# Inspect a sample patient
sample = samples[0]
paths, times = sample["chest_xray"]

print(f"Patient {sample['patient_id']}:")
print(f"  Number of images: {len(paths)}")
print(f"  Times (days from admission): {[round(t, 1) for t in times]}")
print(f"  Mortality label: {sample['label']}")
print(f"  Image paths:")
for p, t in zip(paths, times):
    print(f"    t={t:5.1f}d  {os.path.basename(p)}")

## 3. Standalone Processor Usage

The `TimeImageProcessor` takes a tuple `(image_paths, time_diffs)` and returns `(images_tensor, timestamps_tensor, "image")`.

Key behaviors:
- **Sorts images chronologically** by timestamp
- **Truncates** to `max_images` most recent if set
- Returns the `"image"` tag for modality routing in the multimodal embedding model

In [None]:
proc = TimeImageProcessor(
    image_size=224,
    mode="L",
)

images, timestamps, tag = proc.process(sample["chest_xray"])

print(f"Input: ({len(paths)} paths, {len(times)} timestamps)")
print(f"")
print(f"Output:")
print(f"  images shape:     {images.shape}  # (N, C, H, W)")
print(f"  timestamps shape: {timestamps.shape}  # (N,)")
print(f"  modality tag:     {tag!r}")
print(f"")
print(f"  images dtype:     {images.dtype}")
print(f"  timestamps dtype: {timestamps.dtype}")
print(f"  pixel range:      [{images.min():.3f}, {images.max():.3f}]")
print(f"  timestamps:       {timestamps.tolist()}")

## 4. Chronological Sorting Verification

Even if image paths are provided in random order, the processor always returns them sorted by timestamp.

In [None]:
# Provide images in reverse order
reversed_paths = list(reversed(paths))
reversed_times = list(reversed(times))

print(f"Input order (reversed):")
for p, t in zip(reversed_paths, reversed_times):
    print(f"  t={t:5.1f}d  {os.path.basename(p)}")

_, sorted_timestamps, _ = proc.process((reversed_paths, reversed_times))

print(f"\nOutput timestamps (sorted): {sorted_timestamps.tolist()}")
print(f"Correctly sorted: {all(sorted_timestamps[i] <= sorted_timestamps[i+1] for i in range(len(sorted_timestamps)-1))}")

## 5. Truncation with `max_images`

When `max_images` is set, the processor keeps only the **most recent** images (by timestamp). This is useful for patients with many X-rays where you want to cap compute.

In [None]:
proc_truncated = TimeImageProcessor(
    image_size=224,
    mode="L",
    max_images=2,
)

imgs_trunc, ts_trunc, _ = proc_truncated.process(sample["chest_xray"])

print(f"Original images: {len(paths)}")
print(f"max_images: 2")
print(f"Output images: {imgs_trunc.shape[0]}")
print(f"Kept timestamps: {ts_trunc.tolist()}")
print(f"(These are the 2 most recent observations)")

## 6. Normalization

ImageNet-style normalization can be applied for pretrained backbone compatibility.

In [None]:
proc_norm = TimeImageProcessor(
    image_size=128,
    mode="L",
    normalize=True,
    mean=[0.5],
    std=[0.5],
)

imgs_norm, ts_norm, _ = proc_norm.process(sample["chest_xray"])

print(f"Without normalization:")
print(f"  pixel range: [{images.min():.3f}, {images.max():.3f}]")
print(f"")
print(f"With normalization (mean=0.5, std=0.5):")
print(f"  pixel range: [{imgs_norm.min():.3f}, {imgs_norm.max():.3f}]")
print(f"  output shape: {imgs_norm.shape}")

## 7. Integration with `create_sample_dataset`

The processor is registered as `"time_image"` in PyHealth's processor registry, so it can be used in task schemas. Here we show it working with `create_sample_dataset`.

In [None]:
dataset = create_sample_dataset(
    samples=samples,
    input_schema={
        "chest_xray": "time_image",
    },
    output_schema={
        "label": "binary",
    },
    input_processors={
        "chest_xray": TimeImageProcessor(
            image_size=224,
            mode="L",
            max_images=4,
        ),
    },
    dataset_name="time_xray_example",
)

print(f"Dataset: {dataset}")
print(f"Total samples: {len(dataset)}")
print(f"Input schema:  {dataset.input_schema}")
print(f"Output schema: {dataset.output_schema}")

In [None]:
# Inspect a processed sample
processed = dataset[0]
print(f"Processed sample keys: {list(processed.keys())}")
print()

xray_data = processed["chest_xray"]
if isinstance(xray_data, tuple):
    img_tensor, ts_tensor, modality_tag = xray_data
    print(f"chest_xray output:")
    print(f"  images shape:     {img_tensor.shape}  # (N, C, H, W)")
    print(f"  timestamps shape: {ts_tensor.shape}  # (N,)")
    print(f"  modality tag:     {modality_tag!r}")
else:
    print(f"  type: {type(xray_data)}")

print(f"\nlabel: {processed['label']}")

## 8. Per-Patient Output Shape Summary

Since patients have different numbers of X-rays, the output tensor shapes vary per patient. This is expected — the multimodal embedding model handles variable-length inputs via masking.

In [None]:
proc_demo = TimeImageProcessor(image_size=224, mode="L", max_images=4)

print(f"{'Patient':<10} {'N imgs':<8} {'Output Shape':<25} {'Time Range (days)'}")
print("-" * 65)

for i in range(min(10, len(samples))):
    s = samples[i]
    paths_i, times_i = s["chest_xray"]
    imgs_i, ts_i, _ = proc_demo.process((paths_i, times_i))
    print(
        f"{s['patient_id']:<10} "
        f"{len(paths_i):<8} "
        f"{str(tuple(imgs_i.shape)):<25} "
        f"[{ts_i[0]:.1f}, {ts_i[-1]:.1f}]"
    )

## 9. Multimodal Compatibility

The `TimeImageProcessor` output format is designed to feed directly into the **unified multimodal embedding model**:

```
TimeImageProcessor
    ↓
(N, C, H, W) images + (N,) timestamps + "image" tag
    ↓
VisionEncoder(images) → (B, P, E') patch embeddings
TimeEmbedding(timestamps) → temporal encoding
ModalityEmbedding("image") → modality type encoding
    ↓
Combined: (B, P, E') vision tokens
    ↓
Concatenate with other modalities:
    TextEncoder    → (B, T, E') text tokens
    TimeseriesProc → (B, S, E') timeseries tokens
    ↓
(B, P+T+S, E') → BottleneckTransformer
```

This matches the architecture specified in the multimodal design doc.

In [None]:
# Simulated multimodal input shapes for one patient
E_prime = 128  # shared embedding dimension

# Vision: TimeImageProcessor -> VisionEncoder
P = 49  # patches from CNN/ResNet backbone
vision_tokens = torch.randn(1, P, E_prime)

# Text: TextProcessor -> TextEncoder (Medical RoBERTa)
T = 64  # 128-token chunks
text_tokens = torch.randn(1, T, E_prime)

# Timeseries: TimeseriesProcessor -> TimeseriesEncoder
S = 48  # hourly lab values over 2 days
ts_tokens = torch.randn(1, S, E_prime)

# Concatenate for transformer fusion
combined = torch.cat([vision_tokens, text_tokens, ts_tokens], dim=1)

print(f"Vision tokens:     {tuple(vision_tokens.shape)}  # P={P} patches")
print(f"Text tokens:       {tuple(text_tokens.shape)}  # T={T} tokens")
print(f"Timeseries tokens: {tuple(ts_tokens.shape)}  # S={S} steps")
print(f"")
print(f"Combined sequence: {tuple(combined.shape)}  # P+T+S={P+T+S} tokens")
print(f"Ready for BottleneckTransformer input")

## 10. Cleanup

In [None]:
shutil.rmtree(DATA_ROOT)
print(f"Cleaned up: {DATA_ROOT}")

## Summary

### TimeImageProcessor

| Feature | Details |
|---|---|
| **Registry name** | `"time_image"` |
| **Input** | `(List[image_path], List[time_diff])` |
| **Output** | `(N×C×H×W tensor, N tensor, "image")` |
| **Sorting** | Chronological by timestamp |
| **Truncation** | `max_images` keeps most recent |
| **Normalization** | Optional ImageNet-style |
| **Mode** | RGB, L (grayscale), RGBA |

### Usage in task schema

```python
input_schema = {
    "chest_xray": ("time_image", {
        "image_size": 224,
        "mode": "RGB",
        "normalize": True,
        "mean": [0.485, 0.456, 0.406],
        "std": [0.229, 0.224, 0.225],
        "max_images": 8,
    }),
}
```

### Downstream pipeline

```
TimeImageProcessor → VisionEmbeddingModel → MultimodalEmbeddingModel → BottleneckTransformer
```