# Summary

- I approach the challenge as a two level classification problem. 
- The null values in the dataset are filled with predefined labels (See the code).
- I dropped the Label column and used the first and second category labels to classify the image in two levels.
- I randomly select 15% of the COVID-19 labeled images from train set and moved them to test set.
- I removed the samples with labels, Stress-smoking, Streptococcus and SARS to reduce the dimensionality.
- I used pretrained resnet18 as the main feature extractor to benefit from transfer learning. 
- I implemented a 2 blocks of CNN-BatchNorm-Relu layers as a side network to resnet.
- Input image is passed through resnet and the model predicts the Level-1 classes
- The **center-cropped version** of the input image is passed through the side network. 
- The model predicts the Level-2 classes using both side network's output and Level-1 predictions.
- The combination of Level-1 and Level-2 prediction losses are combined and the model's weights are updated using the combined loss

I was inspired by this excellent [notebook](https://www.kaggle.com/timstefaniak/multi-classification-of-x-ray-images) for data preparing and visualization.

## Import Necessary Packages

In [None]:
from typing import Union, List

import os
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from torchvision import models
from torch.cuda import device_count
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.metrics.functional import f1, accuracy

from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix

sns.set()

In [None]:
data_dir = "../input/coronahack-chest-xraydataset"
image_dir = os.path.join(data_dir, "Coronahack-Chest-XRay-Dataset", "Coronahack-Chest-XRay-Dataset")
train_dir = os.path.join(data_dir, "Coronahack-Chest-XRay-Dataset", "Coronahack-Chest-XRay-Dataset", "train")
test_dir = os.path.join(data_dir, "Coronahack-Chest-XRay-Dataset", "Coronahack-Chest-XRay-Dataset", "test")
model_path = "model"

In [None]:
IMG_RESIZE = 224

## Read Metadata Dataframe and Analyze

In [None]:
meta_df = pd.read_csv(os.path.join(data_dir, 'Chest_xray_Corona_Metadata.csv'), index_col=[0])
meta_df.info()

### Read Summary of MetaData

In [None]:
meta_summary_df = pd.read_csv(os.path.join(data_dir, 'Chest_xray_Corona_dataset_Summary.csv'), index_col=[0])
print(meta_summary_df)

There are only 2 samples with Pnemonia that have Stress-Smoking as Label_1. I dropped these rows to reduce the output dimensionality. 

Again, there are only 5 samples with Label_2 "Streptococcus" and 4 samples with Label_2 "SARS", to keep it simple I also removed these labels from dataset and assigned NaN to Label_2 of these samples.

In [None]:
# drop rows with Label1, 'Stress-Smoking'
meta_df.drop(meta_df[meta_df['Label_1_Virus_category'] == 'Stress-Smoking'].index, inplace=True)
# assign None value to Label2, to samples with Label2, 'SARS' and 'Streptococcus'
meta_df.loc[meta_df[meta_df['Label_2_Virus_category'] == 'SARS'].index, 'Label_2_Virus_category'] = np.NaN
meta_df.loc[meta_df[meta_df['Label_2_Virus_category'] == 'Streptococcus'].index, 'Label_2_Virus_category'] = np.NaN

### Find Null Values and Plot Histogram

In [None]:
missing_values = meta_df.isnull().sum()
missing_values.loc[['Label_2_Virus_category', 'Label_1_Virus_category']].plot.barh()

#### Replace Null Values with Prespecified Labels

Since I will do a multi-stage classification, I fill the NaN values in Label_1 column of samples with 'Normal' label as 'Normal' again

Then, I assign 
- 'Bacteria-unknown' to Label_2 column of all the samples with 'bacteria' Label_1
- 'Normal-2' to Label_2 column of all the samples with 'Normal' Label_1
- 'Virus-unknown' to Label_2 column of all the samples with 'Virus' Label_1 and without 'COVID-19' Label_2

Finally I change the names in Label_1 as below
- 'Normal' -> 'Normal'
- 'Virus'-> 'Pnemonia-Virus'
- 'bacteria' -> 'Pnemonia-Bacteria'

In [None]:
column_nan_values = {'Label_1_Virus_category': 'Normal'}
meta_df.fillna(value=column_nan_values, inplace=True)

In [None]:
meta_df.loc[meta_df['Label_1_Virus_category'] == 'bacteria', 'Label_2_Virus_category'] = 'Bacteria-unknown'
meta_df.loc[meta_df['Label_1_Virus_category'] == 'Normal', 'Label_2_Virus_category'] = 'Normal-2'
meta_df.loc[(meta_df['Label_1_Virus_category'] == 'Virus') & (meta_df['Label_2_Virus_category'] != 'COVID-19'), 'Label_2_Virus_category'] = 'Virus-unknown'

In [None]:
new_label_2_labels = {'Virus': 'Pnemonia-Virus', 'bacteria': 'Pnemonia-Bacteria'}
meta_df['Label_1_Virus_category'].replace(new_label_2_labels, inplace=True)

### Assign Integer Labels to Use in Training and Testing the Model
We can drop the 'Label' column as we will not be using it anymore

In [None]:
meta_df.drop(columns=['Label'], inplace=True)
print(meta_df.columns)

In [None]:
level_1_labels_to_ids = {
    'Normal' : 0,
    'Pnemonia-Virus': 1,
    'Pnemonia-Bacteria': 2
}

level_2_labels_to_ids = {
    'Normal-2' : 0,
    'Virus-unknown' : 1,
    'COVID-19' : 2,
    'Bacteria-unknown': 3
}

# Keep the reverse mappings as well to use it when restoring label names from model predictions later on.
level_1_id2label = {v: k for k, v in level_1_labels_to_ids.items()}
level_2_id2label = {v: k for k, v in level_2_labels_to_ids.items()}

Since we will be classifying the images in two stages, we will assign two-level labels as level_1_target and level_2_target.

level_1_target and level_2_target will be representing the Label_1 and Label_2 columns respectively.

In [None]:
meta_df['level_1_target'] = meta_df['Label_1_Virus_category'].map(level_1_labels_to_ids)
meta_df['level_2_target'] = meta_df['Label_2_Virus_category'].map(level_2_labels_to_ids)

Let's see the new version of our dataframe

In [None]:
meta_df.sample(10)

### Balance Test Set Labels

There are 58 samples with Label_2 COVID-19 in train set but there is no samples with such a condition in test set. In order to measure the accuracy on COVID-19 labels during testing, I randomly selected 15% of the samples with COVID-19 labels and added them to test set.

In [None]:
# get 15% of the COVID-19 labeled samples' indices randomly
test_idx = meta_df[meta_df['Label_2_Virus_category'] == 'COVID-19'].sample(frac=0.15, random_state=1).index
# meta_df.loc[test_idx, 'Dataset_type'] = 'TEST'

In [None]:
print("Number of test samples before adding COVID-19 samples: ", len(meta_df[meta_df['Dataset_type'] == 'TEST']))
print("Number of train samples before removing COVID-19 samples: ", len(meta_df[meta_df['Dataset_type'] == 'TRAIN']))

Get the COVID-19 labeled samples as a separate dataframe

In [None]:
covid_samples = meta_df.loc[test_idx]
print(f"{len(covid_samples)} COVID-19 samples will be moved to test set")

1. Concat covid_samples with test samples
2. Remove the intersection of train samples and covid_samples from train_df

In [None]:
test_df = pd.concat([meta_df[meta_df['Dataset_type'] == 'TEST'], covid_samples])
train_df = meta_df[meta_df['Dataset_type'] == 'TRAIN']
train_df = train_df[~train_df['X_ray_image_name'].isin(covid_samples['X_ray_image_name'])]

In [None]:
print("Number of test samples after adding COVID-19 samples: ", len(test_df))
print("Number of train_df samples after removing COVID-19 samples: ", len(train_df))

Just in case, let's verify that all image files are accessible.

In [None]:
assert all([os.path.isfile(os.path.join(image_dir,dset.lower(),filename)) for filename, dset in train_df[['X_ray_image_name', 'Dataset_type']].values])
assert all([os.path.isfile(os.path.join(image_dir,dset.lower(),filename)) for filename, dset in test_df[['X_ray_image_name', 'Dataset_type']].values])

### Print the Frequency of Labels in Train and Test Data 

In [None]:
print("***Train data***\n")
print("-Label_1 Frequency-\n", train_df['Label_1_Virus_category'].value_counts(), "\n")
print("-Label_2 Frequency-\n", train_df['Label_2_Virus_category'].value_counts(), "\n")

print("***Test data***\n")
print("-Label_1 Frequency-\n", test_df['Label_1_Virus_category'].value_counts(), "\n")
print("-Label_2 Frequency-\n", test_df['Label_2_Virus_category'].value_counts(), "\n")

##### Print the Length of Train and Test Sample Counts

In [None]:
print(f"Train set length: {len(train_df)}")
print(f"Test set length: {len(test_df)}")

In [None]:
class CovidDataset(Dataset):
    """Covid19 Chest X-Ray dataset class

    Args:
        df (pandas.DataFrame): DataFrame that contains meta_data about dataset.
        root_dir: (str): Relative path to root directory that contains images
        transform (callable, optional): Optional transform to be applied on a sample.
    """

    def __init__(self,
                 df,
                 root_dir,
                 transform):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_metadata = self.df.iloc[idx]
        img_path = os.path.join(self.root_dir,img_metadata['Dataset_type'].lower(), img_metadata['X_ray_image_name'])

        image = Image.open(img_path).convert("RGB")
        
        # apply transforms to imaga, i.e. resize, normalize, rescale
        image = self.transform(image)

        target_1 = torch.as_tensor(img_metadata['level_1_target'])
        target_2 = torch.as_tensor(img_metadata['level_2_target'])

        sample = {
            'image': image,
            'target_1': target_1,
            'target_2': target_2
        }
        return sample

## Build the Model

In [None]:
class Identity(nn.Module):
    """
    No operation layer. 
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class TwoLevelClassifier(nn.Module):

    def __init__(self,
                 num_level_1_classes,
                 num_level_2_classes,
                 img_size):
        super().__init__()

        h, w = img_size
        # edges of cropped image
        self.h1 = h - h // 2
        self.h2 = h + h // 2
        self.w1 = w - w // 2
        self.w2 = w + w // 2

        self.resnet = models.resnet18(pretrained=True, progress=True)
        resnet_features = self.resnet.fc.in_features
        
        # We will not be using resnet's classifier directly, so change it to identity layer
        self.resnet.fc = Identity()
        
        # side stack is where the center cropped image is fed. It consists of
        # 2 Conv-BatchNorm-Relu blocks, AdaptiveAvgPooling and a Linear layer that
        # projects the features to number of level 1 dimensions
        self.side_stack = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=2, padding=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32,64, kernel_size=3, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,64, kernel_size=3, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, num_level_1_classes)
        )

        # Possible outcomes of level1 'Normal', 'Virus' or 'Bacteria'
        self.level1_classifier = nn.Linear(resnet_features, num_level_1_classes)
        # Possible outcomes of level2 'Normal-2', 'Virus-unknown', 'Virus-COVID-19', 'Bacteria-unknown'
        self.level2_classifier = nn.Linear(num_level_1_classes, num_level_2_classes)

    def forward(self, x):
        
        # get the features from pretrained model
        features = self.resnet(x)
        
        # predict level 1 classes
        logits1 = self.level1_classifier(features)
        
        # center crop the image
        cropped_x = x[:, :, self.h1:self.h2, self.w1:self.w2]
        
        # pass the cropped image to side_stack to calculate level 2 features
        # and add level 1 class predictions to level 2 features so that 
        # level 1 classes can have an impact on level 2 classes
        level_2_feed = self.side_stack(cropped_x) + logits1
        
        # predict level 2 classes
        logits2 = self.level2_classifier(level_2_feed)

        return logits1, logits2

### Define a wrapper Pytorch Lightning Class to Train and Evaluate the Model

In [None]:
class ModelWrapper(LightningModule):
    def __init__(self, hparams, df_train=None, df_test=None):
        super().__init__()

        self.df_train = df_train
        self.df_test = df_test
        self.hparams = hparams
        self.batch_size = self.hparams['batch_size']
        self.lr = self.hparams['lr']
        self.num_workers = self.hparams['num_workers']

        if df_train is not None:  #
            train_transforms = transforms.Compose([
                # transforms.ToPILImage(mode='RGB'),
                transforms.Resize([IMG_RESIZE, IMG_RESIZE]),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])

            test_transforms = transforms.Compose([
                # transforms.ToPILImage('RGB'),
                transforms.Resize([IMG_RESIZE, IMG_RESIZE]),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])

            self.train_dataset = CovidDataset(df=df_train, root_dir=self.hparams['image_dir'], transform=train_transforms)
            self.test_dataset = CovidDataset(df=df_test, root_dir=self.hparams['image_dir'], transform=test_transforms)

        self.model = TwoLevelClassifier(
            num_level_1_classes=self.hparams['num_level1_classes'],
            num_level_2_classes=self.hparams['num_level2_classes'],
            img_size=self.hparams['img_size']
        )

        # self.model.to(self.device)

        self.loss_level1 = nn.CrossEntropyLoss()
        self.loss_level2 = nn.CrossEntropyLoss(weight=torch.as_tensor(self.hparams['label2_weights']))

        self.loss_weights = self.hparams['loss_weights']

    def forward(self, batch):
        return self.model(batch)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size,
                          num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(dataset=self.test_dataset, batch_size=self.batch_size,
                          num_workers=self.num_workers, shuffle=False)

    def training_step(self, batch, batch_idx):

        x, y_level1, y_level2 = batch['image'], \
                                batch['target_1'], \
                                batch['target_2']

        logits1, logits2 = self(x)

        loss1 = self.loss_level1(logits1, y_level1)
        loss2 = self.loss_level2(logits2, y_level2)

        loss = loss1 * self.loss_weights[0] + loss2 * self.loss_weights[1]

        self.log('train/loss', loss, on_step=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):

        x, y_level1, y_level2 = batch['image'], \
                                batch['target_1'], \
                                batch['target_2']
        logits1, logits2 = self(x)

        loss1 = self.loss_level1(logits1, y_level1)
        loss2 = self.loss_level2(logits2, y_level2)
        loss = loss1 + loss2

        level1_preds = torch.argmax(logits1, dim=1)
        level2_preds = torch.argmax(logits2, dim=1)

        level1_acc = accuracy(level1_preds, y_level1)
        level2_acc = accuracy(level2_preds, y_level2)
        level1_f1 = f1(level1_preds, y_level1, self.hparams['num_level1_classes'])
        level2_f1 = f1(level2_preds, y_level2, self.hparams['num_level2_classes'])

        logs = loss, level1_acc, level2_acc, level1_f1, level2_f1

        return logs

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x[0] for x in outputs]).mean()
        avg_level1_acc = torch.stack([x[1] for x in outputs]).mean()
        avg_level2_acc = torch.stack([x[2] for x in outputs]).mean()
        avg_level1_f1 = torch.stack([x[3] for x in outputs]).mean()
        avg_level2_f1 = torch.stack([x[4] for x in outputs]).mean()

        self.log('val/loss', avg_loss, prog_bar=True, logger=True, on_epoch=True)
        self.log('val/level1_acc', avg_level1_acc, logger=True, on_epoch=True)
        self.log('val/level2_acc', avg_level2_acc, logger=True, on_epoch=True)
        self.log('val/level1_f1', avg_level1_f1, prog_bar=True, logger=True, on_epoch=True)
        self.log('val/level2_f1', avg_level2_f1, prog_bar=True, logger=True, on_epoch=True)

    def configure_optimizers(self):

        if self.hparams['optimizer'] == 'adam':
            optimizer = optim.Adam(self.model.parameters(), self.lr)
        else:  # SGDWithMomentum
            optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)

        return {
            'optimizer': optimizer,
            'lr_scheduler': optim.lr_scheduler.StepLR(optimizer, 
                                                      step_size=self.hparams['sch_step_size'], 
                                                      gamma=self.hparams['sch_gamma'])
        }
        # return optimizer

    def on_train_end(self):
        ckpt_path = os.path.join(self.trainer.log_dir, "checkpoints", "min_val_loss.ckpt")
        print(f"Loading best checkpoint from {ckpt_path}")
        best_model_ = ModelWrapper.load_from_checkpoint(ckpt_path)

        save_dir = self.hparams['model_save_dir']
        
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        print(f"Saving only pytorch model without the wrapper properties to {os.path.join(save_dir, 'best_model.pt')}")
        torch.save(best_model_.model, os.path.join(save_dir, "best_model.pt"))

### Set Training Parameters

In [None]:
MAX_EPOCHS = 10
BATCH_SIZE = 64
lr = 4e-4

params = {
    'batch_size': BATCH_SIZE,
    'lr': lr,
    'sch_step_size': 3,  # StelLR scheduler step size
    'sch_gamma': 0.5,  # StepLR scheduler gamma value
    'optimizer': 'adam',  # SGD with momentum or Adam
    'num_workers': 4,  # number of worker processed for dataloaders
    'num_level1_classes': 3,  # number of level 1 classes
    'num_level2_classes': 4,  # number of level 2 classes
    'label2_weights': [0.1, 0.1, 0.8, 0.1],  # level 2 class weights, because of unbalanced dataset
    'loss_weights': [0.5, 1],  # penalize level 1 predictions as it converges faster than level 2
    'img_size': (IMG_RESIZE, IMG_RESIZE),  # img resize
    'image_dir': image_dir,
    'model_save_dir': 'model',
}

### Train the model with Pytorch Lightning

In [None]:
wrapper = ModelWrapper(hparams=params, df_train=train_df, df_test=test_df)

gpu_num = device_count()

checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    verbose=True,
    monitor='val/loss',
    mode='min',
    filename='min_val_loss'
)

trainer = Trainer(
    default_root_dir=os.getcwd(),
    gpus=gpu_num,
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback]
)

trainer.fit(wrapper)

### Load the Best Model
While training, we saved the model checkpoint where it reached the minimum validation loss to model_path/best_model.pt

In [None]:
best_model = torch.load(os.path.join(model_path, "best_model.pt"))

### Load Test Set for Evaluation

In [None]:
test_transforms = transforms.Compose([
    # transforms.ToPILImage('RGB'),
    transforms.Resize([IMG_RESIZE, IMG_RESIZE]),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_dataset = CovidDataset(df=test_df, root_dir=image_dir, transform=test_transforms)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Define Evaluation Method

In [None]:
def evaluate(model, dataloader, device_):
    print("Evaluating...")
    model.to(device_).eval()
    with torch.no_grad():

        level1_preds = []
        level1_targets = []

        level2_preds = []
        level2_targets = []

        for batch in dataloader:
            x, y_level1, y_level2 = batch['image'].to(device_), \
                                    batch['target_1'].to(device_), \
                                    batch['target_2'].to(device_)
            logits1, logits2 = model(x)

            batch_level1_preds = torch.argmax(logits1, dim=1)
            batch_level2_preds = torch.argmax(logits2, dim=1)

            level1_preds.extend(batch_level1_preds.tolist())
            level2_preds.extend(batch_level2_preds.tolist())

            level1_targets.extend(y_level1.tolist())
            level2_targets.extend(y_level2.tolist())

    return level1_preds, level1_targets, level2_preds, level2_targets

## Evaluate

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

preds_1, targets_1, preds_2, targets_2 = evaluate(best_model, test_dataloader, device)

### Print Precision, Recall and F1 Scores

In [None]:
print("\t\t***\tLEVEL 1 CLASSIFICATION METRICS\t***")
print(classification_report(targets_1, preds_1, target_names=list(level_1_id2label.values()), zero_division=0))
print("\t\t***\tLEVEL 2 CLASSIFICATION METRICS\t***")
print(classification_report(targets_2, preds_2, target_names=list(level_2_id2label.values()), zero_division=0))

### Calculate Confusion Matrices

In [None]:
level_1_conf_mat = confusion_matrix(targets_1, preds_1)
level_1_conf_mat = level_1_conf_mat.astype(np.float) / level_1_conf_mat.sum(axis=1)[:, np.newaxis]

level_2_conf_mat = confusion_matrix(targets_2, preds_2)
level_2_conf_mat = level_2_conf_mat.astype(np.float) / level_2_conf_mat.sum(axis=1)[:, np.newaxis]

#### Plot Confusion Matrix Heatmaps for Level 1 and Level 2 Predictions

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 10))

axs[0].title.set_text("LEVEL 1 CONFUSION MATRIX")
axs[1].title.set_text("LEVEL 2 CONFUSION MATRIX")

sns.heatmap(
    level_1_conf_mat,
    cmap='coolwarm',
    yticklabels=list(level_1_id2label.values()),
    xticklabels=list(level_1_id2label.values()),
    annot=True,
    ax=axs[0]
)

sns.heatmap(
    level_2_conf_mat,
    cmap='coolwarm',
    yticklabels=list(level_2_id2label.values()),
    xticklabels=list(level_2_id2label.values()),
    annot=True,
    ax=axs[1]
)