# Project

## Imports and basic setup

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import KFold
import random
import numpy as np

from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.transforms import ScaledTranslation, IdentityTransform

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%matplotlib inline

## Dataset

In [None]:
# Hyperparameters
BATCH_SIZE = 128

# Data augmentations
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataloaders
train_set = datasets.ImageFolder(".../xray_data_ternary/train", train_transforms)
test_set = datasets.ImageFolder(".../xray_data_ternary/test", test_transforms)

# Count number of instances per class
samples_count = torch.unique(torch.tensor(train_set.targets), return_counts=True)

for i, sample in enumerate(samples_count[0]):
    print(f'{train_set.classes[i]} ({samples_count[0][i]}) - {samples_count[1][i]}')

## Model

In [None]:
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs):
    train_losses = np.empty(0)
    train_accuracies = np.empty(0)
    val_losses = np.empty(0)
    val_accuracies = np.empty(0)

    # Initialize best model weights and best validation accuracy.
    best_model_wts = model.state_dict()

    for epoch in range(num_epochs):

        train_loss = 0
        train_acc = 0
        train_loss_sum = 0
        train_correct = 0
        train_total = 0

        model.train()

        for batch_nr, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            preds = torch.argmax(outputs, 1)

            # Calculate loss
            loss = criterion(outputs, labels)

            train_loss_sum += loss.item() * inputs.size(0)
            train_correct += torch.sum(preds == labels)
            train_total += labels.size(0)
            train_loss = train_loss_sum / train_total
            train_acc = train_correct / train_total

            # Backpropagation
            loss.backward()

            # Update parameters
            optimizer.step()

            # Clear
            optimizer.zero_grad()

            # Print the epoch and loss
            print('\r', f'Epoch {epoch + 1} Batch {batch_nr + 1}/{len(train_loader)} - Train loss: {train_loss} - Accuracy: {train_acc:.2f}', end='')
        print('')

        # Add the loss to the total epoch loss (item() turns a PyTorch scalar into a normal Python datatype)
        train_losses = np.append(train_losses, train_loss)
        train_accuracies = np.append(train_accuracies, train_acc.detach().cpu())

        # Validation
        model.eval()

        val_loss = 0
        val_loss_sum = 0
        val_correct = 0
        val_total = 0

        y_pred = torch.empty((0,), dtype=torch.int64).to(device)
        y_true = torch.empty((0,), dtype=torch.int64).to(device)

        with torch.no_grad():
            for batch_nr, (inputs, labels) in enumerate(val_loader):

                inputs, labels = inputs.to(device), labels.to(device)

                # Forward pass
                outputs = model(inputs)
                preds = torch.argmax(outputs, 1)
                if torch.sum(preds) == 0:
                    print("\nAll 0 again")

                # Calculate loss
                loss = criterion(outputs, labels)

                val_loss_sum += loss.item() * inputs.size(0)
                val_correct += torch.sum(preds == labels)
                val_total += labels.size(0)
                val_loss = val_loss_sum / val_total
                val_acc = val_correct / val_total

                y_pred = torch.cat((y_pred, preds), 0)
                y_true = torch.cat((y_true, labels), 0)

                # Print the epoch and loss
                print('\r', f'Epoch {epoch + 1} Batch {batch_nr + 1}/{len(val_loader)} - Validation loss: {val_loss} - Accuracy: {val_acc:2.2f}', end='')
            print('')

            # Add the loss to the total epoch loss
            val_losses = np.append(val_losses, val_loss)
            val_accuracies = np.append(val_accuracies, val_acc.detach().cpu())

            # Update best model
            if val_acc < min(val_accuracies):
                best_model_wts = model.state_dict()

    # Load best model weights
    model.load_state_dict(best_model_wts)

    return model, (train_losses, train_accuracies), (val_losses, val_accuracies, y_pred.detach().cpu().numpy(), y_true.detach().cpu().numpy())


def test_model(model, criterion, test_loader):
    test_losses = np.empty(0)
    y_pred = torch.empty((0,), dtype=torch.int64).to(device)
    y_true = torch.empty((0,), dtype=torch.int64).to(device)

    test_loss_sum = 0
    test_correct = 0
    test_total = 0

    model.eval()

    with torch.no_grad():
        for batch_nr, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            preds = torch.argmax(outputs, 1)

            # Calculate loss
            loss = criterion(outputs, labels)
            test_loss_sum += loss.item() * inputs.size(0)
            test_total += labels.size(0)
            test_correct += torch.sum(preds == labels)
            test_loss = test_loss_sum / test_total
            test_acc = test_correct / test_total

            y_pred = torch.cat((y_pred, preds), 0)
            y_true = torch.cat((y_true, labels), 0)

            # Print the epoch and loss
            print('\r', f'Batch {batch_nr + 1}/{len(test_loader)} - Test loss: {test_loss} - Accuracy: {test_acc:2.2f}', end='')
        print('')

    # Add the loss to the total epoch loss (item() turns a PyTorch scalar into a normal Python datatype)
    test_losses = np.append(test_losses, test_loss)

    # Print the accuracy and loss
    print(f'Test loss: {test_loss} - Accuracy: {test_acc:.2f}')

    return y_pred.detach().cpu().numpy(), y_true.detach().cpu().numpy(), test_losses


def plot_results(losses, accuracies, y_pred, y_true):
    """
    :param losses: List of tuples of [0] loss array-like and [1] a label
    :param accuracies: List of tuples of [0] accuracy array-like and [1] a label
    """
    sns.set_style("darkgrid")
    # Losses
    plt.figure(figsize=(10, 5), dpi=150)
    for l in losses:
        plt.plot(l[0], "o:", label=l[1])
    plt.xlabel("No. of Epochs", fontdict={'size': 16})
    plt.ylabel("Loss", fontdict={'size': 16})
    plt.xticks(range(0, len(losses[0][0]) + 1, 5), range(0, len(losses[0][0]) + 1, 5), fontsize=14)
    plt.yticks(fontsize=14)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Accuracies
    plt.figure(figsize=(10, 5), dpi=150)
    for a in accuracies:
        plt.plot(a[0], "o:", label=a[1])
    plt.xlabel("No. of Epochs", fontdict={'size': 16})
    plt.ylabel("Accuracy", fontdict={'size': 16})
    plt.xticks(range(0, len(accuracies[0][0]) + 1, 5), range(0, len(accuracies[0][0]) + 1, 5), fontsize=14)
    plt.yticks(fontsize=14)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Compute the confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    # Define the label names for the plot
    categories = ['Normal', 'Bacterial\nPneumonia', 'Viral\nPneumonia']
    # Compute precision, recall, and F1 scores
    precision = precision_score(y_true, y_pred, average=None)
    recall = recall_score(y_true, y_pred, average=None)
    f1 = f1_score(y_true, y_pred, average=None)
    # Create heatmap
    ax = sns.heatmap(cf_matrix, annot=True, fmt='g', linewidth=.5, cmap='Blues', xticklabels=categories, yticklabels=categories)
    #Fix annotation positions
    for t in ax.texts:
        trans = t.get_transform()
        offs = ScaledTranslation(0, -0.3, IdentityTransform())
        t.set_transform(offs + trans)
    # Add the x and y axis labels
    ax.set_xlabel('Predicted Label', fontsize=16)
    ax.set_ylabel('True Label', fontsize=16)
    # Add precision, recall, and F1 scores as annotations to the cells
    for i in range(len(categories)):
        for j in range(len(categories)):
            text = f'P={precision[j]:2.2f}\nR={recall[i]:2.2f}\nF1={f1[j]:2.2f}'
            ax.text(j + 0.5, i + 0.5, text, ha='center', va='top', color='black')
    plt.show()

In [None]:
LEARNING_RATE = 0.0001
EPOCHS = 15

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
val_predictions = np.empty(0)
val_labels = np.empty(0)

# Set up k-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for fold_idx, (train_idx, val_idx) in enumerate(kf.split(train_set)):
    print(f'Fold no: {fold_idx + 1} -----------------------------------------------------------------')
    # Shuffle indices to mix up class order
    random.shuffle(train_idx)
    random.shuffle(val_idx)

    # Create data loaders for the training and validation sets using the samplers
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=train_idx)
    val_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=val_idx)

    fold_train_targets = [train_set.targets[i] for i in train_idx]
    class_weights = torch.tensor(compute_class_weight(class_weight="balanced", classes=[0, 1, 2], y=fold_train_targets), dtype=torch.float).to(device)

    # Model setup
    model = models.resnet18(weights="DEFAULT")
    model.fc = nn.Sequential(nn.Linear(model.fc.in_features, 3))
    model = model.to(device)
    # Define our loss function
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    # Define our optimizer
    optimizer = torch.optim.Adam(model.parameters(), LEARNING_RATE)

    # Train the model
    trained_model, train_stats, val_stats = train_model(model, criterion, optimizer, train_loader, val_loader, EPOCHS)

    # Visualisation
    val_loss, val_acc, val_pred, val_true = val_stats
    train_loss, train_acc = train_stats
    # Save to list for later averaging
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    val_predictions = np.append(val_predictions, val_pred)
    val_labels = np.append(val_labels, val_true)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    plot_results([(train_loss, "Training"), (val_loss, "Validation")], [(train_acc, "Training"), (val_acc, "Validation")], val_pred, val_true)

# Aggregated visualisation of average loss & accuracy
train_loss_avg = sum(train_losses) / len(train_losses)
train_acc_avg = sum(train_accuracies) / len(train_accuracies)
val_loss_avg = sum(val_losses) / len(val_losses)
val_acc_avg = sum(val_accuracies) / len(val_accuracies)
print('Average results across folds: -------------------------------------------------------------------')
plot_results([(train_loss_avg, "Training"), (val_loss_avg, "Validation")], [(train_acc_avg, "Training"), (val_acc_avg, "Validation")], val_predictions,
             val_labels)

In [None]:
plot_results([(train_loss_avg, "Training"), (val_loss_avg, "Validation")], [(train_acc_avg, "Training"), (val_acc_avg, "Validation")], val_predictions,
             val_labels)

In [None]:
# Train on the full training set
def train_model_full(model, criterion, optimizer, train_loader, num_epochs):
    train_losses = np.empty(0)
    train_accuracies = np.empty(0)

    for epoch in range(num_epochs):

        train_loss = 0
        train_acc = 0
        train_loss_sum = 0
        train_correct = 0
        train_total = 0

        model.train()

        for batch_nr, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            preds = torch.argmax(outputs, 1)

            # Calculate loss
            loss = criterion(outputs, labels)

            train_loss_sum += loss.item() * inputs.size(0)
            train_correct += torch.sum(preds == labels)
            train_total += labels.size(0)
            train_loss = train_loss_sum / train_total
            train_acc = train_correct / train_total

            # Backpropagation
            loss.backward()

            # Update parameters
            optimizer.step()

            # Clear
            optimizer.zero_grad()

            # Print the epoch and loss
            print('\r', f'Epoch {epoch + 1} Batch {batch_nr + 1}/{len(train_loader)} - Train loss: {train_loss} - Accuracy: {train_acc:2.2f}', end='')
        print('')

        # Add the loss to the total epoch loss (item() turns a PyTorch scalar into a normal Python datatype)
        train_losses = np.append(train_losses, train_loss)
        train_accuracies = np.append(train_accuracies, train_acc.detach().cpu())

    # Save last model weights
    torch.save(model.state_dict(), "resnet_weights.pt")

    return model, (train_losses, train_accuracies)

In [None]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)
class_weights = torch.tensor(compute_class_weight(class_weight="balanced", classes=[0, 1, 2], y=train_set.targets), dtype=torch.float).to(device)
# Model setup
full_model = models.resnet18(weights="DEFAULT")
full_model.fc = nn.Sequential(nn.Linear(full_model.fc.in_features, 3))
full_model = full_model.to(device)
# Define our loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)
# Define our optimizer
optimizer = torch.optim.Adam(full_model.parameters(), lr=0.0001)
full_model, train_stats = train_model_full(full_model, criterion, optimizer, train_loader, num_epochs=5)

In [None]:
criterion = nn.CrossEntropyLoss()
y_pred, y_true, test_loss = test_model(full_model, criterion, test_loader)

# Compute the confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
# Define the label names for the plot
categories = ['Normal', 'Bacterial\nPneumonia', 'Viral\nPneumonia']
# Compute precision, recall, and F1 scores
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)
# Create heatmap
ax = sns.heatmap(cf_matrix, annot=True, fmt='g', linewidth=.5, cmap='Blues', xticklabels=categories, yticklabels=categories)
#Fix annotation positions
for t in ax.texts:
    trans = t.get_transform()
    offs = ScaledTranslation(0, -0.3, IdentityTransform())
    t.set_transform(offs + trans)
# Add the x and y axis labels
ax.set_xlabel('Predicted Label', fontsize=16)
ax.set_ylabel('True Label', fontsize=16)
# Add precision, recall, and F1 scores as annotations to the cells
for i in range(len(categories)):
    for j in range(len(categories)):
        text = f'P={precision[j]:.2f}\nR={recall[i]:.2f}\nF1={f1[j]:.2f}'
        ax.text(j + 0.5, i + 0.5, text, ha='center', va='top', color='black')
plt.show()

In [None]:
print(f"Micro average F1-score: {f1_score(y_true, y_pred, average='micro'):.2f}")
print(f"Macro average F1-score: {f1_score(y_true, y_pred, average='macro'):.2f}")
print(f"Weighted average F1-score: {f1_score(y_true, y_pred, average='weighted'):.2f}")

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
batch_tensor = next(iter(train_loader))
image_tensor = batch_tensor[0].to(device)
label_tensor = batch_tensor[1]

# Normalize to [0, 1]
image_tensor -= image_tensor.min(3, keepdim=True)[0]
image_tensor /= image_tensor.max(3, keepdim=True)[0]

model = models.resnet18(weights="DEFAULT")
model.fc = nn.Sequential(nn.Linear(model.fc.in_features, 3))
model = model.to(device)
model.load_state_dict(torch.load('resnet_weights.pt', map_location=torch.device('cpu')))
model.eval()
target_layers = [model.layer4[-1]]

# Construct the CAM object once, and then re-use it on many images:
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=1)

image_tensor = image_tensor.cpu()
prediction_tensor = torch.argmax(model(image_tensor.to(device)), 1)

fig = plt.figure(figsize=(20, 8))
fig.suptitle(f'ResNet18 GradCAM Visualization')

position = 1
for i in range(10):
    # Define activation heatmap
    grayscale_cam = cam(input_tensor=image_tensor[i].unsqueeze(0), targets=None)
    grayscale_cam = grayscale_cam[0, :]

    #Select and plot one sample from batch
    img = image_tensor[i].numpy().transpose(1, 2, 0)
    label = label_tensor[i]
    prediction = prediction_tensor[i]

    ax = plt.subplot(2, 5, position)
    ax.set_xticks([])
    ax.set_yticks([])

    visualizations = show_cam_on_image(img, grayscale_cam, use_rgb=True)
    plt.imshow(visualizations)
    plt.title(f'Ground truth: {train_set.classes[label]}\n Prediction: {train_set.classes[prediction]}')

    position += 1

plt.show()