<a href="https://colab.research.google.com/github/GabrielWarner/PyHealth/blob/mimic-cxr-sentence-example/examples/radiology_sentence_classification_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Radiology Sentence Classification with PyHealth Tokenizer and Metrics

# **Author:** Gabriel Warner
# **NetID:** gsw3 (UIUC Online MCS)

#**Paper title (DL4H final project):**
#Integrating ChatGPT into Secure Hospital Networks: A Case Study on Improving Radiology Report Analysis

#**Paper link:**
#<ADD YOUR ARXIV / PDF / GITHUB LINK HERE>

#**Description of the task:**

#This notebook shows a small, reproducible example of **sentence-level radiology classification** using PyHealthâ€™s reusable components.
#We create a toy dataset of radiology report sentences labeled as **normal**, **abnormal**, or **uncertain**, then:

#- Use `pyhealth.tokenizer.Tokenizer` to turn tokens into indices
#- Train a small PyTorch text classifier
#- Evaluate using `pyhealth.metrics.multiclass.multiclass_metrics_fn`

#This example mirrors the sentence-level classification component from our DL4H final project and demonstrates how PyHealth modules can be applied to clinical NLP tasks.


In [1]:
# This example assumes you're running inside the PyHealth repo.
# If you're running it standalone (e.g., in Colab), uncomment the pip install:
# %pip install pyhealth

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from pyhealth.tokenizer import Tokenizer
from pyhealth.metrics import multiclass_metrics_fn  # documented in PyHealth metrics API


ModuleNotFoundError: No module named 'pyhealth'

In [None]:
# Set a fixed random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


In [None]:
radiology_samples = [
    # normal
    ("Lungs are clear. No acute cardiopulmonary abnormality.", "normal"),
    ("No focal consolidation, pleural effusion, or pneumothorax.", "normal"),
    ("Cardiomediastinal silhouette is within normal limits.", "normal"),

    # abnormal
    ("Left lower lobe opacity consistent with pneumonia.", "abnormal"),
    ("Moderate right pleural effusion with associated atelectasis.", "abnormal"),
    ("Patchy bilateral ground-glass opacities concerning for infection.", "abnormal"),

    # uncertain
    ("Findings could represent early interstitial edema.", "uncertain"),
    ("Cannot exclude a small left apical pneumothorax.", "uncertain"),
    ("Opacity may represent atelectasis versus consolidation.", "uncertain"),
]

label2id = {"normal": 0, "abnormal": 1, "uncertain": 2}
id2label = {v: k for k, v in label2id.items()}

texts = [s for s, _ in radiology_samples]
labels = [label2id[y] for _, y in radiology_samples]

list(zip(texts, labels))


In [None]:
def simple_tokenize(text: str) -> List[str]:
    # very simple whitespace + punctuation handling
    text = text.lower().replace(".", "").replace(",", "")
    return text.split()

# Build token space from our tiny corpus
token_space = sorted({tok for txt in texts for tok in simple_tokenize(txt)})

token_space


In [None]:
tokenizer = Tokenizer(tokens=token_space, special_tokens=["<pad>", "<unk>"])

vocab_size = tokenizer.get_vocabulary_size()
pad_index = tokenizer.get_padding_index()

vocab_size, pad_index


In [None]:
MAX_LENGTH = 32  # small, just for demo


class RadiologySentenceDataset(Dataset):
    """Tiny radiology sentence-level dataset using PyHealth Tokenizer.

    Each sample returns:
      - input_ids: LongTensor of token indices (padded/truncated)
      - label: LongTensor scalar (0=normal,1=abnormal,2=uncertain)
    """

    def __init__(self, texts: List[str], labels: List[int], tokenizer: Tokenizer):
        assert len(texts) == len(labels)
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        text = self.texts[idx]
        label = self.labels[idx]

        tokens = simple_tokenize(text)
        # Tokenizer expects 2D batch; we wrap [tokens] then unwrap index 0
        indices_2d = self.tokenizer.batch_encode_2d(
            batch=[tokens],
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH,
        )
        input_ids = torch.tensor(indices_2d[0], dtype=torch.long)

        return input_ids, torch.tensor(label, dtype=torch.long)


In [None]:
# simple 2/1 split: 6 for "train", 3 for "val"
train_texts, val_texts = texts[:6], texts[6:]
train_labels, val_labels = labels[:6], labels[6:]

train_dataset = RadiologySentenceDataset(train_texts, train_labels, tokenizer)
val_dataset = RadiologySentenceDataset(val_texts, val_labels, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=3, shuffle=False)

len(train_dataset), len(val_dataset)


In [None]:
@dataclass
class ModelConfig:
    vocab_size: int
    embed_dim: int = 32
    num_classes: int = 3
    pad_index: int = 0


class SimpleTextClassifier(nn.Module):
    """Minimal text classifier to pair with PyHealth Tokenizer."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.embed_dim,
            padding_idx=config.pad_index,
        )
        self.fc = nn.Linear(config.embed_dim, config.num_classes)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input_ids: [batch_size, seq_len] LongTensor

        Returns:
            logits: [batch_size, num_classes]
        """
        embedded = self.embedding(input_ids)  # [B, L, D]
        # simple average pooling over non-pad positions
        mask = (input_ids != pad_index).unsqueeze(-1)  # [B, L, 1]
        summed = (embedded * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1)
        pooled = summed / denom
        logits = self.fc(pooled)
        return logits


In [None]:
config = ModelConfig(
    vocab_size=vocab_size,
    embed_dim=32,
    num_classes=len(label2id),
    pad_index=pad_index,
)

model = SimpleTextClassifier(config).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model


In [None]:
def run_epoch(loader: DataLoader, train: bool = True):
    if train:
        model.train()
    else:
        model.eval()

    all_logits = []
    all_labels = []

    total_loss = 0.0
    count = 0

    for input_ids, labels_batch in loader:
        input_ids = input_ids.to(device)
        labels_batch = labels_batch.to(device)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            logits = model(input_ids)
            loss = criterion(logits, labels_batch)

            if train:
                loss.backward()
                optimizer.step()

        batch_size = labels_batch.size(0)
        total_loss += loss.item() * batch_size
        count += batch_size

        all_logits.append(logits.detach().cpu())
        all_labels.append(labels_batch.detach().cpu())

    avg_loss = total_loss / max(count, 1)
    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    return avg_loss, all_logits, all_labels


In [None]:
EPOCHS = 10

for epoch in range(1, EPOCHS + 1):
    train_loss, _, _ = run_epoch(train_loader, train=True)
    val_loss, val_logits, val_labels = run_epoch(val_loader, train=False)

    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")


In [None]:
import numpy as np

# Get predictions and probabilities on the full dataset (train + val)
_, train_logits, train_labels_tensor = run_epoch(train_loader, train=False)
_, val_logits, val_labels_tensor = run_epoch(val_loader, train=False)

y_true = np.concatenate(
    [train_labels_tensor.numpy(), val_labels_tensor.numpy()], axis=0
)

y_prob = torch.softmax(
    torch.cat([train_logits, val_logits], dim=0), dim=-1
).numpy()

metrics = multiclass_metrics_fn(
    y_true=y_true,
    y_prob=y_prob,
    metrics=["accuracy", "macro_f1"],
)

metrics
