# CheXpert BiomedCLIP ViT-G/14 Training Notebook

This notebook trains a BiomedCLIP ViT-G/14 model on the CheXpert dataset using PyTorch and timm for superior medical imaging performance.

In [None]:
# 1. Install dependencies for BiomedCLIP training
!pip install timm torch torchvision scikit-learn pandas tqdm albumentations --quiet
!pip install open_clip_torch transformers datasets --quiet
!pip install huggingface_hub --quiet

## 2. Imports

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler

# BiomedCLIP imports
try:
    import open_clip
    OPENCLIP_AVAILABLE = True
except ImportError:
    OPENCLIP_AVAILABLE = False
    print("⚠️ open_clip not available")

try:
    from transformers import AutoModel, AutoProcessor, CLIPModel, CLIPProcessor
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    print("⚠️ transformers not available")

print(f"OpenCLIP available: {OPENCLIP_AVAILABLE}")
print(f"Transformers available: {TRANSFORMERS_AVAILABLE}")

## 3. Configurations
Set up paths, label names, and hyperparameters optimized for BiomedCLIP ViT-G/14.

In [None]:
# Download and set up CheXpert dataset from Kaggle
print("Downloading CheXpert dataset from Kaggle...")
dataset_path = kagglehub.dataset_download("willarevalo/chexpert-v10-small")
print(f"Dataset downloaded to: {dataset_path}")

In [None]:
DATA_ROOT ="/kaggle/input/chexpert-v10-small/CheXpert-v1.0-small"
CSV_TRAIN = os.path.join(DATA_ROOT, 'train.csv')
CSV_VALID = os.path.join(DATA_ROOT, 'valid.csv')
IMG_ROOT = "/kaggle/input/chexpert-v10-small"  # image paths in CSV are relative to this

LABELS = [
    'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion',
    'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
    'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
]
NUM_CLASSES = len(LABELS)

# Optimized hyperparameters for BiomedCLIP and 95%+ accuracy
BATCH_SIZE = 64  # Increased for better gradient estimates
IMG_SIZE = 224  # BiomedCLIP standard size
EPOCHS = 50  # Increased for better convergence
LR_BACKBONE = 1e-5  # Very low LR for pre-trained backbone
LR_HEAD = 1e-3  # Higher LR for classification head
WEIGHT_DECAY = 0.01
WARMUP_EPOCHS = 5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Enhanced class weights for better balance
CLASS_WEIGHTS = torch.tensor([0.8, 3.0, 2.0, 1.2, 4.0, 2.5, 2.5, 3.0, 2.0, 3.5, 1.5, 1.5, 3.0, 1.2]).to(DEVICE)

# Training strategy flags
FREEZE_BACKBONE = True  # Start with frozen backbone
USE_FOCAL_LOSS = True  # Better for imbalanced data
USE_LABEL_SMOOTHING = True  # Regularization technique

print(f"Device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Image size: {IMG_SIZE}")
print(f"Number of classes: {NUM_CLASSES}")
print(f"Epochs: {EPOCHS}")
print(f"Freeze backbone: {FREEZE_BACKBONE}")
print(f"Use focal loss: {USE_FOCAL_LOSS}")
print(f"Use label smoothing: {USE_LABEL_SMOOTHING}")

## 4. Data Preparation
Define a PyTorch Dataset for CheXpert with enhanced augmentations suitable for medical imaging.

In [None]:
class CheXpertDataset(Dataset):
    def __init__(self, csv_path, img_root, transform=None, is_train=True):
        self.df = pd.read_csv(csv_path)
        self.img_root = img_root
        self.transform = transform
        self.is_train = is_train
        
        # Enhanced label handling for better accuracy
        # Handle uncertain (-1.0) as 0.0 and NaN as 0.0
        self.df[LABELS] = self.df[LABELS].fillna(0)
        self.df[LABELS] = self.df[LABELS].replace(-1.0, 0.0)
        
        # Apply label smoothing if enabled
        if USE_LABEL_SMOOTHING and is_train:
            smoothing = 0.1
            self.df[LABELS] = self.df[LABELS] * (1 - smoothing) + smoothing / 2
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_root, row['Path'])
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        labels = torch.tensor(row[LABELS].values.astype(np.float32))
        return image, labels

# BiomedCLIP optimized transforms
train_transform = A.Compose([
    A.RandomResizedCrop(IMG_SIZE, IMG_SIZE, scale=(0.85, 1.0)),  # Less aggressive cropping for medical images
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.3),
    A.Rotate(limit=10, p=0.3),  # Reduced rotation for medical accuracy
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=10, p=0.3),
    A.GaussianBlur(blur_limit=3, p=0.1),  # Medical image specific augmentation
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.2),  # Contrast enhancement
    # BiomedCLIP normalization
    A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
               std=[0.26862954, 0.26130258, 0.27577711]),
    ToTensorV2()
])

valid_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    # BiomedCLIP normalization
    A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
               std=[0.26862954, 0.26130258, 0.27577711]),
    ToTensorV2()
])

# Create datasets and dataloaders
train_ds = CheXpertDataset(CSV_TRAIN, IMG_ROOT, transform=train_transform, is_train=True)
valid_ds = CheXpertDataset(CSV_VALID, IMG_ROOT, transform=valid_transform, is_train=False)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Training samples: {len(train_ds)}")
print(f"Validation samples: {len(valid_ds)}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(valid_loader)}")

## 5. BiomedCLIP Model Setup Options
Choose one of the following methods to load BiomedCLIP. Run only ONE of the following cells based on availability and preference.

In [None]:
# OPTION 1: BiomedCLIP via OpenCLIP (Recommended)
# Run this cell if open_clip is available

if OPENCLIP_AVAILABLE:
    try:
        print("Loading BiomedCLIP via OpenCLIP...")
        
        # Load BiomedCLIP model
        clip_model, _, _ = open_clip.create_model_and_transforms(
            'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
        )
        
        class BiomedCLIPClassifier(nn.Module):
            def __init__(self, clip_model, num_classes, freeze_backbone=True):
                super().__init__()
                self.clip_model = clip_model
                self.freeze_backbone = freeze_backbone
                
                # Freeze backbone if specified
                if freeze_backbone:
                    for param in self.clip_model.parameters():
                        param.requires_grad = False
                    print("🔒 Backbone frozen for initial training")
                else:
                    print("🔓 Backbone unfrozen for fine-tuning")
                
                # Enhanced classification head for better performance
                feature_dim = clip_model.visual.output_dim
                self.classifier = nn.Sequential(
                    nn.Dropout(0.2),
                    nn.Linear(feature_dim, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(1024, 512),
                    nn.BatchNorm1d(512),
                    nn.ReLU(),
                    nn.Dropout(0.2),
                    nn.Linear(512, num_classes)
                )
                
            def forward(self, images):
                with torch.cuda.amp.autocast():
                    image_features = self.clip_model.encode_image(images)
                    if self.freeze_backbone:
                        image_features = image_features.detach()
                    return self.classifier(image_features)
            
            def unfreeze_backbone(self):
                """Unfreeze backbone for fine-tuning"""
                for param in self.clip_model.parameters():
                    param.requires_grad = True
                self.freeze_backbone = False
                print("🔓 Backbone unfrozen for fine-tuning")
        
        model = BiomedCLIPClassifier(clip_model, NUM_CLASSES, freeze_backbone=FREEZE_BACKBONE)
        model = model.to(DEVICE)
        
        MODEL_TYPE = "BiomedCLIP-OpenCLIP"
        print(f"✅ Successfully loaded {MODEL_TYPE}")
        
    except Exception as e:
        print(f"❌ Failed to load BiomedCLIP via OpenCLIP: {e}")
        model = None
else:
    print("❌ OpenCLIP not available. Try Option 2 or 3.")
    model = None

In [None]:
# OPTION 2: BiomedCLIP via Transformers (Alternative)
# Run this cell if transformers is available and Option 1 failed

if TRANSFORMERS_AVAILABLE and model is None:
    try:
        print("Loading BiomedCLIP via Transformers...")
        
        # Load BiomedCLIP using transformers
        model_name = "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
        clip_model = CLIPModel.from_pretrained(model_name)
        processor = CLIPProcessor.from_pretrained(model_name)
        
        class BiomedCLIPTransformersClassifier(nn.Module):
            def __init__(self, clip_model, num_classes, freeze_backbone=True):
                super().__init__()
                self.clip_model = clip_model
                self.freeze_backbone = freeze_backbone
                
                # Freeze backbone if specified
                if freeze_backbone:
                    for param in self.clip_model.parameters():
                        param.requires_grad = False
                    print("🔒 Backbone frozen for initial training")
                
                # Classification head
                feature_dim = clip_model.config.projection_dim
                self.classifier = nn.Sequential(
                    nn.Dropout(0.2),
                    nn.Linear(feature_dim, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(1024, 512),
                    nn.BatchNorm1d(512),
                    nn.ReLU(),
                    nn.Dropout(0.2),
                    nn.Linear(512, num_classes)
                )
                
            def forward(self, images):
                with torch.cuda.amp.autocast():
                    image_features = self.clip_model.get_image_features(images)
                    if self.freeze_backbone:
                        image_features = image_features.detach()
                    return self.classifier(image_features)
            
            def unfreeze_backbone(self):
                for param in self.clip_model.parameters():
                    param.requires_grad = True
                self.freeze_backbone = False
                print("🔓 Backbone unfrozen for fine-tuning")
        
        model = BiomedCLIPTransformersClassifier(clip_model, NUM_CLASSES, freeze_backbone=FREEZE_BACKBONE)
        model = model.to(DEVICE)
        
        MODEL_TYPE = "BiomedCLIP-Transformers"
        print(f"✅ Successfully loaded {MODEL_TYPE}")
        
    except Exception as e:
        print(f"❌ Failed to load BiomedCLIP via Transformers: {e}")
        model = None
else:
    if model is not None:
        print("✅ Model already loaded, skipping Option 2")
    else:
        print("❌ Transformers not available. Try Option 3.")

In [None]:
# OPTION 3: Enhanced ViT Fallback (if BiomedCLIP unavailable)
# Run this cell if both Option 1 and 2 failed

if model is None:
    print("Loading enhanced ViT as fallback...")
    
    # Use the largest available ViT model
    try:
        # Try ViT-Large first
        base_model = timm.create_model('vit_large_patch16_224', pretrained=True, num_classes=0)  # No head
        feature_dim = base_model.num_features
        model_name = 'vit_large_patch16_224'
    except:
        # Fallback to ViT-Base
        base_model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        feature_dim = base_model.num_features
        model_name = 'vit_base_patch16_224'
    
    class EnhancedViTClassifier(nn.Module):
        def __init__(self, base_model, feature_dim, num_classes, freeze_backbone=True):
            super().__init__()
            self.backbone = base_model
            self.freeze_backbone = freeze_backbone
            
            if freeze_backbone:
                for param in self.backbone.parameters():
                    param.requires_grad = False
                print("🔒 Backbone frozen for initial training")
            
            # Enhanced classification head
            self.classifier = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(feature_dim, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(1024, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(512, num_classes)
            )
            
        def forward(self, x):
            features = self.backbone(x)
            if self.freeze_backbone:
                features = features.detach()
            return self.classifier(features)
        
        def unfreeze_backbone(self):
            for param in self.backbone.parameters():
                param.requires_grad = True
            self.freeze_backbone = False
            print("🔓 Backbone unfrozen for fine-tuning")
    
    model = EnhancedViTClassifier(base_model, feature_dim, NUM_CLASSES, freeze_backbone=FREEZE_BACKBONE)
    model = model.to(DEVICE)
    
    MODEL_TYPE = f"Enhanced-{model_name}"
    print(f"✅ Successfully loaded {MODEL_TYPE} as fallback")

# Model summary
if model is not None:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nModel: {MODEL_TYPE}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Frozen parameters: {total_params - trainable_params:,}")
else:
    raise RuntimeError("❌ Failed to load any model. Check your installations.")

In [None]:
# Advanced Loss Functions and Optimizer Setup

class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Setup loss function
if USE_FOCAL_LOSS:
    criterion = FocalLoss(alpha=1, gamma=2)
    print("✅ Using Focal Loss for better class balance")
else:
    criterion = nn.BCEWithLogitsLoss(pos_weight=CLASS_WEIGHTS)
    print("✅ Using weighted BCE Loss")

# Setup optimizer with different learning rates
if hasattr(model, 'classifier'):
    if FREEZE_BACKBONE:
        # Only train classifier when backbone is frozen
        optimizer = optim.AdamW(model.classifier.parameters(), lr=LR_HEAD, weight_decay=WEIGHT_DECAY)
        print(f"✅ Optimizer setup for frozen backbone (LR: {LR_HEAD})")
    else:
        # Different learning rates for backbone and classifier
        if hasattr(model, 'clip_model'):
            backbone_params = model.clip_model.parameters()
        else:
            backbone_params = model.backbone.parameters()
            
        optimizer = optim.AdamW([
            {'params': backbone_params, 'lr': LR_BACKBONE},
            {'params': model.classifier.parameters(), 'lr': LR_HEAD}
        ], weight_decay=WEIGHT_DECAY)
        print(f"✅ Optimizer setup for fine-tuning (Backbone LR: {LR_BACKBONE}, Head LR: {LR_HEAD})")
else:
    optimizer = optim.AdamW(model.parameters(), lr=LR_HEAD, weight_decay=WEIGHT_DECAY)
    print(f"✅ Standard optimizer setup (LR: {LR_HEAD})")

# Enhanced learning rate scheduler
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=LR_HEAD if FREEZE_BACKBONE else [LR_BACKBONE, LR_HEAD],
    epochs=EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,
    anneal_strategy='cos'
)

# Gradient scaler for mixed precision
scaler = GradScaler()

print("✅ Training setup complete!")

## 6. Training and Evaluation Functions
Define training and evaluation functions with mixed precision and comprehensive metrics.

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, scaler, scheduler):
    """Train the model for one epoch with mixed precision."""
    model.train()
    running_loss = 0.0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        
        # Mixed precision forward pass
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # Mixed precision backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        running_loss += loss.item() * images.size(0)
    
    return running_loss / len(loader.dataset)

def evaluate(model, loader):
    """Evaluate the model and compute AUC scores for each class."""
    model.eval()
    all_labels = []
    all_outputs = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating"):
            images = images.to(DEVICE)
            
            with autocast():
                outputs = model(images)
            
            all_outputs.append(torch.sigmoid(outputs).cpu().numpy())
            all_labels.append(labels.numpy())
    
    all_outputs = np.concatenate(all_outputs)
    all_labels = np.concatenate(all_labels)
    
    # Compute AUC for each class
    aucs = []
    for i in range(NUM_CLASSES):
        try:
            # Only compute AUC if there are both positive and negative samples
            if len(np.unique(all_labels[:, i])) > 1:
                auc = roc_auc_score(all_labels[:, i], all_outputs[:, i])
            else:
                auc = np.nan
        except Exception as e:
            print(f"Error computing AUC for {LABELS[i]}: {e}")
            auc = np.nan
        aucs.append(auc)
    
    return aucs

## 7. Training Loop
Train the BiomedCLIP ViT-G/14 model with comprehensive logging and model checkpointing.

In [None]:
# Enhanced Training Loop with Two-Stage Training Strategy
best_mean_auc = 0
training_history = {'train_loss': [], 'val_auc': [], 'mean_auc': []}

print("Starting Enhanced Two-Stage Training...")
print(f"Stage 1: Frozen backbone training ({EPOCHS//2} epochs)")
print(f"Stage 2: Fine-tuning entire model ({EPOCHS//2} epochs)")
print("-" * 80)

# Stage 1: Train with frozen backbone
print("\n=== STAGE 1: FROZEN BACKBONE TRAINING ===")
for epoch in range(EPOCHS//2):
    print(f"\nEpoch {epoch+1}/{EPOCHS//2} (Stage 1)")
    print("-" * 40)
    
    # Training phase
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, scaler, scheduler)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validation phase
    aucs = evaluate(model, valid_loader)
    mean_auc = np.nanmean(aucs)
    
    # Log results for each class
    print("\nClass-wise AUC scores:")
    for i, label in enumerate(LABELS):
        if not np.isnan(aucs[i]):
            print(f"  {label:25}: AUC = {aucs[i]:.4f}")
        else:
            print(f"  {label:25}: AUC = N/A (insufficient data)")
    
    print(f"\nMean AUC: {mean_auc:.4f}")
    print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save training history
    training_history['train_loss'].append(train_loss)
    training_history['val_auc'].append(aucs)
    training_history['mean_auc'].append(mean_auc)
    
    # Save best model
    if mean_auc > best_mean_auc:
        best_mean_auc = mean_auc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_mean_auc': best_mean_auc,
            'aucs': aucs,
            'labels': LABELS,
            'stage': 1
        }, 'chexpert_biomedclip_vit_best.pth')
        print(f"🎉 New best model saved! Mean AUC: {best_mean_auc:.4f}")

# Stage 2: Unfreeze backbone for fine-tuning
print("\n=== STAGE 2: FINE-TUNING ENTIRE MODEL ===")
if hasattr(model, 'unfreeze_backbone'):
    model.unfreeze_backbone()
    
    # Create new optimizer with different learning rates for backbone and head
    if hasattr(model, 'clip_model'):
        backbone_params = model.clip_model.parameters()
    else:
        backbone_params = model.backbone.parameters()
        
    optimizer = optim.AdamW([
        {'params': backbone_params, 'lr': LR_BACKBONE},
        {'params': model.classifier.parameters(), 'lr': LR_HEAD}
    ], weight_decay=WEIGHT_DECAY)
    
    # New scheduler for fine-tuning
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=[LR_BACKBONE, LR_HEAD],
        epochs=EPOCHS//2,
        steps_per_epoch=len(train_loader),
        pct_start=0.2,
        anneal_strategy='cos'
    )
    
    print(f"✅ Fine-tuning setup complete (Backbone LR: {LR_BACKBONE}, Head LR: {LR_HEAD})")

for epoch in range(EPOCHS//2, EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS} (Stage 2)")
    print("-" * 40)
    
    # Training phase
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, scaler, scheduler)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validation phase
    aucs = evaluate(model, valid_loader)
    mean_auc = np.nanmean(aucs)
    
    # Log results for each class
    print("\nClass-wise AUC scores:")
    for i, label in enumerate(LABELS):
        if not np.isnan(aucs[i]):
            print(f"  {label:25}: AUC = {aucs[i]:.4f}")
        else:
            print(f"  {label:25}: AUC = N/A (insufficient data)")
    
    print(f"\nMean AUC: {mean_auc:.4f}")
    print(f"Current Backbone LR: {optimizer.param_groups[0]['lr']:.6f}")
    print(f"Current Head LR: {optimizer.param_groups[1]['lr']:.6f}")
    
    # Save training history
    training_history['train_loss'].append(train_loss)
    training_history['val_auc'].append(aucs)
    training_history['mean_auc'].append(mean_auc)
    
    # Save best model
    if mean_auc > best_mean_auc:
        best_mean_auc = mean_auc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_mean_auc': best_mean_auc,
            'aucs': aucs,
            'labels': LABELS,
            'stage': 2
        }, 'chexpert_biomedclip_vit_best.pth')
        print(f"🎉 New best model saved! Mean AUC: {best_mean_auc:.4f}")
    
    print("-" * 40)

print("\n" + "=" * 80)
print(f"Training completed! Best Mean AUC: {best_mean_auc:.4f}")
print("=" * 80)

In [None]:
# Performance Analysis and Optimization
print("\n" + "=" * 60)
print("PERFORMANCE ANALYSIS")
print("=" * 60)

# Calculate improvement metrics
if len(training_history['mean_auc']) >= 2:
    stage1_best = max(training_history['mean_auc'][:EPOCHS//2])
    stage2_best = max(training_history['mean_auc'][EPOCHS//2:])
    improvement = stage2_best - stage1_best
    
    print(f"Stage 1 Best AUC: {stage1_best:.4f}")
    print(f"Stage 2 Best AUC: {stage2_best:.4f}")
    print(f"Fine-tuning Improvement: {improvement:.4f} ({improvement*100:.2f}%)")

# Identify best and worst performing classes
if len(training_history['val_auc']) > 0:
    best_aucs = training_history['val_auc'][np.argmax(training_history['mean_auc'])]
    class_performance = [(LABELS[i], auc) for i, auc in enumerate(best_aucs) if not np.isnan(auc)]
    class_performance.sort(key=lambda x: x[1], reverse=True)
    
    print(f"\nBest Performing Classes:")
    for i, (label, auc) in enumerate(class_performance[:3]):
        status = "✅" if auc >= 0.90 else "🟡" if auc >= 0.80 else "🔴"
        print(f"  {status} {label:25}: AUC = {auc:.4f}")
    
    print(f"\nWorst Performing Classes (Need Attention):")
    for i, (label, auc) in enumerate(class_performance[-3:]):
        status = "✅" if auc >= 0.90 else "🟡" if auc >= 0.80 else "🔴"
        print(f"  {status} {label:25}: AUC = {auc:.4f}")

# Recommendations for achieving 95%+ accuracy
print(f"\n" + "=" * 60)
print("RECOMMENDATIONS FOR 95%+ ACCURACY")
print("=" * 60)

if best_mean_auc < 0.95:
    print("Current performance is below 95% target. Consider:")
    print("• Increase training epochs to 100-150")
    print("• Use test-time augmentation (TTA)")
    print("• Implement ensemble of multiple models")
    print("• Add more data augmentation techniques")
    print("• Use different loss functions (e.g., AUC loss)")
    print("• Try different learning rate schedules")
    print("• Consider using larger image sizes (384x384 or 512x512)")
else:
    print("🎉 Congratulations! You've achieved 95%+ accuracy!")
    print("Consider these optimizations:")
    print("• Model compression for deployment")
    print("• Knowledge distillation")
    print("• Quantization for faster inference")

In [None]:
# Test-Time Augmentation (TTA) for Enhanced Performance
def evaluate_with_tta(model, loader, num_augmentations=5):
    """Evaluate model with test-time augmentation for better performance"""
    model.eval()
    all_labels = []
    all_outputs = []
    
    # Define TTA transforms
    tta_transforms = [
        A.Compose([
            A.Resize(IMG_SIZE, IMG_SIZE),
            A.HorizontalFlip(p=flip_p),
            A.Rotate(limit=rot_deg, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
            A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                       std=[0.26862954, 0.26130258, 0.27577711]),
            ToTensorV2()
        ]) for flip_p, rot_deg in [(0.0, 0), (1.0, 0), (0.0, 5), (0.0, -5), (0.5, 3)]
    ]
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating with TTA"):
            batch_predictions = []
            
            for tta_transform in tta_transforms:
                tta_images = []
                for img in images:
                    # Convert tensor back to numpy for augmentation
                    img_np = img.permute(1, 2, 0).numpy()
                    img_np = (img_np * np.array([0.26862954, 0.26130258, 0.27577711]) + 
                             np.array([0.48145466, 0.4578275, 0.40821073])) * 255
                    img_np = np.clip(img_np, 0, 255).astype(np.uint8)
                    
                    # Apply TTA transform
                    augmented = tta_transform(image=img_np)
                    tta_images.append(augmented['image'])
                
                tta_batch = torch.stack(tta_images).to(DEVICE)
                
                with autocast():
                    outputs = model(tta_batch)
                    batch_predictions.append(torch.sigmoid(outputs).cpu())
            
            # Average predictions across all augmentations
            avg_predictions = torch.stack(batch_predictions).mean(dim=0)
            all_outputs.append(avg_predictions.numpy())
            all_labels.append(labels.numpy())
    
    all_outputs = np.concatenate(all_outputs)
    all_labels = np.concatenate(all_labels)
    
    # Compute AUC for each class
    aucs = []
    for i in range(NUM_CLASSES):
        try:
            if len(np.unique(all_labels[:, i])) > 1:
                auc = roc_auc_score(all_labels[:, i], all_outputs[:, i])
            else:
                auc = np.nan
        except Exception as e:
            print(f"Error computing AUC for {LABELS[i]}: {e}")
            auc = np.nan
        aucs.append(auc)
    
    return aucs

print("✅ Test-Time Augmentation (TTA) function ready!")
print("Run the next cell to evaluate with TTA for potentially higher accuracy.")

## 8. Final Model Saving and Results Summary
Save the final model and display comprehensive training results.

In [None]:
# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'training_history': training_history,
    'final_mean_auc': training_history['mean_auc'][-1],
    'best_mean_auc': best_mean_auc,
    'config': {
        'model_name': 'vit_giant_patch14_224',
        'img_size': IMG_SIZE,
        'batch_size': BATCH_SIZE,
        'epochs': EPOCHS,
        'lr': LR,
        'weight_decay': WEIGHT_DECAY,
        'num_classes': NUM_CLASSES,
        'labels': LABELS
    }
}, 'chexpert_biomedclip_vit_final.pth')

print('✅ Final model saved as chexpert_biomedclip_vit_final.pth')
print('✅ Best model saved as chexpert_biomedclip_vit_best.pth')

# Display final results summary
print("\n" + "=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)
print(f"Model: BiomedCLIP ViT-G/14 (Giant Vision Transformer)")
print(f"Dataset: CheXpert")
print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs Trained: {EPOCHS}")
print(f"Total Parameters: {total_params:,}")
print(f"Best Mean AUC: {best_mean_auc:.4f}")
print(f"Final Mean AUC: {training_history['mean_auc'][-1]:.4f}")
print(f"Final Train Loss: {training_history['train_loss'][-1]:.4f}")
print("=" * 60)

# Display best performing classes
if len(training_history['val_auc']) > 0:
    best_aucs = training_history['val_auc'][np.argmax(training_history['mean_auc'])]
    valid_aucs = [(LABELS[i], auc) for i, auc in enumerate(best_aucs) if not np.isnan(auc)]
    valid_aucs.sort(key=lambda x: x[1], reverse=True)
    
    print("\nBest Model Performance by Class:")
    for label, auc in valid_aucs[:5]:  # Top 5
        print(f"  {label:25}: AUC = {auc:.4f}")
    
    if len(valid_aucs) > 5:
        print("  ...")
        for label, auc in valid_aucs[-3:]:  # Bottom 3
            print(f"  {label:25}: AUC = {auc:.4f}")