## Project: Multiclass Classification of X-Ray Images

*Disclaimer*: This is my first public Kaggle project. Therefore, comments are very welcome. 

The first part (data loading and inspecting) was inspired by 

https://www.kaggle.com/digvijayyadav/deep-learning-and-transfer-learning-on-covid-19.

The latter part (neutral network) was guided by an amazing pyTorch tutorial by freeCodeCamp.org,

https://www.youtube.com/watch?v=GIsg-ZUy0MY&t=23737s .

Thanks!


## Outline & Summary
* We use the CoronaHack-Chest X-Ray-Dataset, https://www.kaggle.com/praveengovi/coronahack-chest-xraydataset, to classify the X-ray images in four classes: *healthy*, *bacteria*, *COVID-19* and *other*.
* After investigating the data, we find that the provided test set does not distinguish between virus-caused pneumonia cases, i.e., we cannot separate between the ''*COVID-19*'' class and the ''*other*'' class in the test set. Therefore, we shall only make use of the available training data, and separate it into a training and a validation/test set.
* We then setup a convolutional neural network with residual blocks. Concretely, we use the ResNet9 architecture.
* In addition, we employ the following techniques: *Data normalization*, *Data augmentation* (padding and random crop), *batch normalization*, *learning rate scheduling*, *weight decay* and *gradient clipping*. Note that the exact setup and hyperparameters for these techniques are not yet optimized, therefore, there may still be room for improvement in the performance. 
* We use the Adam optimizer.
* After 25 training epochs, we evaluate the performance of the model. We achieve an accuracy of about 83%. We calculate the confusion matrix, and see that COVID-19 cases are identified correctly to 77%.




### Load Libraries

In [None]:
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data.dataloader import default_collate
# from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set()

### Loading and Inspecting the Data
Load Covid-19 X-Ray dataset and create training and test data objects:

In [None]:
data_path = '/kaggle/input/coronahack-chest-xraydataset/'
img_path = data_path + 'Coronahack-Chest-XRay-Dataset/Coronahack-Chest-XRay-Dataset'
train_img_dir = img_path + '/train'
test_img_dir = img_path + '/test'
img_dir = os.listdir(img_path)
df_meta = pd.read_csv(data_path+'Chest_xray_Corona_Metadata.csv')
df_meta_summary = pd.read_csv(data_path+'Chest_xray_Corona_dataset_Summary.csv')


In [None]:
df_meta.tail()

In [None]:
df_meta.info()

In [None]:
missing_entries = df_meta.isnull().sum()

In [None]:
missing_entries.plot(kind="barh");

Many entries in the Label_1_Virus_category and Label_2_Virus_category columns are NaNs. We want to replace them with a string 'unknown'.

In [None]:
#replace null data points to 'unknown'
df_meta.fillna('unknown', inplace=True)
df_meta.isnull().sum()

In [None]:
print((df_meta['Label_1_Virus_category']).value_counts())
print((df_meta['Label_2_Virus_category']).value_counts())

We only have 58 images for COVID-19 (which are all in the training set). Adding more data on this class would certainly improve the model performance.

In [None]:
train_data = df_meta[df_meta['Dataset_type']=='TRAIN']
test_data = df_meta[df_meta['Dataset_type']=='TEST']
assert train_data.shape[0] + test_data.shape[0] == df_meta.shape[0]
print(f"Shape of train data: {train_data.shape}")
print(f"Shape of test data: {test_data.shape}")

train_data.sample(5)

In [None]:
test_data.sample(5)

In [None]:
sns.countplot(train_data['Label_1_Virus_category']);

In [None]:
sns.countplot(train_data['Label_2_Virus_category']);

In [None]:
sns.countplot(test_data['Label_1_Virus_category']);

In [None]:
sns.countplot(test_data['Label_2_Virus_category']);

*Note:* A problem occurs in the test set where all virus-caused pneumonia are not further labeled as COVID-19 (i.e., all Label_2_Virus_category entries are missing/unknown). We therefore don't know whether all virus-caused pneumonia cases are caused by COVID-19 or by another virus, and hence we cannot assess the accuracy with the validation set. For this reason, we discard the provided test set and split instead the training data into a training set and a validation set.

## Define output classes

We aim at classifying four distinct categories: COVID-19 cases vs. healthy vs. bacteria-caused vs. other virus-caused pneumonia.


In [None]:
train_data.loc[train_data['Label'].eq('Normal'), 'class'] = 'healthy';
train_data.loc[(train_data['class'].ne('healthy') & train_data['Label_1_Virus_category'].eq('bacteria')), 'class'] = 'bacteria';
train_data.loc[(train_data['class'].ne('healthy') & train_data['class'].ne('bacteria') & train_data['Label_2_Virus_category'].eq('COVID-19')), 'class'] = 'COVID-19';
train_data.loc[(train_data['class'].ne('healthy') & train_data['class'].ne('bacteria') & train_data['class'].ne('COVID-19')), 'class'] = 'other';


In [None]:
target_dict = {'healthy' : 0,
               'bacteria' : 1,
               'COVID-19' : 2,
               'other' : 3}
train_data['target'] = train_data['class'].map(target_dict);

In [None]:
sns.countplot(train_data['class']);

In [None]:
train_data.sample(10)

## Display X-Ray Images

Let's display some of the X-Ray images.

In [None]:
def plot_images(path,class_str,numdisplay):
    fig, ax = plt.subplots(numdisplay,2, figsize=(15,2.5*numdisplay))
    for row,file in enumerate(path):
        image = plt.imread(file)
#         print(image.shape)
        ax[row,0].imshow(image, cmap=plt.cm.bone)
        ax[row,1].hist(image.ravel(), 256, [0,256])
        ax[row,0].axis('off')
        if row == 0:
            ax[row,0].set_title('Images')
            ax[row,1].set_title('Histograms')
    fig.suptitle('Class='+class_str,size=16)
    plt.show()    


In [None]:
def display_class_images(img_path,dataset,train_or_test_str,classlabel,numdisplay):
    path = dataset[dataset['class']==classlabel]['X_ray_image_name'].values
    sample_path = path[:numdisplay]
    img_dir = img_path+"/"+train_or_test_str
    sample_path = list(map(lambda x: os.path.join(img_dir,x), sample_path))
    plot_images(sample_path,classlabel,numdisplay)



In [None]:
display_class_images(img_path,train_data,"train","healthy",4)

In [None]:
display_class_images(img_path,train_data,"train","COVID-19",4)

In [None]:
display_class_images(img_path,train_data,"train","bacteria",4)

In [None]:
display_class_images(img_path,train_data,"train","other",4)

We cannot use a convenient DataLoader for images like *ImageFolder* from torchvision.datasets as the data is not ordered in subfolders named after the classes. In fact, the class labels are given in the metadata file, and the images are in another directory. We therefore write a custom dataset:

In [None]:
class CustomDataSet(Dataset):
    def __init__(self, main_dir,meta_data, transform):
        self.main_dir = main_dir
        self.meta_data = meta_data
        self.transform = transform
        self.total_imgs = os.listdir(main_dir)

    def __len__(self):
        return len(self.meta_data)

    def __getitem__(self, idx):
        meta_data = self.meta_data.iloc[idx] 
        filename = meta_data['X_ray_image_name']
        try:
            file_idx = self.total_imgs.index(filename)
        except:
            print("Data not found!")
            return None        
        img_loc = os.path.join(self.main_dir, self.total_imgs[file_idx])
        image = Image.open(img_loc).convert("RGB")
        image = image.resize((128,128))
        tensor_image = self.transform(image)
        tensor_label = torch.tensor(meta_data['target'].item())
        return tensor_image, tensor_label

def my_collate(batch):
    "Puts each data field into a tensor with outer dimension batch size"
    batch = filter (lambda x:x is not None, batch)
    return default_collate(list(batch))


*note*: We resized the images to 128 x 128 pixels. The original images are much larger. More information would be retained when taking a larger pixel size, to the disadvantage of a more CPU/GPU-intensive training. This may lead to a better performance.

Set the batch size. (Note: The batch size may still be tuned to optimize performance.)

In [None]:
batch_size=32

### Data normalization and augmentation

The following stats are used for the normalization of the input data. Further below we calculate the mean and the standard deviation for each of the three channels (RGB). These are set here as hard-coded values.

In [None]:
calc_normalization_stats = False # Set this if you want to evaluate the stats

In [None]:
stats = ((0.0093, 0.0093, 0.0092),(0.4827, 0.4828, 0.4828)) # std_dev and mean of images per channel. See below for evaluation.

For the training set we apply a randomized data augmentations: We pad each image by 8 pixels, then take a random crop of size 128 x 128 pixels. The padding is done in *edge* mode. 

Furthermore, we normalize the data, as described above.

In [None]:
if calc_normalization_stats:
    train_tfms = tt.ToTensor()
else:
    train_tfms = tt.Compose([tt.RandomCrop(128, padding=8, padding_mode='edge'), tt.ToTensor(), tt.Normalize(*stats, inplace = True)])
test_tfms = tt.Compose([tt.ToTensor(), tt.Normalize(*stats, inplace = True)])

In [None]:
train_data.shape

We now do a random split of the available dataset (''train_data'') into a training set (''train_ds'', 75%) and a validation/test set (''test_ds'', 25%).

In [None]:
train_ds, test_ds = train_test_split(train_data, test_size=0.25,random_state= 1, shuffle = True)
train_ds, test_ds = train_ds.reset_index(drop=True), test_ds.reset_index(drop=True)
train_ds.shape, test_ds.shape



In [None]:
train_dataset = CustomDataSet(train_img_dir, train_ds, transform=train_tfms)    
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,                            
                          num_workers=0, pin_memory=True, collate_fn=my_collate)

test_dataset = CustomDataSet(train_img_dir, test_ds, transform=test_tfms)
test_loader = DataLoader(test_dataset , batch_size=2*batch_size, shuffle=False, 
                         num_workers=0, pin_memory=True, collate_fn=my_collate)

### Detour: Get statistics for data normalization

(needed only in initial run)

Calculate mean and standard deviation over all train images, later used for normalization (see above).

In [None]:
if calc_normalization_stats:
    mean_per_batch = []
    stdev_per_batch = []
    num_batches = 0
    for idx, (images,labels) in enumerate(train_loader):
        mean_per_batch.append(torch.std_mean(images,[0,2,3])[1])
        stdev_per_batch.append(torch.std_mean(images,[0,2,3])[0])

    channel_std,channel_mean = torch.std_mean(torch.stack(mean_per_batch),0)
    print(channel_std,channel_mean) 

### Check the data batches 

In [None]:
def show_batch(dl):
    for images,labels in dl:
        print(images.shape, labels.shape)
        fig, ax = plt.subplots(figsize=(8,8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images,nrow=8).permute(1,2,0))
        break

In [None]:
show_batch(train_loader)

In [None]:
show_batch(test_loader)

In [None]:
len(train_loader)

In [None]:
len(test_loader)

In [None]:
def show_example(img, label):
    print('Label: ', "("+str(label)+")")
    plt.imshow(img.permute(1, 2, 0))

## Setting up Neural Network

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item()/len(preds))

In [None]:
class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

#### A simple convolutional neural network

*note*: This was my first attempt. Leave it here for reference, but commented out.

In [None]:
# class CnnModel(ImageClassificationBase):
#     def __init__(self):
#         super().__init__()
#         self.network = nn.Sequential(
#             nn.Conv2d(3, 128, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2), # output: 256 x 64 x 64

#             nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2), # output: 512 x 32 x 32

#             nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2), # output: 1024 x 16 x 16

#             nn.Conv2d(1024, 2048, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2), # output: 2048 x 8 x 8


#             nn.Flatten(), 
#             nn.Linear(2048*8*8, 512),
#             nn.ReLU(),
#             nn.Linear(512, 128),
#             nn.ReLU(),
#             nn.Linear(128, 4))
        
#     def forward(self, xb):
#         return self.network(xb)

#### A model with residual blocks (ResNet9)

![resnet-9](https://github.com/lambdal/cifar10-fast/raw/master/net.svg?sanitize=true)

In [None]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))
        
        self.classifier = nn.Sequential(nn.MaxPool2d(16), 
                                        nn.Flatten(), 
                                        nn.Linear(512, num_classes))
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out

Comment this in if you want to see the model summary:

In [None]:
#model = ResNet9(3,4)
#model

## Setting up GPU device

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [None]:
device = get_default_device()
device

In [None]:
train_loader = DeviceDataLoader(train_loader, device)
test_loader = DeviceDataLoader(test_loader, device)

## Setting up the Model Training

Some notes on the employed techniques:

**Learning rate scheduling**: The learning rate is changed after every batch of training. We use the *one cycle learning rate policy* strategy, which starts at a low learning rate, gradually increasing it for about 30% of epochs, then gradually decreasing it to a very low learning rate until the end. After training we shall plot the learning rate as a function of training batch number.

**Weight decay**: A regularization technique that prevents the weights from becoming too large. This is done by adding a penalty proportional to the weights magnitude to the loss function.

**Gradient clipping**: This constrains the gradients to a limited range, thus preventing too large gradient values leading to drastic changes in the back-propagation.

In [None]:
@torch.no_grad()
def evaluate(model, test_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in test_loader]
    return model.validation_epoch_end(outputs)

# Simple fit function, was used with the simple CNN model:
# def fit(epochs, lr, model, train_loader, test_loader, opt_func=torch.optim.SGD):
#     history = []
#     optimizer = opt_func(model.parameters(), lr)
#     for epoch in range(epochs):
#         # Training Phase 
#         model.train()
#         train_losses = []
#         for batch in train_loader:
#             loss = model.training_step(batch)
#             train_losses.append(loss)
#             loss.backward()
#             optimizer.step()
#             optimizer.zero_grad()
#         # Validation phase
#         result = evaluate(model, test_loader)
#         result['train_loss'] = torch.stack(train_losses).mean().item()
#         model.epoch_end(epoch, result)
#         history.append(result)
#     return history

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

# More sophisticated fit function, with the following features:
# Learning rate scheduling, weight decay, gradient clipping

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    # Set up custom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [None]:
# model = to_device(CnnModel(), device)
model = to_device(ResNet9(3, 4), device)

#### Load previously fitted parameters (if needed):

In [None]:
load = False

if load:
    PATH = 'COVID-19_classification-resnet9.pth'
    model.load_state_dict(torch.load(PATH))

Test that (untrained) model gives (meaningless) predictions as output:

In [None]:
history = [evaluate(model,test_loader)]
history

## Train the model

We can now start the training. We use the Adam optimizer, which uses momentum and adaptive learning rates for faster training. We set the number of epochs to 25.

In [None]:
num_epochs = 25
opt_func = torch.optim.Adam
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4

In [None]:
#history = fit(num_epochs, lr, model, train_loader, test_loader, opt_func)
history += fit_one_cycle(num_epochs, max_lr, model, train_loader, test_loader, 
                             grad_clip=grad_clip, 
                             weight_decay=weight_decay, 
                             opt_func=opt_func)

After about 1h of training (25 epochs), the accuracy converges to about 83%.

Save the trained parameters:

In [None]:
torch.save(model.state_dict(), 'COVID-19_classification-resnet9.pth')

## Performance evaluation

Let's first look at the evolution of the accuracy and the loss in the training/validation set.

In [None]:
def plot_accuracies(history):
    accuracies = [x['val_acc'] for x in history]
    plt.plot(accuracies, '-x')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Accuracy vs. No. of epochs');

In [None]:
def plot_losses(history):
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x['val_loss'] for x in history]
    plt.plot(train_losses, '-bx')
    plt.plot(val_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Loss vs. No. of epochs');

In [None]:
plot_accuracies(history)

In [None]:
plot_losses(history)

The trend indicates that our model isn't overfitting quite yet.

Let's also look at the evolution of the learning rate, as adjusted by the learning rate scheduler:

In [None]:
def plot_lrs(history):
    lrs = np.concatenate([x.get('lrs', []) for x in history])
    plt.plot(lrs)
    plt.xlabel('Batch no.')
    plt.ylabel('Learning rate')
    plt.title('Learning Rate vs. Batch no.');

In [None]:
plot_lrs(history)

Finally, we would like to calculate the confusion matrix. For this, we need to obtain the true class labels and the predicted classes from our trained model for the test set.

In [None]:
@torch.no_grad()
def get_all_preds_and_targets(model, loader):
    all_preds = torch.tensor([])
    all_targets = torch.tensor([])
    for batch in loader:
        images, labels = batch

        outputs = model(images)
        _, preds = torch.max(outputs, dim=1)
        all_preds = torch.cat((all_preds, preds),dim=0)
        all_targets = torch.cat((all_targets, labels),dim=0)
    return all_preds, all_targets

In [None]:
device = torch.device('cpu');
test_loader = DeviceDataLoader(test_loader, device);
model = to_device(model, device);

In [None]:
with torch.no_grad():
    predictions, targets = get_all_preds_and_targets(model, test_loader)

Let's first look at the count of the class predictions and the true values:

In [None]:
sns.countplot(predictions.numpy());

In [None]:
sns.countplot(targets.numpy());

Now, calculate the confusion matrix:

In [None]:
#let's print a classification report
from sklearn.metrics import classification_report, confusion_matrix
print(classification_report(targets, predictions))

In [None]:
con_mat = confusion_matrix(targets, predictions)
con_mat = con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis]
plt.figure(figsize = (10,10))
plt.title('CONFUSION MATRIX')
sns.heatmap(con_mat, cmap='coolwarm',
            yticklabels=['Healthy', 'Bacteria','COVID-19','other virus'],
            xticklabels=['Healthy', 'Bacteria','COVID-19','other virus'],
            annot=True);

(*note*: The rows indicate the true class, the columns give the predicted class.)

COVID-19 cases are identified correctly to about 77% of the time. Given the relatively small data sample on COVID-19 cases in the training set (58), I think this is not too bad.

## Outlook

There are certainly many possibilities for improvements:

1. All hyperparameters have not been fine-tuned yet. One could also try more/different settings for the data augmentation and normalization. Also, a different ratio between training and test data may lead to a better performance.

2. Increasing the training set. Obviously, the number of COVID-19 images is a bit low. Adding more data should definitely improve the 77% identification rate.

3. Try a different achitecture?

Do you have more ideas? I would be happy to hear/read your comments!
