# TB Chest X-Ray Classification with DenseNet121

This notebook demonstrates how to train a deep learning model to detect Tuberculosis (TB) from Chest X-Ray images. 

### System Overview
- **Architecture**: DenseNet121 (Transfer Learning from ImageNet)
- **Input**: 224x224 RGB Images
- **Classes**: Normal vs. Tuberculosis
- **Augmentations**: Rotation, Perspective Shift, Color Jitter, Gaussian Blur (to simulate real-world mobile quality)
- **Framework**: PyTorch

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# --- Configuration ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 16
LEARNING_RATE = 0.001
EPOCHS = 10
DATA_DIR = "/kaggle/input/chest-xray-tb-dataset/data/images" # UPDATE THIS PATH ON KAGGLE

print(f"Using device: {DEVICE}")

## 1. Data Preprocessing & Augmentation

We use a custom Dataset class that applies robust data augmentation. These augmentations (rotation, blur, color jitter) are specifically chosen to make the model robust against low-quality images often captured by mobile phones in field settings.

In [None]:
class TBDataset(Dataset):
    # Use integer class labels: 0=Normal, 1=Tuberculosis
    classEncoding = {
        'Normal': 0,
        'Tuberculosis': 1
    }

    def __init__(self, root_dir):
        self.image_names = []
        self.labels = []
        
        categories = ['Normal', 'Tuberculosis']
        
        for category in categories:
            class_folder = os.path.join(root_dir, category)
            if not os.path.exists(class_folder):
                print(f"Warning: {class_folder} not found. Make sure dataset is attached.")
                continue
                
            for file_name in os.listdir(class_folder):
                if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    full_path = os.path.join(class_folder, file_name)
                    self.image_names.append(full_path)
                    self.labels.append(torch.tensor(self.classEncoding[category], dtype=torch.long))

    def __len__(self):
        return len(self.image_names)
        
    def get_transforms(self):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return transforms.Compose([
            transforms.Resize(256),
            
            # --- Robustness Augmentations ---
            transforms.RandomRotation(degrees=15),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.1, hue=0.05),
            transforms.RandomGrayscale(p=0.1),
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2.0))], p=0.3),
            # -------------------------------
            
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])

    def __getitem__(self, index):
        image_path = self.image_names[index]
        image = Image.open(image_path).convert('RGB')
        
        preprocess = self.get_transforms()
        image = preprocess(image)
        return image, self.labels[index]

## 2. Model Architecture (DenseNet121)

We use DenseNet121 pre-trained on ImageNet. The final classification layer is replaced to output 2 classes (Normal, TB).

In [None]:
class DenseNet121(nn.Module):
    def __init__(self):
        super(DenseNet121, self).__init__()
        # Use ImageNet pretrained weights
        self.model = torchvision.models.densenet121(weights="DEFAULT")
        
        # Modifying the last layer for 2 classes
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(1024, 2)
        )

    def forward(self, x):
        x = self.model(x)
        return x

## 3. Training Loop

We implement a training loop that tracks loss and validation accuracy, saving the best model weights.

In [None]:
def train_model():
    # 1. Prepare Data
    full_dataset = TBDataset(DATA_DIR)
    if len(full_dataset) == 0:
        print("No images found. Check DATA_DIR path.")
        return
        
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_set, test_set = random_split(full_dataset, [train_size, test_size])
    
    trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    testloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    # 2. Setup Model
    model = DenseNet121().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    # 3. Training Loop
    best_acc = 0.0
    print(f"Starting training on {len(train_set)} images, validating on {len(test_set)} images...")
    
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for i, (images, labels) in enumerate(trainloader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        scheduler.step()
        epoch_loss = running_loss / len(trainloader)
        
        # Validation
        acc = evaluate(model, testloader)
        print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {epoch_loss:.4f} | Val Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "tb_model_best.pt")
            print("  -> Saved Best Model")
            
    return model, testloader

def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total

## 4. Run Training
Uncomment the line below to run training (requires uploaded dataset).

In [None]:
# model, testloader = train_model()

## 5. Evaluation & Confusion Matrix
Visualizing performance on the test set.

In [None]:
def plot_confusion_matrix(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(DEVICE)
            outputs = model(images)
            predicted = torch.argmax(outputs, dim=1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Normal', 'TB'], yticklabels=['Normal', 'TB'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()
    
    print(classification_report(all_labels, all_preds, target_names=['Normal', 'TB']))

# plot_confusion_matrix(model, testloader)