# CheXpert Vision Transformer (ViT) Training Notebook

This notebook trains a ViT model on the CheXpert dataset using PyTorch and timm.

In [None]:
# 1. Install dependencies
!pip install timm torch torchvision scikit-learn pandas tqdm --quiet

## 2. Imports

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
import torch.nn as nn
import torch.optim as optim

## 3. Configurations
Set up paths, label names, and hyperparameters.

In [None]:
DATA_ROOT = '/Volumes/2TB/chest/CheXpert_Small'
CSV_TRAIN = os.path.join(DATA_ROOT, 'train.csv')
CSV_VALID = os.path.join(DATA_ROOT, 'valid.csv')
IMG_ROOT = DATA_ROOT  # image paths in CSV are relative to this

LABELS = [
    'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion',
    'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
    'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
]
NUM_CLASSES = len(LABELS)
BATCH_SIZE = 16
IMG_SIZE = 224
EPOCHS = 5  # Increase for better results
LR = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## 4. Data Preparation
Define a PyTorch Dataset for CheXpert, handling uncertain (-1.0) and NaN labels as 0.0.

In [None]:
class CheXpertDataset(Dataset):
    def __init__(self, csv_path, img_root, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_root = img_root
        self.transform = transform
        self.df[LABELS] = self.df[LABELS].fillna(0)
        self.df[LABELS] = self.df[LABELS].replace(-1.0, 0.0)
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_root, row['Path'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        labels = torch.tensor(row[LABELS].values.astype(np.float32))
        return image, labels

transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_valid = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_ds = CheXpertDataset(CSV_TRAIN, IMG_ROOT, transform=transform_train)
valid_ds = CheXpertDataset(CSV_VALID, IMG_ROOT, transform=transform_valid)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

## 5. Model
Create a Vision Transformer (ViT) model using timm.

In [None]:
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=NUM_CLASSES)
model = model.to(DEVICE)

## 6. Loss and Optimizer
Use BCEWithLogitsLoss for multi-label classification.

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

## 7. Training and Evaluation
Train the model and evaluate AUC for each label.

In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(loader):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(loader.dataset)

def evaluate(model, loader):
    model.eval()
    all_labels = []
    all_outputs = []
    with torch.no_grad():
        for images, labels in tqdm(loader):
            images = images.to(DEVICE)
            outputs = model(images)
            all_outputs.append(torch.sigmoid(outputs).cpu().numpy())
            all_labels.append(labels.numpy())
    all_outputs = np.concatenate(all_outputs)
    all_labels = np.concatenate(all_labels)
    aucs = []
    for i in range(NUM_CLASSES):
        try:
            auc = roc_auc_score(all_labels[:, i], all_outputs[:, i])
        except:
            auc = np.nan
        aucs.append(auc)
    return aucs

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    print(f"Train Loss: {train_loss:.4f}")
    aucs = evaluate(model, valid_loader)
    for i, label in enumerate(LABELS):
        print(f"{label}: AUC = {aucs[i]:.4f}")
    print(f"Mean AUC: {np.nanmean(aucs):.4f}")

## 8. Save Model
Save the trained model weights.

In [None]:
torch.save(model.state_dict(), 'chexpert_vit.pth')
print('Model saved as chexpert_vit.pth')