# Importing necessary libraries

In [None]:
import torch
import torchvision

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import seaborn as sns
sns.set_style("whitegrid")

from PIL import Image
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

In [None]:
def set_device():
  """
  Set the device. CUDA if available, CPU otherwise

  Args:
    None

  Returns:
    Nothing
  """
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  if not torch.cuda.is_available():
    print("WARNING: For this notebook to perform best, "
        "if possible, in the menu under `Runtime` -> "
        "`Change runtime type.`  select `GPU` ")
  else:
    print("GPU is enabled in this notebook.")

  return device

DEVICE = set_device()
print(DEVICE)

## Loading the data

In [None]:
train_dir = "../input/chest-xray-pneumonia/chest_xray/train"
test_dir = "../input/chest-xray-pneumonia/chest_xray/test"
val_dir = "../input/chest-xray-pneumonia/chest_xray/val"

In [None]:
train_transform = transforms.Compose((
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.RandomAffine(
        20, 
        shear=10,
        interpolation=transforms.InterpolationMode.NEAREST
    ),
    transforms.RandomHorizontalFlip(p=0.2),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.ColorJitter(brightness=[0.5,2.0]),
    transforms.ToTensor()
))

train_image = ImageFolder(train_dir, transform=train_transform)

val_transform = transforms.Compose((
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor()
))

val_image = ImageFolder(val_dir, transform=val_transform)

test_image = ImageFolder(test_dir, transform=val_transform)

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_image,                                  
    batch_size=32,
    shuffle=True,
    num_workers=2,
)

val_loader = torch.utils.data.DataLoader(
    val_image,                                  
    batch_size=2,
    shuffle=True,
    num_workers=2,
)

test_loader = torch.utils.data.DataLoader(
    test_image,                                  
    batch_size=2,
    shuffle=True,
    num_workers=2,
)

# Defining some helper functions

In [None]:
class LRScheduler():
    """
    Learning rate scheduler. If the validation loss does not decrease for the 
    given number of `patience` epochs, then the learning rate will decrease by
    by given `factor`.
    """
    def __init__(
        self, optimizer, patience=5, min_lr=1e-6, factor=0.5
    ):
        """
        new_lr = old_lr * factor

        :param optimizer: the optimizer we are using
        :param patience: how many epochs to wait before updating the lr
        :param min_lr: least lr value to reduce to while updating
        :param factor: factor by which the lr should be updated
        """
        self.optimizer = optimizer
        self.patience = patience
        self.min_lr = min_lr
        self.factor = factor

        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 
                self.optimizer,
                mode='min',
                patience=self.patience,
                factor=self.factor,
                min_lr=self.min_lr,
                verbose=True
            )

    def __call__(self, val_loss):
        self.lr_scheduler.step(val_loss)

class EarlyStopping():
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                print('INFO: Early stopping')
                self.early_stop = True


In [None]:
def train(model, criterion, optimizer, device, train_loader, validation_loader, epochs, lr_scheduler_flag=False, early_stopping_flag=False):
    train_loss, validation_loss = [], []
    train_acc, validation_acc = [], []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.

        correct, total = 0, 0 
        with tqdm(train_loader, unit='batch') as tepoch:
            tepoch.set_description('Training: ')
            for data, target in tepoch:
                data, target = data.to(device), target.to(device)
                output = model(data)
                optimizer.zero_grad()
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                tepoch.set_postfix(loss=loss.item())
                running_loss += loss.item()
                
                _, predicted = torch.max(output, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
                
        train_loss.append(running_loss / len(train_loader))  # append the loss for this epoch
        train_acc.append(correct/total)
                
        # evaluate on validation data
        
        model.eval()
        running_loss = 0.
        correct, total = 0, 0 
        with tqdm(validation_loader, unit='batch') as tepoch:
            tepoch.set_description('Validation: ')
            for data, target in tepoch:
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                
                loss = criterion(output, target)
                tepoch.set_postfix(loss=loss.item())
                running_loss += loss.item()

                # get accuracy 
                _, predicted = torch.max(output, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        validation_loss.append(running_loss/len(validation_loader))
        validation_acc.append(correct/total)
        
        if validation_acc[-1] == 1: break
        if lr_scheduler_flag:
            lr_scheduler(validation_loss[-1])
        if early_stopping_flag:
            early_stopping(validation_loss[-1])
            if early_stopping.early_stop:
                break
    
    return train_loss, train_acc, validation_loss, validation_acc 

In [None]:
def plot_loss_accuracy(train_loss, train_acc, validation_loss, validation_acc):
    epochs = len(train_loss)
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.plot(list(range(epochs)), train_loss, label='Training Loss')
    ax1.plot(list(range(epochs)), validation_loss, label='Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.set_title('Epoch vs Loss')
    ax1.legend()

    ax2.plot(list(range(epochs)), train_acc, label='Training Accuracy')
    ax2.plot(list(range(epochs)), validation_acc, label='Validation Accuracy')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Epoch vs Accuracy')
    ax2.legend()
    fig.set_size_inches(15.5, 5.5)
    plt.show()

# VGG19

In [None]:
vgg_model = models.vgg19(pretrained=True)
for param in vgg_model.parameters():
  param.requires_grad = False

for param in vgg_model.classifier.parameters():
    param.requires_grad = True

num_ftrs = vgg_model.classifier[-1].in_features

vgg_model.classifier[-1] = nn.Sequential(
    nn.Linear(num_ftrs, 2)
)
vgg_model = vgg_model.to(DEVICE)
print(vgg_model)

In [None]:
optimizer = torch.optim.Adam(vgg_model.parameters(), lr=1e-5)
lr_scheduler = LRScheduler(optimizer)
early_stopping = EarlyStopping(patience=3)
loss_fn = nn.CrossEntropyLoss()

In [None]:
train_loss, train_acc, validation_loss, validation_acc = train(
    vgg_model, loss_fn, optimizer, DEVICE, train_loader, val_loader, 20,
    lr_scheduler_flag=True, early_stopping_flag=False
)
plot_loss_accuracy(train_loss, train_acc, validation_loss, validation_acc)

In [None]:
print(
    "train_loss:", train_loss, 
    "\ntrain_acc:", train_acc, 
    "\nvalidation_loss:", validation_loss, 
    "validation_acc:", validation_acc
)

## Evaluating VGG19

In [None]:
with torch.no_grad():
    loss_sum = 0
    total_correct = 0
    total = len(test_image)
    for batch in test_loader:
      images, labels = batch
      images = images.to(DEVICE)
      labels = labels.to(DEVICE)
      output = vgg_model(images)
      loss = loss_fn(output, labels)
      loss_sum += loss.item()

      predictions = torch.argmax(output, dim=1)

      num_correct = torch.sum(predictions == labels)
      total_correct += num_correct.cpu()

pct = total_correct / total
print(f'\n Final accuracy is {pct}')

In [None]:
def get_y(net, test_loader, device):
    y_pred=torch.zeros(0,dtype=torch.long, device='cpu')
    y_true=torch.zeros(0,dtype=torch.long, device='cpu')
    with torch.no_grad():
      for batch in test_loader:
        data, label = batch
        data, label = data.to(device), label.to(device)
        outputs = net(data)
        _, preds = torch.max(outputs, 1)
        y_pred = torch.cat([y_pred, preds.view(-1).cpu()])
        y_true = torch.cat([y_true, label.view(-1).cpu()])

    return y_pred, y_true

In [None]:
y_pred, y_true = get_y(vgg_model, test_loader, DEVICE)
confusion_matrix_df = pd.DataFrame(confusion_matrix(y_true, y_pred))

In [None]:
plt.figure(figsize=(16,8))
sns.heatmap(confusion_matrix_df, cmap="RdBu", annot=True)
plt.title('Confusion Matrix Heatmap', fontsize=16)
plt.xlabel("Predicted condition")
plt.ylabel("Actual condition")
plt.show()

In [None]:
precision_recall_fscore_support(y_true, y_pred, average="binary")

# Basic CNN

In [None]:
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
    self.fc1 = nn.Linear(in_features=387200, out_features=128)
    self.fc2 = nn.Linear(in_features=128, out_features=2)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.relu(x)
    x = self.pool(x)
    x = x.flatten(start_dim=1)
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

In [None]:
cnn_net = CNN().to(DEVICE)
cnn_optimizer = torch.optim.Adam(cnn_net.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()

In [None]:
import torch.onnx
torch.onnx.export(cnn_net,               # model being run
                  torch.rand(16, 3, 224, 224).to(DEVICE),                         # model input (or a tuple for multiple inputs)
                  "cnn.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

In [None]:
train_loss, train_acc, validation_loss, validation_acc = train(
    cnn_net, loss_fn, cnn_optimizer, DEVICE, train_loader, val_loader, 20
)
plot_loss_accuracy(train_loss, train_acc, validation_loss, validation_acc)

## Evaluating CNN

In [None]:
with torch.no_grad():
    loss_sum = 0
    total_correct = 0
    total = len(test_image)
    for batch in test_loader:
      images, labels = batch
      images = images.to(DEVICE)
      labels = labels.to(DEVICE)
      output = cnn_net(images)
      loss = loss_fn(output, labels)
      loss_sum += loss.item()

      predictions = torch.argmax(output, dim=1)

      num_correct = torch.sum(predictions == labels)
      total_correct += num_correct.cpu()

pct = total_correct / total
print(f'\n Final accuracy is {pct}')

In [None]:
y_pred, y_true = get_y(cnn_net, test_loader, DEVICE)
confusion_matrix_df = pd.DataFrame(confusion_matrix(y_true, y_pred))

In [None]:
plt.figure(figsize=(16,8))
sns.heatmap(confusion_matrix_df, cmap="RdBu", annot=True)
plt.title('Confusion Matrix Heatmap', fontsize=16)
plt.xlabel("Predicted condition")
plt.ylabel("Actual condition")
plt.show()

In [None]:
precision_recall_fscore_support(y_true, y_pred, average="binary")