
# Multimodal Healthcare Classifier (CV + NLP) — PneumoniaMNIST

**Highlights:** Deep Learning • Computer Vision • NLP • Multimodal Fusion • PyTorch • Hugging Face

This notebook trains three models on a lightweight medical dataset:
- **Image-only**: ResNet-18 on chest X-rays (PneumoniaMNIST, via MedMNIST)
- **Text-only**: DistilBERT on *synthetic* short clinical notes correlated with the labels
- **Fusion**: Combines image and text embeddings for improved performance

**Why it's portfolio-ready**  
Clean config, baselines vs fusion, clear metrics (Accuracy/F1), confusion matrices, checkpoints, and reproducibility.



## 1) Environment setup (run locally)
Uncomment and run the cell below **in your local environment**.


In [None]:

# %%capture
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install transformers medmnist scikit-learn matplotlib tqdm



## 2) Imports & Config


In [None]:

import os
import random
import json
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModel

from medmnist import PneumoniaMNIST
from medmnist import INFO as MED_INFO

@dataclass
class Config:
    seed: int = 42
    img_size: int = 224
    batch_size: int = 64
    num_workers: int = 2
    lr: float = 2e-4
    epochs_fusion: int = 5
    epochs_baseline: int = 3
    model_dir: str = "checkpoints"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    text_model_name: str = "distilbert-base-uncased"
    max_text_len: int = 64

cfg = Config()
os.makedirs(cfg.model_dir, exist_ok=True)

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)
print("Device:", cfg.device)



## 3) Utilities (metrics & plotting)


In [None]:

def plot_confusion(y_true, y_pred, labels, title, save_path=None):
    cm = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots(figsize=(4.5,4))
    im = ax.imshow(cm, interpolation='nearest')
    ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=labels, yticklabels=labels,
           ylabel='True label',
           xlabel='Predicted label',
           title=title)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], 'd'),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

def compute_metrics(y_true, y_pred, name: str):
    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred)
    print(f"[{name}] Acc: {acc:.4f} | F1: {f1:.4f}")
    print(classification_report(y_true, y_pred, target_names=["normal","pneumonia"]))



## 4) Data: PneumoniaMNIST (MedMNIST) & transforms


In [None]:

def get_medmnist_splits() -> Tuple[PneumoniaMNIST, PneumoniaMNIST, PneumoniaMNIST]:
    info = MED_INFO["pneumoniamnist"]
    print(f"PneumoniaMNIST: {info['description']} | Task: {info['task']} | Labels: {info['label']}")
    train_set = PneumoniaMNIST(split='train', download=True)
    val_set   = PneumoniaMNIST(split='val', download=True)
    test_set  = PneumoniaMNIST(split='test', download=True)
    return train_set, val_set, test_set

def make_transforms(img_size: int):
    train_tf = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: t.repeat(3,1,1)),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    eval_tf = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: t.repeat(3,1,1)),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    return train_tf, eval_tf



## 5) Synthetic clinical notes (NLP) + Dataset wrapper


In [None]:

def synthesize_note(label: int) -> str:
    normal_templates = [
        "Patient with mild cough, afebrile, clear breath sounds, O2 sat stable.",
        "No respiratory distress noted; chest auscultation unremarkable.",
        "Vitals within normal limits; denies dyspnea; lungs clear to auscultation.",
    ]
    pna_templates = [
        "Productive cough with fever; crackles at bases; increased WBC; suspected pneumonia.",
        "Shortness of breath and chills; focal consolidation on exam; starting antibiotics.",
        "Febrile with hypoxia; chest findings consistent with community-acquired pneumonia.",
    ]
    additives = [
        "Hx HTN.", "No allergies.", "Smoker; counseling provided.", "Recent viral illness.",
        "Follow-up CXR recommended.", "Sputum culture pending."
    ]
    base = random.choice(pna_templates if label == 1 else normal_templates)
    extra = " ".join(random.sample(additives, k=random.randint(0,2)))
    return (base + " " + extra).strip()

class MultimodalPneumo(Dataset):
    def __init__(self, base: PneumoniaMNIST, tf, tokenizer, max_len: int, seed: int = 42):
        self.base = base
        self.tf = tf
        self.tokenizer = tokenizer
        self.max_len = max_len

        random.seed(seed)
        self.notes = [synthesize_note(int(label)) for _, label in self.base]

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, label = self.base[idx]
        label = int(label)
        img = np.array(img.squeeze(), dtype=np.uint8)
        img = self.tf(img)

        text = self.notes[idx]
        enc = self.tokenizer(text, truncation=True, max_length=self.max_len,
                             padding="max_length", return_tensors="pt")
        enc = {k: v.squeeze(0) for k, v in enc.items()}
        return {
            "pixel_values": img,
            "input_ids": enc["input_ids"],
            "attention_mask": enc["attention_mask"],
            "label": torch.tensor(label, dtype=torch.long),
            "text": text
        }



## 6) Models: Image encoder, Text encoder, and Fusion heads


In [None]:

class ImageEncoder(nn.Module):
    def __init__(self, out_dim=128, pretrained=True):
        super().__init__()
        m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None)
        in_feats = m.fc.in_features
        m.fc = nn.Identity()
        self.backbone = m
        self.proj = nn.Sequential(
            nn.Linear(in_feats, 256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, out_dim)
        )
    def forward(self, x):
        h = self.backbone(x)
        z = self.proj(h)
        return z

class TextEncoder(nn.Module):
    def __init__(self, model_name: str, out_dim=128, freeze_bert=True):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False
        hidden = self.bert.config.hidden_size
        self.proj = nn.Sequential(
            nn.Linear(hidden, 256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, out_dim)
        )
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls = outputs.last_hidden_state[:,0]
        z = self.proj(cls)
        return z

class FusionClassifier(nn.Module):
    def __init__(self, img_dim=128, txt_dim=128, num_classes=2):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(img_dim + txt_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    def forward(self, zi, zt):
        x = torch.cat([zi, zt], dim=1)
        return self.classifier(x)

class ImageOnlyClassifier(nn.Module):
    def __init__(self, img_dim=128, num_classes=2):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(img_dim, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
    def forward(self, zi):
        return self.head(zi)

class TextOnlyClassifier(nn.Module):
    def __init__(self, txt_dim=128, num_classes=2):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(txt_dim, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
    def forward(self, zt):
        return self.head(zt)

class MultimodalSystem(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.img_enc = ImageEncoder(out_dim=128, pretrained=True)
        self.txt_enc = TextEncoder(cfg.text_model_name, out_dim=128, freeze_bert=True)
        self.fusion = FusionClassifier(img_dim=128, txt_dim=128, num_classes=2)
    def forward(self, imgs, input_ids, attention_mask):
        zi = self.img_enc(imgs)
        zt = self.txt_enc(input_ids, attention_mask)
        logits = self.fusion(zi, zt)
        return logits

class ImageOnlySystem(nn.Module):
    def __init__(self):
        super().__init__()
        self.img_enc = ImageEncoder(out_dim=128, pretrained=True)
        self.cls = ImageOnlyClassifier(img_dim=128, num_classes=2)
    def forward(self, imgs, *_):
        zi = self.img_enc(imgs)
        return self.cls(zi)

class TextOnlySystem(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.txt_enc = TextEncoder(cfg.text_model_name, out_dim=128, freeze_bert=True)
        self.cls = TextOnlyClassifier(txt_dim=128, num_classes=2)
    def forward(self, _, input_ids, attention_mask):
        zt = self.txt_enc(input_ids, attention_mask)
        return self.cls(zt)



## 7) Train / Eval loops


In [None]:

def run_epoch(model, loaders, optim, criterion, stage, device):
    is_train = stage == "train"
    model.train(is_train)

    total_loss = 0.0
    y_true, y_pred = [], []

    pbar = tqdm(loaders, desc=f"{stage}", leave=False)
    for batch in pbar:
        imgs = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attn = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        if is_train:
            optim.zero_grad()

        logits = model(imgs, input_ids, attn)
        loss = criterion(logits, labels)

        if is_train:
            loss.backward()
            optim.step()

        total_loss += loss.item() * labels.size(0)
        preds = logits.argmax(dim=1).detach().cpu().numpy().tolist()
        y_pred.extend(preds)
        y_true.extend(labels.detach().cpu().numpy().tolist())

    avg_loss = total_loss / len(loaders.dataset)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    return avg_loss, acc, f1, (y_true, y_pred)



## 8) Train baselines and fusion


In [None]:

tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name)
train_raw, val_raw, test_raw = get_medmnist_splits()
tf_train, tf_eval = make_transforms(cfg.img_size)

train_ds = MultimodalPneumo(train_raw, tf_train, tokenizer, cfg.max_text_len, seed=cfg.seed)
val_ds   = MultimodalPneumo(val_raw,   tf_eval,  tokenizer, cfg.max_text_len, seed=cfg.seed)
test_ds  = MultimodalPneumo(test_raw,  tf_eval,  tokenizer, cfg.max_text_len, seed=cfg.seed)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

criterion = nn.CrossEntropyLoss()

# --- Image-only
print("\n=== Baseline: Image-Only ===")
img_only = ImageOnlySystem().to(cfg.device)
optim_i = torch.optim.AdamW([p for p in img_only.parameters() if p.requires_grad], lr=cfg.lr)
best_val = 0.0
for epoch in range(cfg.epochs_baseline):
    tr = run_epoch(img_only, train_loader, optim_i, criterion, "train", cfg.device)
    vl = run_epoch(img_only, val_loader,   optim_i, criterion, "eval",  cfg.device)
    print(f"[ImageOnly][Epoch {epoch+1}] train_loss={tr[0]:.4f} val_acc={vl[1]:.4f} val_f1={vl[2]:.4f}")
    if vl[1] > best_val:
        best_val = vl[1]
        torch.save(img_only.state_dict(), os.path.join(cfg.model_dir, "image_only.pt"))

# --- Text-only
print("\n=== Baseline: Text-Only ===")
txt_only = TextOnlySystem(cfg).to(cfg.device)
optim_t = torch.optim.AdamW([p for p in txt_only.parameters() if p.requires_grad], lr=cfg.lr)
best_val = 0.0
for epoch in range(cfg.epochs_baseline):
    tr = run_epoch(txt_only, train_loader, optim_t, criterion, "train", cfg.device)
    vl = run_epoch(txt_only, val_loader,   optim_t, criterion, "eval",  cfg.device)
    print(f"[TextOnly][Epoch {epoch+1}] train_loss={tr[0]:.4f} val_acc={vl[1]:.4f} val_f1={vl[2]:.4f}")
    if vl[1] > best_val:
        best_val = vl[1]
        torch.save(txt_only.state_dict(), os.path.join(cfg.model_dir, "text_only.pt"))

# --- Fusion
print("\n=== Multimodal Fusion (Image + Text) ===")
mm = MultimodalSystem(cfg).to(cfg.device)
optim_f = torch.optim.AdamW([p for p in mm.parameters() if p.requires_grad], lr=cfg.lr)
best_val = 0.0
for epoch in range(cfg.epochs_fusion):
    tr = run_epoch(mm, train_loader, optim_f, criterion, "train", cfg.device)
    vl = run_epoch(mm, val_loader,   optim_f, criterion, "eval",  cfg.device)
    print(f"[Fusion][Epoch {epoch+1}] train_loss={tr[0]:.4f} val_acc={vl[1]:.4f} val_f1={vl[2]:.4f}")
    if vl[1] > best_val:
        best_val = vl[1]
        torch.save(mm.state_dict(), os.path.join(cfg.model_dir, "fusion.pt"))



## 9) Evaluate on test set + Confusion matrices


In [None]:

def evaluate_model(model, loader, name):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for b in loader:
            imgs = b["pixel_values"].to(cfg.device)
            input_ids = b["input_ids"].to(cfg.device)
            attn = b["attention_mask"].to(cfg.device)
            labels = b["label"].to(cfg.device)
            logits = model(imgs, input_ids, attn)
            preds = logits.argmax(dim=1).cpu().numpy().tolist()
            y_pred.extend(preds)
            y_true.extend(labels.cpu().numpy().tolist())
    compute_metrics(y_true, y_pred, name)
    plot_confusion(y_true, y_pred, ["normal","pneumonia"], f"Confusion ({name})")

# Load best checkpoints (if training cell was run)
if os.path.exists(os.path.join(cfg.model_dir, "image_only.pt")):
    img_only.load_state_dict(torch.load(os.path.join(cfg.model_dir, "image_only.pt"), map_location=cfg.device))
if os.path.exists(os.path.join(cfg.model_dir, "text_only.pt")):
    txt_only.load_state_dict(torch.load(os.path.join(cfg.model_dir, "text_only.pt"), map_location=cfg.device))
if os.path.exists(os.path.join(cfg.model_dir, "fusion.pt")):
    mm.load_state_dict(torch.load(os.path.join(cfg.model_dir, "fusion.pt"), map_location=cfg.device))

evaluate_model(img_only, test_loader, "image_only")
evaluate_model(txt_only, test_loader, "text_only")
evaluate_model(mm, test_loader, "fusion")

# Save quick summary
results = []
# For a concise record, we can re-run evaluations capturing metrics into results.json as needed.



## 10) Optional extensions
- Unfreeze last ResNet block / BERT layers for fine-tuning (lower LR).
- Swap DistilBERT with **ClinicalBERT** (if licensing/data access permits) for more clinical language.
- Add **Grad-CAM** for image explainability & attention visualization for text.
- Log experiments with **Weights & Biases** or **MLflow**.
- Replace synthetic notes with a de-identified corpus when available.
