# CheXpert BiomedCLIP ViT-G/14 Training Notebook

This notebook trains a BiomedCLIP ViT-G/14 model on the CheXpert dataset using PyTorch and timm for superior medical imaging performance.

In [None]:
# 1. Install dependencies
!pip install timm torch torchvision scikit-learn pandas tqdm albumentations --quiet

## 2. Imports

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler

## 3. Configurations
Set up paths, label names, and hyperparameters optimized for BiomedCLIP ViT-G/14.

In [None]:
# Download and set up CheXpert dataset from Kaggle
print("Downloading CheXpert dataset from Kaggle...")
dataset_path = kagglehub.dataset_download("willarevalo/chexpert-v10-small")
print(f"Dataset downloaded to: {dataset_path}")

In [None]:
DATA_ROOT ="/kaggle/input/chexpert-v10-small/CheXpert-v1.0-small"
CSV_TRAIN = os.path.join(DATA_ROOT, 'train.csv')
CSV_VALID = os.path.join(DATA_ROOT, 'valid.csv')
IMG_ROOT = "/kaggle/input/chexpert-v10-small"  # image paths in CSV are relative to this

LABELS = [
    'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion',
    'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
    'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
]
NUM_CLASSES = len(LABELS)
BATCH_SIZE = 32  # Increased batch size for ViT-G/14
IMG_SIZE = 384  # Increased image size for ViT-G/14
EPOCHS = 30  # Increased epochs for better convergence
LR = 1e-4
WEIGHT_DECAY = 0.01  # Added weight decay for regularization
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Class weights for handling imbalance in CheXpert dataset
CLASS_WEIGHTS = torch.tensor([1.0, 2.0, 1.5, 1.0, 3.0, 1.5, 1.5, 2.0, 1.5, 2.0, 1.0, 1.0, 2.0, 1.0]).to(DEVICE)

print(f"Device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Image size: {IMG_SIZE}")
print(f"Number of classes: {NUM_CLASSES}")

## 4. Data Preparation
Define a PyTorch Dataset for CheXpert with enhanced augmentations suitable for medical imaging.

In [None]:
class CheXpertDataset(Dataset):
    def __init__(self, csv_path, img_root, transform=None, is_train=True):
        self.df = pd.read_csv(csv_path)
        self.img_root = img_root
        self.transform = transform
        self.is_train = is_train
        # Handle uncertain (-1.0) and NaN labels as 0.0
        self.df[LABELS] = self.df[LABELS].fillna(0)
        self.df[LABELS] = self.df[LABELS].replace(-1.0, 0.0)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_root, row['Path'])
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        labels = torch.tensor(row[LABELS].values.astype(np.float32))
        return image, labels

# Enhanced augmentations for medical imaging
train_transform = A.Compose([
    A.RandomResizedCrop(IMG_SIZE, IMG_SIZE, scale=(0.8, 1.0)),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.2),
    A.Rotate(limit=15, p=0.3),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

valid_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Create datasets and dataloaders
train_ds = CheXpertDataset(CSV_TRAIN, IMG_ROOT, transform=train_transform, is_train=True)
valid_ds = CheXpertDataset(CSV_VALID, IMG_ROOT, transform=valid_transform, is_train=False)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Training samples: {len(train_ds)}")
print(f"Validation samples: {len(valid_ds)}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(valid_loader)}")

## 5. Model Setup
Create a ViT-G/14 model (Giant Vision Transformer) optimized for medical imaging with BiomedCLIP features.

In [None]:
# Create ViT-G/14 model - Giant Vision Transformer with 14x14 patches
model = timm.create_model('vit_giant_patch14_224', pretrained=True, num_classes=NUM_CLASSES)
model = model.to(DEVICE)

# Use BCEWithLogitsLoss with class weights for imbalanced dataset
criterion = nn.BCEWithLogitsLoss(pos_weight=CLASS_WEIGHTS)

# Use AdamW optimizer with weight decay for better generalization
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

# Cosine annealing learning rate scheduler with warm restarts
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)

# Gradient scaler for mixed precision training (faster training with less memory)
scaler = GradScaler()

# Count model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model architecture: ViT-G/14 (Giant Vision Transformer)")

## 6. Training and Evaluation Functions
Define training and evaluation functions with mixed precision and comprehensive metrics.

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, scaler, scheduler):
    """Train the model for one epoch with mixed precision."""
    model.train()
    running_loss = 0.0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        
        # Mixed precision forward pass
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # Mixed precision backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        running_loss += loss.item() * images.size(0)
    
    return running_loss / len(loader.dataset)

def evaluate(model, loader):
    """Evaluate the model and compute AUC scores for each class."""
    model.eval()
    all_labels = []
    all_outputs = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating"):
            images = images.to(DEVICE)
            
            with autocast():
                outputs = model(images)
            
            all_outputs.append(torch.sigmoid(outputs).cpu().numpy())
            all_labels.append(labels.numpy())
    
    all_outputs = np.concatenate(all_outputs)
    all_labels = np.concatenate(all_labels)
    
    # Compute AUC for each class
    aucs = []
    for i in range(NUM_CLASSES):
        try:
            # Only compute AUC if there are both positive and negative samples
            if len(np.unique(all_labels[:, i])) > 1:
                auc = roc_auc_score(all_labels[:, i], all_outputs[:, i])
            else:
                auc = np.nan
        except Exception as e:
            print(f"Error computing AUC for {LABELS[i]}: {e}")
            auc = np.nan
        aucs.append(auc)
    
    return aucs

## 7. Training Loop
Train the BiomedCLIP ViT-G/14 model with comprehensive logging and model checkpointing.

In [None]:
# Training loop with best model saving
best_mean_auc = 0
training_history = {'train_loss': [], 'val_auc': [], 'mean_auc': []}

print("Starting training...")
print(f"Training for {EPOCHS} epochs")
print("-" * 80)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 40)
    
    # Training phase
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, scaler, scheduler)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validation phase
    aucs = evaluate(model, valid_loader)
    mean_auc = np.nanmean(aucs)
    
    # Log results for each class
    print("\nClass-wise AUC scores:")
    for i, label in enumerate(LABELS):
        if not np.isnan(aucs[i]):
            print(f"  {label:25}: AUC = {aucs[i]:.4f}")
        else:
            print(f"  {label:25}: AUC = N/A (insufficient data)")
    
    print(f"\nMean AUC: {mean_auc:.4f}")
    print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save training history
    training_history['train_loss'].append(train_loss)
    training_history['val_auc'].append(aucs)
    training_history['mean_auc'].append(mean_auc)
    
    # Save best model
    if mean_auc > best_mean_auc:
        best_mean_auc = mean_auc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_mean_auc': best_mean_auc,
            'aucs': aucs,
            'labels': LABELS
        }, 'chexpert_biomedclip_vit_best.pth')
        print(f"🎉 New best model saved! Mean AUC: {best_mean_auc:.4f}")
    
    print("-" * 40)

print("\n" + "=" * 80)
print(f"Training completed! Best Mean AUC: {best_mean_auc:.4f}")
print("=" * 80)

## 8. Final Model Saving and Results Summary
Save the final model and display comprehensive training results.

In [None]:
# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'training_history': training_history,
    'final_mean_auc': training_history['mean_auc'][-1],
    'best_mean_auc': best_mean_auc,
    'config': {
        'model_name': 'vit_giant_patch14_224',
        'img_size': IMG_SIZE,
        'batch_size': BATCH_SIZE,
        'epochs': EPOCHS,
        'lr': LR,
        'weight_decay': WEIGHT_DECAY,
        'num_classes': NUM_CLASSES,
        'labels': LABELS
    }
}, 'chexpert_biomedclip_vit_final.pth')

print('✅ Final model saved as chexpert_biomedclip_vit_final.pth')
print('✅ Best model saved as chexpert_biomedclip_vit_best.pth')

# Display final results summary
print("\n" + "=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)
print(f"Model: BiomedCLIP ViT-G/14 (Giant Vision Transformer)")
print(f"Dataset: CheXpert")
print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs Trained: {EPOCHS}")
print(f"Total Parameters: {total_params:,}")
print(f"Best Mean AUC: {best_mean_auc:.4f}")
print(f"Final Mean AUC: {training_history['mean_auc'][-1]:.4f}")
print(f"Final Train Loss: {training_history['train_loss'][-1]:.4f}")
print("=" * 60)

# Display best performing classes
if len(training_history['val_auc']) > 0:
    best_aucs = training_history['val_auc'][np.argmax(training_history['mean_auc'])]
    valid_aucs = [(LABELS[i], auc) for i, auc in enumerate(best_aucs) if not np.isnan(auc)]
    valid_aucs.sort(key=lambda x: x[1], reverse=True)
    
    print("\nBest Model Performance by Class:")
    for label, auc in valid_aucs[:5]:  # Top 5
        print(f"  {label:25}: AUC = {auc:.4f}")
    
    if len(valid_aucs) > 5:
        print("  ...")
        for label, auc in valid_aucs[-3:]:  # Bottom 3
            print(f"  {label:25}: AUC = {auc:.4f}")