In [None]:
# ==========================================
# SECTION 1: SETUP & INSTALLATIONS
# ==========================================
# Run this first - installs all required packages

!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install kaggle opendatasets pillow matplotlib seaborn scikit-learn
!pip install pytorch-grad-cam
!pip install timm  # For better model architectures

print(" All packages installed!")

# ==========================================
# SECTION 2: IMPORTS
# ==========================================

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import shutil
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# ==========================================
# SECTION 3: DATASET DOWNLOAD
# ==========================================
# This downloads ~5GB of data, takes 10-15 minutes

import opendatasets as od

# Dataset 1: Kaggle Chest X-Ray (Pneumonia)
print("Downloading Pneumonia Dataset...")
od.download("https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia")

# Dataset 2: COVID-19 Radiography Database
print("Downloading COVID-19 Dataset...")
od.download("https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database")

# Dataset 3: TB Chest X-ray Database (Alternative public source)
print("Downloading TB Dataset...")
od.download("https://www.kaggle.com/datasets/raddar/tuberculosis-chest-xrays-shenzhen")

print(" All datasets downloaded!")

# ==========================================
# SECTION 4: DATA ORGANIZATION
# ==========================================
# Organize all datasets into a unified structure

# Create organized directory structure
base_dir = '/content/xray_dataset'
os.makedirs(base_dir, exist_ok=True)

for split in ['train', 'val', 'test']:
    for category in ['NORMAL', 'PNEUMONIA', 'COVID', 'TB']:
        os.makedirs(f'{base_dir}/{split}/{category}', exist_ok=True)

print(" Directory structure created!")

# Function to copy and organize images
def organize_images(source_paths, dest_base, label, split_ratio={'train': 0.7, 'val': 0.15, 'test': 0.15}):
    """Organize images into train/val/test splits"""
    all_images = []

    for source_path in source_paths:
        if os.path.exists(source_path):
            images = [os.path.join(source_path, f) for f in os.listdir(source_path)
                     if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            all_images.extend(images)

    # Shuffle
    np.random.shuffle(all_images)

    # Split
    n_train = int(len(all_images) * split_ratio['train'])
    n_val = int(len(all_images) * split_ratio['val'])

    train_imgs = all_images[:n_train]
    val_imgs = all_images[n_train:n_train + n_val]
    test_imgs = all_images[n_train + n_val:]

    # Copy files
    for split, imgs in [('train', train_imgs), ('val', val_imgs), ('test', test_imgs)]:
        dest_dir = f'{dest_base}/{split}/{label}'
        for i, img_path in enumerate(tqdm(imgs, desc=f'{label} {split}')):
            try:
                shutil.copy(img_path, f'{dest_dir}/{label}_{split}_{i}.jpg')
            except Exception as e:
                continue

    print(f" {label}: Train={len(train_imgs)}, Val={len(val_imgs)}, Test={len(test_imgs)}")

# Organize NORMAL images (from pneumonia dataset)
print(" Organizing NORMAL images...")
organize_images(
    ['/content/chest-xray-pneumonia/chest_xray/train/NORMAL',
     '/content/chest-xray-pneumonia/chest_xray/test/NORMAL'],
    base_dir, 'NORMAL'
)

# Organize PNEUMONIA images
print(" Organizing PNEUMONIA images...")
organize_images(
    ['/content/chest-xray-pneumonia/chest_xray/train/PNEUMONIA',
     '/content/chest-xray-pneumonia/chest_xray/test/PNEUMONIA'],
    base_dir, 'PNEUMONIA'
)

# Organize COVID images
print(" Organizing COVID images...")
organize_images(
    ['/content/covid19-radiography-database/COVID-19_Radiography_Dataset/COVID/images'],
    base_dir, 'COVID'
)

# Organize TB images
print(" Organizing TB images...")
organize_images(
    ['/content/tuberculosis-chest-xrays-shenzhen/images'], # Corrected path
    base_dir, 'TB'
)

print(" All data organized!")

# ==========================================
# SECTION 5: DATA EXPLORATION
# ==========================================

# Count images in each category
def count_images(base_path):
    stats = {}
    for split in ['train', 'val', 'test']:
        stats[split] = {}
        for category in ['NORMAL', 'PNEUMONIA', 'COVID', 'TB']:
            path = f'{base_path}/{split}/{category}'
            stats[split][category] = len([f for f in os.listdir(path) if f.endswith('.jpg')])
    return stats

stats = count_images(base_dir)
print(" Dataset Statistics:")
print(pd.DataFrame(stats).T)

# Visualize distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for idx, split in enumerate(['train', 'val', 'test']):
    data = stats[split]
    axes[idx].bar(data.keys(), data.values(), color=['green', 'blue', 'red', 'orange'])
    axes[idx].set_title(f'{split.upper()} Set Distribution')
    axes[idx].set_ylabel('Number of Images')
    axes[idx].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# ==========================================
# SECTION 6: DATASET CLASS
# ==========================================

class ChestXrayDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.classes = ['NORMAL', 'PNEUMONIA', 'COVID', 'TB']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        # Load all image paths and labels
        self.images = []
        self.labels = []

        for class_name in self.classes:
            class_dir = os.path.join(root_dir, split, class_name)
            if os.path.exists(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.images.append(os.path.join(class_dir, img_name))
                        self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# ==========================================
# SECTION 7: DATA AUGMENTATION & TRANSFORMS
# ==========================================

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = ChestXrayDataset(base_dir, split='train', transform=train_transform)
val_dataset = ChestXrayDataset(base_dir, split='val', transform=val_transform)
test_dataset = ChestXrayDataset(base_dir, split='test', transform=val_transform)

print(f" Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# ==========================================
# SECTION 8: DATA LOADERS
# ==========================================

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(" Data loaders created!")

# Visualize sample batch
def show_batch(dataloader, n=8):
    batch = next(iter(dataloader))
    images, labels = batch

    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()

    classes = ['NORMAL', 'PNEUMONIA', 'COVID', 'TB']

    for i in range(min(n, len(images))):
        img = images[i].cpu().numpy().transpose(1, 2, 0)
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)

        axes[i].imshow(img)
        axes[i].set_title(f"Label: {classes[labels[i]]}", fontsize=12)
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

print("\ Sample images from training set:")
show_batch(train_loader)

# ==========================================
# SECTION 9: MODEL ARCHITECTURE
# ==========================================

class XRayClassifier(nn.Module):
    def __init__(self, num_classes=4, pretrained=True):
        super(XRayClassifier, self).__init__()

        # Use EfficientNet-B0 (better than ResNet50 for medical images)
        self.backbone = models.efficientnet_b0(pretrained=pretrained)

        # Replace classifier
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = XRayClassifier(num_classes=4, pretrained=True).to(device)

print(f" Model created on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# ==========================================
# SECTION 10: LOSS & OPTIMIZER
# ==========================================

# Calculate class weights for imbalanced dataset
class_counts = [stats['train'][cls] for cls in ['NORMAL', 'PNEUMONIA', 'COVID', 'TB']]
class_weights = torch.FloatTensor([1.0 / c if c > 0 else 0.0 for c in class_counts]) # Handle zero counts
# If you intend to completely ignore classes with zero samples, you might further normalize only non-zero weights.
# For now, setting to 0.0 will ensure no division by zero, but loss won't be impacted by this class.
# A more robust solution might involve removing the class entirely if it has no samples.
class_weights = class_weights / class_weights.sum() * 4  # Normalize
class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

print(" Loss function and optimizer configured!")

# ==========================================
# SECTION 11: TRAINING FUNCTIONS
# ==========================================

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(dataloader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_postfix({'loss': running_loss/len(dataloader), 'acc': 100.*correct/total})

    return running_loss / len(dataloader), 100. * correct / total

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            pbar.set_postfix({'loss': running_loss/len(dataloader), 'acc': 100.*correct/total})

    return running_loss / len(dataloader), 100. * correct / total, all_preds, all_labels

# ==========================================
# SECTION 12: TRAINING LOOP
# ==========================================

num_epochs = 20
best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

print("\nðŸš€ Starting Training...\n")

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 50)

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

    # Validate
    val_loss, val_acc, val_preds, val_labels = validate_epoch(model, val_loader, criterion, device)

    # Update scheduler
    scheduler.step(val_loss)

    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
        }, 'best_xray_model.pth')
        print(f"New best model saved! Val Acc: {val_acc:.2f}%")

    print()

print(f" Training Complete! Best Val Accuracy: {best_val_acc:.2f}%")

# ==========================================
# SECTION 13: PLOT TRAINING HISTORY
# ==========================================

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

# ==========================================
# SECTION 14: TEST EVALUATION
# ==========================================

# Load best model
checkpoint = torch.load('best_xray_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(" Best model loaded!")

# Evaluate on test set
test_loss, test_acc, test_preds, test_labels = validate_epoch(model, test_loader, criterion, device)

print(f" Test Set Performance:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")

# Classification Report
classes = ['NORMAL', 'PNEUMONIA', 'COVID', 'TB']
print(" Classification Report:")
print(classification_report(test_labels, test_preds, target_names=classes))

# Confusion Matrix
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.title('Confusion Matrix - Test Set')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Per-class accuracy
class_accuracy = cm.diagonal() / cm.sum(axis=1)
for i, cls in enumerate(classes):
    print(f"{cls} Accuracy: {class_accuracy[i]*100:.2f}%")

# ==========================================
# SECTION 15: EXPORT MODEL
# ==========================================

# Export for deployment
torch.save(model.state_dict(), 'xray_model_weights.pth')
torch.save(model, 'xray_model_complete.pth')

# Save model info
model_info = {
    'accuracy': test_acc,
    'classes': classes,
    'input_size': (224, 224),
    'num_parameters': sum(p.numel() for p in model.parameters())
}

import json
with open('model_info.json', 'w') as f:
    json.dump(model_info, f, indent=2)

print(" Model exported successfully!")
print("Files saved:")
print("  - best_xray_model.pth (checkpoint)")
print("  - xray_model_weights.pth (weights only)")
print("  - xray_model_complete.pth (complete model)")
print("  - model_info.json (metadata)")

# ==========================================
# SECTION 16: GRADCAM IMPLEMENTATION
# ==========================================

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

class GradCAMVisualizer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        # Target the last convolutional layer
        target_layer = model.backbone.features[-1]
        self.cam = GradCAM(model=model, target_layers=[target_layer])

    def generate_heatmap(self, image_tensor, predicted_class):
        """Generate GradCAM heatmap for an image"""
        grayscale_cam = self.cam(input_tensor=image_tensor.unsqueeze(0),
                                  targets=[predicted_class])
        return grayscale_cam[0]

# Test GradCAM
gradcam = GradCAMVisualizer(model, device)

# Get sample images
sample_images, sample_labels = next(iter(test_loader))
sample_images = sample_images[:4].to(device)
sample_labels = sample_labels[:4]

model.eval()
with torch.no_grad():
    outputs = model(sample_images)
    _, predictions = outputs.max(1)

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    # Original image
    img = sample_images[i].cpu().numpy().transpose(1, 2, 0)
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img = np.clip(img, 0, 1)

    # Generate heatmap
    heatmap = gradcam.generate_heatmap(sample_images[i], predictions[i].item())

    # Overlay
    cam_image = show_cam_on_image(img, heatmap, use_rgb=True)

    # Plot
    axes[0, i].imshow(img)
    axes[0, i].set_title(f"True: {classes[sample_labels[i]]}", fontsize=10)
    axes[0, i].axis('off')

    axes[1, i].imshow(cam_image)
    axes[1, i].set_title(f"Pred: {classes[predictions[i]]}", fontsize=10)
    axes[1, i].axis('off')

plt.suptitle('GradCAM Visualizations - Red areas show where AI is looking', fontsize=14)
plt.tight_layout()
plt.show()

print("GradCAM visualization complete!")

# ==========================================
# SECTION 17: INFERENCE FUNCTION
# ==========================================

def predict_xray(model, image_path, device, transform):
    """
    Predict disease from X-ray image

    Returns:
        prediction: class name
        confidence: probability score
        heatmap: GradCAM visualization
    """
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Predict
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = F.softmax(output, dim=1)
        confidence, predicted = probabilities.max(1)

    # Get class name
    classes = ['NORMAL', 'PNEUMONIA', 'COVID', 'TB']
    prediction = classes[predicted.item()]
    confidence_score = confidence.item() * 100

    # Generate heatmap
    gradcam_viz = GradCAMVisualizer(model, device)
    heatmap = gradcam_viz.generate_heatmap(image_tensor[0], predicted.item())

    return prediction, confidence_score, heatmap

# Test inference
test_image_path = test_dataset.images[0]
prediction, confidence, heatmap = predict_xray(model, test_image_path, device, val_transform)

print(f" Inference Test:")
print(f"Prediction: {prediction}")
print(f"Confidence: {confidence:.2f}%")

# ==========================================
# SECTION 18: SAVE FOR DEPLOYMENT
# ==========================================

# Create deployment package
import zipfile

deployment_files = [
    'best_xray_model.pth',
    'xray_model_weights.pth',
    'xray_model_complete.pth',
    'model_info.json'
]

with zipfile.ZipFile('xray_model_deployment.zip', 'w') as zipf:
    for file in deployment_files:
        if os.path.exists(file):
            zipf.write(file)

print(" Deployment package created: xray_model_deployment.zip")
print(" ALL DONE! Your model is ready for deployment!")
print(f" Final Metrics:")
print(f"   Test Accuracy: {test_acc:.2f}%")
print(f"   Model Size: ~{os.path.getsize('xray_model_complete.pth') / (1024*1024):.1f} MB")
print(f"   Classes: {', '.join(classes)}")
print(f" Download the model files to deploy in your backend!")