# Respiratory Disease Classification
***
## Table of Contents
***

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
import warnings
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
from PIL import Image

from sklearn.model_selection import train_test_split
from numpy.typing import NDArray
from torchvision import datasets, transforms, models
from torchinfo import summary
from torchmetrics import Accuracy, F1Score
from pathlib import Path
from sklearn.metrics import confusion_matrix, classification_report
from torch.utils.data import DataLoader, Subset, Dataset

In [None]:
# Set a seed for reproducibility.
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)

## 1. Introduction


## 2. Device Agnostic Code
Mac GPU acceleration (`mps` backend) delivers significant speed-up over CPU for deep learning tasks, especially for large models and batch sizes. On Windows, `cuda` is used instead of `mps`.

In [None]:
# DEVICE = torch.device(
#     device="cuda" if torch.cuda.is_available() else "cpu"
# )  # For Windows
DEVICE = torch.device(
    device="mps" if torch.backends.mps.is_available() else "cpu"
)  # For MacOS
DEVICE

## 3. Loading Data
Retrieved from [COVID-19 Radiography Database](https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database/data)

In [None]:
DATA_PATH = Path("_datasets/Radiography_Dataset")

if DATA_PATH.is_dir():
    print(f"{DATA_PATH} directory exists.")
else:
    print(f"{DATA_PATH} directory NOT FOUND!")

In [None]:
def walk_through_dir(dir_path: str) -> None:
    """
    Print the number of directories and image files in each subdirectory of a given directory.

    Args:
        dir_path: Path to the root directory to walk through.
    """

    for (
        directory_path,
        directory_names,
        file_names,
    ) in os.walk(top=dir_path):
        print(
            f"{len(directory_names)} directories and {len(file_names)} images found in {directory_path}"
        )

In [None]:
walk_through_dir(DATA_PATH)

In [None]:
class CustomDataset(Dataset):
    img_folder: str
    transform: transforms.Compose
    extensions: set[str]
    all_paths: list[str]
    all_labels: list[str]
    categories = set[str]
    label2id = dict[str, int]
    id2label = dict[int, str]
    all_labels_indices = list[int]

    def __init__(
        self, img_path, transform=None, extensions={".png", ".jpg", ".jpeg"}
    ) -> None:
        img_folder = "*/images/*"
        self.img_path = img_path
        self.transform = transform
        self.all_paths = [
            path
            for path in img_path.glob(img_folder)
            if path.suffix.lower() in extensions
        ]
        # All labels (2 folder levels above)
        self.all_labels = [path.parent.parent.name for path in self.all_paths]
        self.categories = sorted(set(self.all_labels))  # Unique labels
        self.label2id = {label: index for index, label in enumerate(self.categories)}
        self.id2label = {index: label for label, index in self.label2id.items()}
        self.encoded_labels = [self.label2id[label] for label in self.all_labels]

    def __len__(self) -> int:
        return len(self.all_paths)

    def __getitem__(self, index):
        single_file_path = self.all_paths[index]
        try:
            img = Image.open(fp=single_file_path).convert(mode="RGB")
        except Exception as e:
            print(f"Error opening image {single_file_path}: {e}")
        label_index = self.encoded_labels[index]

        if self.transform:
            img = self.transform(img)
        return img, label_index


In [None]:
dataset = CustomDataset(img_path=DATA_PATH)

## 4. Understanding Data

In [None]:
def show_random_images(dataset: Dataset) -> None:
    cols, rows = 3, 3
    figure = plt.figure(figsize=(rows * 3, cols * 3))
    for i in range(1, cols * rows + 1):
        sample_index = random.randint(a=0, b=len(dataset))
        img, label = dataset[sample_index]
        figure.add_subplot(rows, cols, i)
        plt.title(dataset.id2label[label])
        plt.axis("off")
        plt.tight_layout()
        plt.imshow(img)
    plt.show()

In [None]:
show_random_images(dataset)

In [None]:
dataset.id2label

In [None]:
unique_vals, counts = np.unique(dataset.all_labels, return_counts=True)
df_dist = pd.DataFrame({"Class Label": unique_vals, "Count": counts})
print(df_dist)

In [None]:
def plot_distribution(dataset: Dataset) -> None:
    unique_vals, counts = np.unique(dataset.all_labels, return_counts=True)
    df_dist = pd.DataFrame({"Class Label": unique_vals, "Count": counts})
    plt.figure(figsize=(10, 6))
    sns.barplot(
        data=df_dist, x="Class Label", y="Count", hue="Class Label", palette="Set2"
    )
    plt.xlabel("Class Label")
    plt.ylabel("Count")
    plt.title("Distribution of Class Labels")
    plt.tight_layout()
    plt.show()


plot_distribution(dataset)

## 5. Preparing Data
### Splitting Dataset

In [None]:
def split_dataset(dataset, random_seed):
    all_indices = range(len(dataset))
    all_labels = dataset.encoded_labels

    train_indices, sub_indices = train_test_split(
        all_indices, test_size=0.2, stratify=all_labels, random_state=random_seed
    )  # 80% train, 20% sub

    val_indices, test_indices = train_test_split(
        sub_indices,
        test_size=0.5,
        stratify=[all_labels[i] for i in sub_indices],
        random_state=random_seed,
    )  # sub -> 50% validation, 50% test

    train_data = Subset(dataset=dataset, indices=train_indices)
    val_data = Subset(dataset=dataset, indices=val_indices)
    test_data = Subset(dataset=dataset, indices=test_indices)
    return train_data, val_data, test_data

### Data Transformation
For transfer learning using pretrained models in PyTorch, it is a common and effective practice to normalise the dataset using the standard mean and standard deviation values of the ImageNet dataset, on which many pretrained models were originally trained. This ensures that the input data distribution matches the distribution expected by the pretrained model, leading to better convergence and improved performance during fine-tuning.

[Reference - PyTorch Forums](https://discuss.pytorch.org/t/discussion-why-normalise-according-to-imagenet-mean-and-std-dev-for-transfer-learning/115670)

In addition to normalisation, various data augmentation techniques (such as random flips and random rotations) are applied to increase data diversity and improve the model's generalisation capability.

In [None]:
BATCH_SIZE = 32
IMG_SIZE = 224
RANDOM_SEED = 42
IMAGE_NET_MEANS = [0.485, 0.456, 0.406]
IMAGE_NET_STDS = [0.229, 0.224, 0.225]
N_CLASSES = len(unique_vals)

train_transform = transforms.Compose(
    transforms=[
        transforms.Resize(size=IMG_SIZE),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGE_NET_MEANS, std=IMAGE_NET_STDS),
    ]
)

test_transform = transforms.Compose(
    transforms=[
        transforms.Resize(size=IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGE_NET_MEANS, std=IMAGE_NET_STDS),
    ]
)

### Preparing DataLoaders

In [None]:
train_dataset, val_dataset, test_dataset = split_dataset(dataset, RANDOM_SEED)
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = test_transform
test_dataset.dataset.transform = test_transform


train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
print(f"Batch size: {BATCH_SIZE}")
print(f"train_dataset: {len(train_dataset)} -> train_loader: {len(train_loader)}")
print(f"val_dataset: {len(val_dataset)} -> val_loader: {len(val_loader)}")
print(f"test_dataset: {len(test_dataset)} -> test_loader: {len(test_loader)}")

## 6. Transfer Learning
Transfer learning is a powerful technique in deep learning where a model pretrained on a large, general dataset is adapted for a related task. This practice improves performance while reducing the amount of training data and training time required.

Setting `param.requires_grad = False` in PyTorch freezes the model parameters (weights and biases), preventing gradient computations and updates during training. This allows parts of the model to remain unchanged while selectively training other layers.

### ResNet-50

In [None]:
class ResNet50(nn.Module):
    def __init__(
        self,
        num_classes: int,
        is_frozen: bool | None = True,
        device: torch.device | str = "cpu",
    ) -> None:
        super().__init__()
        self.device = device
        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.model.fc = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=self.model.fc.in_features, out_features=num_classes),
        )

        if is_frozen:
            # Freeze all layers initially
            for param in self.model.parameters():
                param.requires_grad = False
            for param in self.model.layer4.parameters():
                param.requires_grad = True
            for param in self.model.fc.parameters():
                param.requires_grad = True

        self.to(device=device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

### DenseNet121

In [None]:
class DenseNet121(nn.Module):
    def __init__(
        self,
        num_classes: int,
        is_frozen: bool | None = True,
        device: torch.device | None = "cpu",
    ) -> None:
        super().__init__()
        self.device = device
        self.model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=n_features, out_features=num_classes),
        )

        if is_frozen:
            for param in self.model.parameters():  # all layers
                param.requires_grad = False
            # Unfreeze only the last dense block, final batch norm and classifier by default
            for param in self.model.features[10].parameters():  # denseblock4
                param.requires_grad = True
            for param in self.model.features[11].parameters():  # norm5
                param.requires_grad = True
            for param in self.model.classifier.parameters():  # classifier
                param.requires_grad = True

        self.to(device=device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


In [None]:
resnet50 = ResNet50(num_classes=N_CLASSES, device=DEVICE)
densenet121 = DenseNet121(num_classes=N_CLASSES, device=DEVICE)

In [None]:
for model in [resnet50, densenet121]:
    print(
        summary(
            model=model,
            input_size=(
                BATCH_SIZE,
                3,
                IMG_SIZE,
                IMG_SIZE,
            ),
            verbose=0,
            col_names=["input_size", "output_size", "num_params", "trainable"],
            col_width=20,
            row_settings=["var_names"],
        ),
    )
    print("\n")

## 7. Evaluation Metrics
We will use the following evaluation metrics:
- `torchmetrics.Accuracy`
- `torchmetrics.F1Score`

In [None]:
accuracy = Accuracy(task="multiclass", num_classes=N_CLASSES).to(device=DEVICE)
f1 = F1Score(task="multiclass", num_classes=N_CLASSES).to(device=DEVICE)

metrics = [accuracy, f1]

## 8. Loss Function
### Cross-Entropy Loss
Cross-Entropy Loss is a loss function used for classification problems, particularly when the model outputs probabilities using a softmax activation in the final layer. It measures the difference between the true labels and the predicted probability distribution.

For a single data point, the cross-entropy loss is defined as:

\begin{align*}
    L = - \sum^{k}_{i=1}y_{i}\log{(\hat y_{i})}
\end{align*}

where:
- $y_i$: True label for the $i$-th class. If one-hot encoded, $y_{i} = 1$ for the corrected class, $y_{i} = 0$ otherwise.
- $\hat y_i$: Predicted probability for the $i$-th class.
- $k$: Number of classes.

For a batch of $m$ data point:

\begin{align*}
    C = \dfrac{1}{m} \sum^{m}_{j=1} \left (- \sum^{k}_{i=1}y_{j, i}\log{(\hat y_{j, i})} \right)
\end{align*}

where:
- $C$: Average cross-entropy loss over the batch.
- $m$: Number of training examples (batch size).
- $k$: Number of classes.
- $y_{j, i} \in { 0, 1}$: Indicator that true class for sample $j$ corresponds to class $i$.
- $\hat y_{j, i} \in { 0, 1}$: Predicted probability for sample $j$ belonging to class $i$.

In PyTorch:
- Use `nn.CrossEntropyLoss()` directly with raw logits.
- Do not apply `Softmax()` or `LogSoftmax()` manually before the loss.
- Internally, `nn.CrossEntropyLoss() = LogSoftmax() + NegativeLogLikelihoodLoss()`. 

In [None]:
# Inversed frequency weights
class_weights = 1.0 / torch.tensor(counts, dtype=torch.float)
class_weights = (class_weights / class_weights.sum()).to(DEVICE)  # Normalised
for val, weight in zip(unique_vals, class_weights):
    print(f"Weight for {val}: {weight:.5f}")

criterion = nn.CrossEntropyLoss(weight=class_weights)

## 9. Optimiser
An optimiser in neural networks is used to adjust the parameters (weights and biases) of a model during training to minimise the loss. Optimisers are essential for enabling neural networks to learn from data: without them, the model would not improve over time.

**AdamW** (a variant of the Adam optimiser) separates weight decay (L2 regularisation) from the gradient updates. This decoupling often improves a model's generalisation performance compared to the original Adam optimiser, reducing the risk of overfitting, especially in large-scale models.

**ReduceLROnPlateau** is a learning rate scheduler that monitors a specified metric (usually validation loss) and reduces the learning rate by a given factor if the metric stops improving for a certain number of epochs (`patience`). This allows the optimiser to take smaller, more precise steps when progress plateaus, often leading to better final model performance.

In [None]:
LEARNING_RATE = 1e-3
DECAY_RATE = 1e-4

# ! === ResNet50 ===
resnet50_optimiser = optim.AdamW(
    params=filter(lambda p: p.requires_grad, resnet50.parameters()),
    lr=LEARNING_RATE,
    weight_decay=DECAY_RATE,
)
resnet50_scheduler = ReduceLROnPlateau(
    optimizer=resnet50_optimiser,
    mode="min",
    patience=3,
    factor=0.5,
)

# ! === DenseNet121 ===
densenet121_optimiser = optim.AdamW(
    params=filter(lambda p: p.requires_grad, densenet121.parameters()),
    lr=LEARNING_RATE,
    weight_decay=DECAY_RATE,
)
densenet121_scheduler = ReduceLROnPlateau(
    optimizer=densenet121_optimiser,
    mode="min",
    patience=3,
    factor=0.5,
)

## 10. Training and Evaluation
1. Iterate through epochs
1. For each epoch, iterate through training batches, perform training steps, calculate the train loss and evaluation metrics per batch.
1. For each epoch, iterate through validation batches, perform validation steps, calculate the validation loss and evaluation metrics per batch.


### Training Steps
1. Zero the gradients
    - Clear the gradients from the previous iteration to prevent accumulation across batches.
1. Forward pass
    - Pass inputs through the model to obtain predictions.
1. Calculate loss and evaluation metrics per batch
    - Measure how far the predictions deviate from the true labels using a loss function.
    - Compute evaluation metrics (e.g., accuracy, F1 Score) for the current batch.
1. Backward pass
    - Compute gradients of the loss with respect to the model's parameters via backpropagation.
    - Update the parameter $\theta$ using the computed gradients, typically following:
    
    $$
        \theta \leftarrow \theta - \eta \dfrac{\partial \mathcal{L}}{\partial \theta}
    $$
    where $\eta$ is the learning rate.
1. Average training loss and evaluation metrics
    - Calculate the mean loss and metric values across all batches in the epoch.


## 10. Training and Evaluation
1. Iterate through epochs
1. For each epoch, iterate through training batches, perform training steps, calculate the train loss and evaluation metrics per batch.
1. For each epoch, iterate through validation batches, perform validation steps, calculate the validation loss and evaluation metrics per batch.


### Training Steps
1. Zero the gradients
    - Clear the gradients from the previous iteration to prevent accumulation across batches.
1. Forward pass
    - Pass inputs through the model to obtain predictions.
1. Calculate loss and evaluation metrics per batch
    - Measure how far the predictions deviate from the true labels using a loss function.
    - Compute evaluation metrics (e.g., accuracy, F1 Score) for the current batch.
1. Backward pass
    - Compute gradients of the loss with respect to the model's parameters via backpropagation.
    - Update the parameter $\theta$ using the computed gradients, typically following:
    
    $$
        \theta \leftarrow \theta - \eta \dfrac{\partial \mathcal{L}}{\partial \theta}
    $$
    where $\eta$ is the learning rate.
1. Average training loss and evaluation metrics
    - Calculate the mean loss and metric values across all batches in the epoch.


In [None]:
def train_step(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    optimiser: optim.Optimizer,
    metrics: list[nn.Module],
    device: torch.device,
) -> tuple[float, list[float]]:
    model.train()
    for metric in metrics:
        metric.reset()
    n_total_samples = len(data_loader.dataset)
    train_loss = 0.0

    for inputs, labels in data_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        batch_size = inputs.size(0)

        # Optimiser zero grad without intervening forward pass
        optimiser.zero_grad()

        # Forward pass
        y_logits = model(inputs)

        # Calculate loss
        loss = criterion(y_logits, labels)
        train_loss += loss.item() * batch_size

        # Calculate metrics
        y_probs = torch.softmax(input=y_logits, dim=1)
        y_preds = torch.argmax(input=y_probs, dim=1)

        for metric in metrics:
            metric.update(y_preds, labels)

        # Loss backward for backpropagation (computing gradients)
        loss.backward()

        # Optimiser step to apply gradients and update parameters
        optimiser.step()

    avg_train_loss = train_loss / n_total_samples
    train_metric_scores = [metric.compute().item() * 100 for metric in metrics]
    return avg_train_loss, train_metric_scores

### Validation Steps
1. Forward pass
    - Set the model to evaluation mode (which disables dropout and batch normalisation and desactivates gradient tracking for safety).
    - Pass inputs through the model to obtain predictions.
1. Calculate loss and evaluation metrics per batch
    - Measure how far the predictions deviate from the true labels using a loss function.
    - Compute evaluation metrics (e.g., accuracy, F1-Score) for the current batch.
1. Average test loss and evaluation metrics
    - Calculate the mean loss and metric values across all batches in the epoch.

In [None]:
def validation_step(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    metrics: list[nn.Module],
    device: torch.device,
) -> tuple[float, list[float]]:
    model.eval()
    val_loss = 0.0
    n_total_samples = len(data_loader.dataset)
    for metric in metrics:
        metric.reset()

    with torch.inference_mode():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            batch_size = inputs.size(0)

            # 1. Forward pass
            y_logits = model(inputs)

            # 2. Calculate loss
            loss = criterion(y_logits, labels)
            val_loss += loss.item() * batch_size

            # 3. Calculate metrics
            y_probs = torch.softmax(input=y_logits, dim=1)
            y_preds = torch.argmax(input=y_probs, dim=1)

            for metric in metrics:
                metric.update(y_preds, labels)

    avg_val_loss = val_loss / n_total_samples
    val_metric_scores = [metric.compute().item() * 100 for metric in metrics]
    return avg_val_loss, val_metric_scores

### Model Training and Evaluation Pipeline

In [None]:
def train_and_validate(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimiser: optim.Optimizer,
    scheduler: optim.lr_scheduler,
    metrics: list[nn.Module],
    device: torch.device,
    total_epochs: int,
) -> dict[str, list[float]]:
    model.to(device=device)
    epochs_range = range(1, total_epochs + 1)
    best_loss = float("inf")
    patience_counter = 0
    best_epoch = 0
    model_name = model.__class__.__name__.lower()

    history = {
        "train_loss": [],
        "train_acc": [],
        "train_f1": [],
        "val_loss": [],
        "val_acc": [],
        "val_f1": [],
    }

    start_time = time.time()

    for epoch in epochs_range:
        train_loss, train_metrics = train_step(
            model=model,
            data_loader=train_loader,
            criterion=criterion,
            optimiser=optimiser,
            metrics=metrics,
            device=device,
        )
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_metrics[0])
        history["train_f1"].append(train_metrics[1])

        val_loss, val_metrics = validation_step(
            model=model,
            data_loader=val_loader,
            criterion=criterion,
            metrics=metrics,
            device=device,
        )
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_metrics[0])
        history["val_f1"].append(val_metrics[1])

        scheduler.step(val_loss)  # Update learning rate based on validation loss

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(obj=model.state_dict(), f=f"{model_name}_best.pth")
            patience_counter = 0
            best_epoch = epoch
        else:  # Early Stopping
            patience_counter += 1
            if patience_counter >= 5:
                print(f"Early stopping at epoch {epoch}")
                break
        print(f"Epoch [{epoch}/{total_epochs}]\n{'=' * 60}")
        print(
            f"{'Train Loss:':<12}{train_loss:>6.4f} | {'Train Accuracy:':<15}{train_metrics[0]:>6.2f}% | {'Train F1:':<10}{train_metrics[1]:>6.2f}%"
        )
        print(
            f"{'Val Loss:':<12}{val_loss:>6.4f} | {'Val Accuracy:':<15}{val_metrics[0]:>6.2f}% | {'Val F1:':<10}{val_metrics[1]:>6.2f}"
        )
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Training and validation completed in {elapsed_time:.2f} seconds.\n")
    print(f"The best-performing model was saved at epoch: {best_epoch}")
    return history

In [None]:
EPOCHS = 20
EPOCH_RANGE = range(1, EPOCHS + 1)
MODEL_NAME_RESNET50 = "ResNet50"
MODEL_NAME_DENSENET121 = "DenseNet121"

print(f"Training {MODEL_NAME_RESNET50}...")
resnet50_history = train_and_validate(
    model=resnet50,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimiser=resnet50_optimiser,
    scheduler=resnet50_scheduler,
    metrics=metrics,
    device=DEVICE,
    total_epochs=EPOCHS,
)

print(f"Training {MODEL_NAME_DENSENET121}...")
densenet121_history = train_and_validate(
    model=densenet121,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimiser=densenet121_optimiser,
    scheduler=densenet121_scheduler,
    metrics=metrics,
    device=DEVICE,
    total_epochs=EPOCHS,
)