In [None]:
import os

# Silence TF / CUDA / cuDNN backend warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter  # ‚Üê still using TensorBoard
from torchvision import transforms, models

from tqdm import tqdm
import numpy as np
import random
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, 
    f1_score, roc_auc_score, confusion_matrix, roc_curve
)
import cv2

print("‚úì Imports loaded (TensorBoard warnings silenced)")


In [None]:
class ChestXrayDataset(Dataset):
    """Dataset for chest X-ray pneumonia classification"""
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        self.class_to_idx = {'NORMAL': 0, 'PNEUMONIA': 1}
        
        for class_name in ['NORMAL', 'PNEUMONIA']:
            class_folder = os.path.join(root_dir, class_name)
            label = self.class_to_idx[class_name]
            
            for filename in os.listdir(class_folder):
                if filename.endswith(('.jpeg', '.jpg', '.png')):
                    img_path = os.path.join(class_folder, filename)
                    self.image_paths.append(img_path)
                    self.labels.append(label)
        
        print(f"Loaded {len(self.image_paths)} images")
        print(f"  NORMAL: {self.labels.count(0)}")
        print(f"  PNEUMONIA: {self.labels.count(1)}")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

print("‚úì ChestXrayDataset defined")

In [None]:
config = {
    'data_dir': '/kaggle/input/chest-xray-pneumonia/chest_xray',
    'model_name': 'resnet18',
    'pretrained': True,
    'num_classes': 2,
    'learning_rate': 0.0001,
    'batch_size': 32,
    'num_epochs': 50,
    'patience': 5,
    'use_augmentation': True,
    'augmentation': {
        'random_horizontal_flip': True,
        'random_rotation': 10,
        'color_jitter_brightness': 0.2,
        'color_jitter_contrast': 0.2,
    },
    'use_class_weights': True,
    'class_weights': [3.0, 1.0],
    'optimizer': 'adam',
    'weight_decay': 1e-4,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'seed': 42,
    'save_dir': 'checkpoints/',
    'best_model_path': 'checkpoints/best_model.pth',
    'log_dir': 'runs/',
    'print_freq': 10,
}

print("Configuration:")
print(f"  Model: {config['model_name']}")
print(f"  Learning Rate: {config['learning_rate']}")
print(f"  Batch Size: {config['batch_size']}")
print(f"  Device: {config['device']}")

In [None]:
def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def create_model(model_name='resnet18', num_classes=2, pretrained=True):
    """Create model with transfer learning"""
    if model_name == 'resnet18':
        model = models.resnet18(pretrained=pretrained)
    elif model_name == 'resnet34':
        model = models.resnet34(pretrained=pretrained)
    elif model_name == 'resnet50':
        model = models.resnet50(pretrained=pretrained)
    
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    return model

def get_transforms(augment=False):
    """Get train and validation transforms"""
    if augment:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(config['augmentation']['random_rotation']),
            transforms.ColorJitter(
                brightness=config['augmentation']['color_jitter_brightness'],
                contrast=config['augmentation']['color_jitter_contrast']
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} [TRAIN]')
    
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    return running_loss / len(train_loader.dataset), 100. * correct / total

def validate(model, val_loader, criterion, device, epoch):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(val_loader, desc=f'Epoch {epoch+1} [VAL]')
    
    with torch.no_grad():
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    return running_loss / len(val_loader.dataset), 100. * correct / total

print("‚úì Training helper functions defined")

In [None]:
def train_model():
    """Main training function"""
    set_seed(config['seed'])
    
    os.makedirs(config['save_dir'], exist_ok=True)
    os.makedirs(config['log_dir'], exist_ok=True)
    
    writer = SummaryWriter(config['log_dir'])
    device = torch.device(config['device'])
    
    print("=" * 60)
    print("TRAINING STARTED")
    print("=" * 60)
    
    # Prepare data
    train_transform, val_transform = get_transforms(augment=config['use_augmentation'])
    
    train_dataset = ChestXrayDataset(
        root_dir=os.path.join(config['data_dir'], 'train'),
        transform=train_transform
    )
    val_dataset = ChestXrayDataset(
        root_dir=os.path.join(config['data_dir'], 'val'),
        transform=val_transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                             shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], 
                           shuffle=False, num_workers=2, pin_memory=True)
    
    # Model setup
    model = create_model(config['model_name'], config['num_classes'], config['pretrained'])
    model = model.to(device)
    
    if config['use_class_weights']:
        class_weights = torch.FloatTensor(config['class_weights']).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss()
    
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    
    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(config['num_epochs']):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
        val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)
        
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', train_acc, epoch)
        writer.add_scalar('Accuracy/val', val_acc, epoch)
        
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
        print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_acc': val_acc,
                'config': config
            }, config['best_model_path'])
            print(f"‚úì Best model saved!")
        else:
            patience_counter += 1
            print(f"No improvement ({patience_counter}/{config['patience']})")
        
        if patience_counter >= config['patience']:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break
    
    writer.close()
    print("\n" + "=" * 60)
    print("TRAINING COMPLETE!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print("=" * 60)

# Run training
train_model()

In [None]:
def load_model(checkpoint_path, device):
    """Load trained model"""
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 2)
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"‚úì Model loaded (Epoch {checkpoint['epoch']+1})")
    return model

def evaluate_model(model, test_loader, device):
    """Evaluate on test set"""
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
    
    return np.array(all_preds), np.array(all_labels), np.array(all_probs)

def generate_gradcam(model, image_tensor, target_layer, target_class=None):
    """Generate Grad-CAM heatmap"""
    model.eval()
    gradients, activations = [], []
    
    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])
    
    def forward_hook(module, input, output):
        activations.append(output)
    
    handle_backward = target_layer.register_full_backward_hook(backward_hook)
    handle_forward = target_layer.register_forward_hook(forward_hook)
    
    output = model(image_tensor)
    if target_class is None:
        target_class = output.argmax(dim=1).item()
    
    model.zero_grad()
    output[0, target_class].backward()
    
    grads = gradients[0].cpu().data.numpy()[0]
    acts = activations[0].cpu().data.numpy()[0]
    
    weights = np.mean(grads, axis=(1, 2))
    cam = np.zeros(acts.shape[1:])
    for i, w in enumerate(weights):
        cam += w * acts[i]
    
    cam = np.maximum(cam, 0)
    cam = cam / (cam.max() + 1e-8)
    
    handle_backward.remove()
    handle_forward.remove()
    
    return cam

print("‚úì Evaluation functions defined")

In [None]:
device = torch.device(config['device'])

# Load best model
model = load_model(config['best_model_path'], device)

# Prepare test data
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset = ChestXrayDataset(
    root_dir=os.path.join(config['data_dir'], 'test'),
    transform=test_transform
)

test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], 
                        shuffle=False, num_workers=2)

# Evaluate
y_pred, y_true, y_probs = evaluate_model(model, test_loader, device)

# Print metrics
print("\n" + "=" * 60)
print("TEST RESULTS")
print("=" * 60)
print(f"Accuracy:  {accuracy_score(y_true, y_pred)*100:.2f}%")
print(f"Precision: {precision_score(y_true, y_pred)*100:.2f}%")
print(f"Recall:    {recall_score(y_true, y_pred)*100:.2f}%")
print(f"F1 Score:  {f1_score(y_true, y_pred):.4f}")
print(f"AUC-ROC:   {roc_auc_score(y_true, y_probs):.4f}")

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['NORMAL', 'PNEUMONIA'],
            yticklabels=['NORMAL', 'PNEUMONIA'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150)
plt.show()

print("\n‚úì Evaluation complete!")

In [None]:
# Visualize Grad-CAM for 5 random samples
num_samples = 5
indices = np.random.choice(len(test_dataset), num_samples, replace=False)

fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))

for idx, sample_idx in enumerate(indices):
    image, label = test_dataset[sample_idx]
    image_tensor = image.unsqueeze(0).to(device)
    
    # Get prediction
    with torch.no_grad():
        output = model(image_tensor)
        pred_class = output.argmax(dim=1).item()
        prob = F.softmax(output, dim=1)[0, pred_class].item()
    
    # Generate Grad-CAM
    heatmap = generate_gradcam(model, image_tensor, target_layer=model.layer4)
    heatmap_resized = cv2.resize(heatmap, (224, 224))
    
    # Denormalize image
    img_np = image.cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = std * img_np + mean
    img_np = np.clip(img_np, 0, 1)
    
    # Plot
    axes[idx, 0].imshow(img_np)
    axes[idx, 0].set_title(f'True: {"PNEUMONIA" if label==1 else "NORMAL"}')
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(heatmap_resized, cmap='jet')
    axes[idx, 1].set_title(f'Pred: {"PNEUMONIA" if pred_class==1 else "NORMAL"} ({prob:.1%})')
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(img_np, alpha=0.7)
    axes[idx, 2].imshow(heatmap_resized, cmap='jet', alpha=0.3)
    axes[idx, 2].set_title('Overlay')
    axes[idx, 2].axis('off')

plt.tight_layout()
plt.savefig('gradcam.png', dpi=150)
plt.show()

print("‚úì Grad-CAM visualizations saved!")


# ============================================
# ROC CURVE
# ============================================
# Run this AFTER you have: y_true, y_pred, y_probs

from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt
import numpy as np

print("\n" + "="*60)
print("ROC CURVE ANALYSIS")
print("="*60)

# Calculate ROC curve
fpr, tpr, thresholds = roc_curve(y_true, y_probs)
auc_score = roc_auc_score(y_true, y_probs)

# Create figure
plt.figure(figsize=(10, 8))

# Plot ROC curve
plt.plot(fpr, tpr, linewidth=3, 
         label=f'ROC Curve (AUC = {auc_score:.4f})', 
         color='#2E86AB')

# Plot diagonal (random classifier)
plt.plot([0, 1], [0, 1], 'k--', linewidth=2, 
         label='Random Classifier', alpha=0.5)

# Fill area under curve
plt.fill_between(fpr, tpr, alpha=0.2, color='#2E86AB')

# Mark optimal operating point (Youden's index)
# This finds the threshold that maximizes sensitivity + specificity
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
plt.plot(fpr[optimal_idx], tpr[optimal_idx], 'ro', markersize=12, 
         label=f'Optimal Point (threshold={optimal_threshold:.3f})')

# Formatting
plt.xlabel('False Positive Rate', fontsize=14, fontweight='bold')
plt.ylabel('True Positive Rate (Sensitivity)', fontsize=14, fontweight='bold')
plt.title(f'ROC Curve - Pneumonia Detection\nAUC = {auc_score:.4f}', 
          fontsize=15, fontweight='bold', pad=20)
plt.legend(loc='lower right', fontsize=12, framealpha=0.9)
plt.grid(alpha=0.3, linestyle='--')
plt.xlim([-0.02, 1.02])
plt.ylim([-0.02, 1.02])

plt.tight_layout()
plt.show()

# Print interpretation
print(f"\nROC-AUC Score: {auc_score:.4f}")
print("\nInterpretation:")
if auc_score >= 0.9:
    print("  ‚≠ê EXCELLENT - Model has excellent discrimination")
elif auc_score >= 0.8:
    print("  ‚úÖ GOOD - Model has good discrimination")
elif auc_score >= 0.7:
    print("  ‚ö†Ô∏è  FAIR - Model has acceptable discrimination")
else:
    print("  ‚ùå POOR - Model discrimination is poor")

print(f"\nOptimal Operating Point:")
print(f"  Threshold: {optimal_threshold:.3f}")
print(f"  Sensitivity (TPR): {tpr[optimal_idx]:.3f}")
print(f"  Specificity (1-FPR): {1-fpr[optimal_idx]:.3f}")

print("="*60)

In [None]:
# ============================================
# SAVE ALL VISUALIZATIONS TO FILES
# ============================================
# Run this in a NEW CELL at the END of your notebook
# Make sure you already have: y_true, y_pred, y_probs, model, test_dataset

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import torch
import torch.nn.functional as F
import cv2

print("="*60)
print("SAVING ALL VISUALIZATIONS TO FILES")
print("="*60)

# ============================================
# 1. CONFUSION MATRIX
# ============================================
print("\n[1/4] Generating confusion matrix...")

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['NORMAL', 'PNEUMONIA'],
            yticklabels=['NORMAL', 'PNEUMONIA'],
            cbar_kws={'label': 'Count'},
            annot_kws={'size': 16, 'weight': 'bold'})

# Add percentages
for i in range(2):
    for j in range(2):
        percentage = cm[i,j] / cm[i].sum() * 100
        plt.text(j + 0.5, i + 0.75, 
                f'({percentage:.1f}%)',
                ha="center", va="center", 
                color="gray", fontsize=11)

plt.ylabel('True Label', fontsize=14, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')

# Calculate metrics for title
sensitivity = cm[1,1] / (cm[1,0] + cm[1,1]) * 100
specificity = cm[0,0] / (cm[0,0] + cm[0,1]) * 100

plt.title(f'Confusion Matrix - Pneumonia Detection\n' + 
          f'Sensitivity: {sensitivity:.1f}% | Specificity: {specificity:.1f}%', 
          fontsize=15, fontweight='bold', pad=20)

plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight', facecolor='white')
print("   ‚úì Saved: confusion_matrix.png")
plt.close()


# ============================================
# 2. ROC CURVE
# ============================================
print("\n[2/4] Generating ROC curve...")

fpr, tpr, thresholds = roc_curve(y_true, y_probs)
auc_score = roc_auc_score(y_true, y_probs)

plt.figure(figsize=(10, 8))

# Plot ROC
plt.plot(fpr, tpr, linewidth=3, 
         label=f'ROC Curve (AUC = {auc_score:.4f})', 
         color='#2E86AB')
plt.plot([0, 1], [0, 1], 'k--', linewidth=2, 
         label='Random Classifier', alpha=0.5)
plt.fill_between(fpr, tpr, alpha=0.2, color='#2E86AB')

# Mark optimal point
optimal_idx = np.argmax(tpr - fpr)
plt.plot(fpr[optimal_idx], tpr[optimal_idx], 'ro', markersize=12, 
         label=f'Optimal (threshold={thresholds[optimal_idx]:.3f})')

plt.xlabel('False Positive Rate', fontsize=14, fontweight='bold')
plt.ylabel('True Positive Rate', fontsize=14, fontweight='bold')
plt.title(f'ROC Curve - Pneumonia Detection\nAUC = {auc_score:.4f}', 
          fontsize=15, fontweight='bold', pad=20)
plt.legend(loc='lower right', fontsize=12)
plt.grid(alpha=0.3, linestyle='--')
plt.xlim([-0.02, 1.02])
plt.ylim([-0.02, 1.02])

plt.tight_layout()
plt.savefig('roc_curve.png', dpi=300, bbox_inches='tight', facecolor='white')
print("   ‚úì Saved: roc_curve.png")
plt.close()


# ============================================
# 3. METRICS SUMMARY
# ============================================
print("\n[3/4] Generating metrics summary...")

metrics = {
    'Accuracy': accuracy_score(y_true, y_pred),
    'Precision': precision_score(y_true, y_pred),
    'Recall': recall_score(y_true, y_pred),
    'F1 Score': f1_score(y_true, y_pred),
    'AUC-ROC': auc_score
}

fig, ax = plt.subplots(figsize=(12, 7))
colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E']
bars = ax.barh(list(metrics.keys()), list(metrics.values()), 
               color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)

# Add value labels
for i, (metric, value) in enumerate(metrics.items()):
    ax.text(value + 0.02, i, f'{value:.4f}', 
            va='center', fontsize=13, fontweight='bold')

# Add excellence threshold line
ax.axvline(x=0.9, color='gray', linestyle='--', alpha=0.5, linewidth=2)
ax.text(0.91, 4.3, 'Excellent\n(>0.9)', fontsize=10, 
        color='gray', ha='left', style='italic')

ax.set_xlim([0, 1.15])
ax.set_xlabel('Score', fontsize=14, fontweight='bold')
ax.set_title('Model Performance Metrics Summary', 
             fontsize=15, fontweight='bold', pad=20)
ax.grid(axis='x', alpha=0.3, linestyle='--')

# Add interpretation note
ax.text(0.5, -0.15, 
        '‚úì High Recall (99.49%) prioritized for medical screening | ' +
        '‚úì AUC-ROC (0.95+) indicates excellent discrimination',
        ha='center', fontsize=11, style='italic', color='#555',
        transform=ax.transAxes)

plt.tight_layout()
plt.savefig('metrics_summary.png', dpi=300, bbox_inches='tight', facecolor='white')
print("   ‚úì Saved: metrics_summary.png")
plt.close()


# ============================================
# 4. GRAD-CAM VISUALIZATIONS
# ============================================
print("\n[4/4] Generating Grad-CAM visualizations...")

# Select diverse examples
def select_samples(y_true, y_pred, y_probs):
    """Select 5 interesting samples"""
    samples = []
    
    # True Positive (high confidence)
    tp = np.where((y_true == 1) & (y_pred == 1))[0]
    if len(tp) > 0:
        samples.append(tp[np.argmax(y_probs[tp])])
    
    # True Negative (high confidence)
    tn = np.where((y_true == 0) & (y_pred == 0))[0]
    if len(tn) > 0:
        samples.append(tn[np.argmin(y_probs[tn])])
    
    # False Positive
    fp = np.where((y_true == 0) & (y_pred == 1))[0]
    if len(fp) > 0:
        samples.append(fp[0])
    
    # False Negative
    fn = np.where((y_true == 1) & (y_pred == 0))[0]
    if len(fn) > 0:
        samples.append(fn[0])
    
    # Borderline case
    border = np.argmin(np.abs(y_probs - 0.5))
    if border not in samples:
        samples.append(border)
    
    return samples[:5]

selected = select_samples(y_true, y_pred, y_probs)

fig, axes = plt.subplots(len(selected), 3, figsize=(15, 5*len(selected)))
if len(selected) == 1:
    axes = axes.reshape(1, -1)

for idx, sample_idx in enumerate(selected):
    # Load image
    image, label = test_dataset[sample_idx]
    image_tensor = image.unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(image_tensor)
        pred_class = output.argmax(dim=1).item()
        prob = F.softmax(output, dim=1)[0, pred_class].item()
    
    # Generate Grad-CAM
    heatmap = generate_gradcam(model, image_tensor, target_layer=model.layer4)
    heatmap_resized = cv2.resize(heatmap, (224, 224))
    
    # Denormalize
    img_np = image.cpu().numpy().transpose(1, 2, 0)
    img_np = np.array([0.229, 0.224, 0.225]) * img_np + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)
    
    # Check correctness
    correct = label == pred_class
    marker = "‚úì CORRECT" if correct else "‚úó INCORRECT"
    color = 'green' if correct else 'red'
    
    # Plot original
    axes[idx, 0].imshow(img_np, cmap='gray')
    axes[idx, 0].set_title(
        f'{marker}\nTrue: {"PNEUMONIA" if label==1 else "NORMAL"}',
        fontsize=13, fontweight='bold', color=color
    )
    axes[idx, 0].axis('off')
    
    # Plot heatmap
    axes[idx, 1].imshow(heatmap_resized, cmap='jet')
    axes[idx, 1].set_title(
        f'Grad-CAM Heatmap\nPred: {"PNEUMONIA" if pred_class==1 else "NORMAL"}',
        fontsize=13, fontweight='bold'
    )
    axes[idx, 1].axis('off')
    
    # Plot overlay
    axes[idx, 2].imshow(img_np, cmap='gray', alpha=0.7)
    axes[idx, 2].imshow(heatmap_resized, cmap='jet', alpha=0.4)
    axes[idx, 2].set_title(
        f'Overlay\nConfidence: {prob*100:.1f}%',
        fontsize=13, fontweight='bold'
    )
    axes[idx, 2].axis('off')

plt.suptitle('Grad-CAM Model Attention Analysis', 
             fontsize=17, fontweight='bold', y=0.998)
plt.tight_layout()
plt.savefig('gradcam_visualizations.png', dpi=300, bbox_inches='tight', facecolor='white')
print("   ‚úì Saved: gradcam_visualizations.png")
plt.close()


# ============================================
# SUMMARY
# ============================================
print("\n" + "="*60)
print("‚úÖ ALL VISUALIZATIONS SAVED SUCCESSFULLY!")
print("="*60)
print("\nGenerated Files:")
print("  1. confusion_matrix.png      (Performance matrix)")
print("  2. roc_curve.png              (ROC analysis)")
print("  3. metrics_summary.png        (Bar chart of metrics)")
print("  4. gradcam_visualizations.png (Model interpretability)")
print("\nüì• TO DOWNLOAD FROM KAGGLE:")
print("  ‚Üí Look at right sidebar ‚Üí 'Output' section")
print("  ‚Üí Click download icon (‚¨áÔ∏è) next to each .png file")
print("  ‚Üí Save to your computer's images/ folder")
print("="*60)