In [None]:
# libraries required

import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, ConcatDataset, Subset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

In [None]:
# X-Ray COVID Positive: create the datasets 

class Covid_ChestXray_Dataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 24])
#         print(img_path)
        with Image.open(img_path) as image:
            label = self.img_labels.iloc[idx, 5]

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# , transforms.ConvertImageDtype(torch.uint8)  
            ])

            image = transform(image)
            
            return image, label

class HFlipped_Covid_ChestXray_Dataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 24])
#         print(img_path)
        with Image.open(img_path) as image:
            label = self.img_labels.iloc[idx, 5]

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# , transforms.ConvertImageDtype(torch.uint8)  
            ])

            image = transform(image)

            return image, label

class VFlipped_Covid_ChestXray_Dataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 24])
#         print(img_path)
        with Image.open(img_path) as image:
            label = self.img_labels.iloc[idx, 5]

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# , transforms.ConvertImageDtype(torch.uint8)   
            ])

            image = transform(image)

            return image, label

class Rotated_Covid_ChestXray_Dataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 24])
#         print(img_path)
        with Image.open(img_path) as image:
            label = self.img_labels.iloc[idx, 5]

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# , transforms.ConvertImageDtype(torch.uint8)  
            ])

            image = transform(image)

            return image, label

class Translated_Covid_ChestXray_Dataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 24])
#         print(img_path)
        with Image.open(img_path) as image:
            label = self.img_labels.iloc[idx, 5]

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomAffine(0, (0.2,0.2)),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

from torch.utils.data import DataLoader

# specify the dataset path and the metadata
metadata_path = 'dataset/covid-chestxray-dataset/processed_metadata.csv'
images_path = 'dataset/covid-chestxray-dataset/images/'

# create the datasets for the augmented images
initial_covid_chestxray_dataset = Covid_ChestXray_Dataset(metadata_path, images_path)
h_flipped_covid_chestxray_dataset = HFlipped_Covid_ChestXray_Dataset(metadata_path, images_path)
v_flipped_covid_chestxray_dataset = VFlipped_Covid_ChestXray_Dataset(metadata_path, images_path)
rotated_covid_chestxray_dataset = Rotated_Covid_ChestXray_Dataset(metadata_path, images_path)
translated_covid_chestxray_dataset = Translated_Covid_ChestXray_Dataset(metadata_path, images_path)

# combine the datasets
combined_covid_chestxray_dataset = ConcatDataset([initial_covid_chestxray_dataset, h_flipped_covid_chestxray_dataset, v_flipped_covid_chestxray_dataset, rotated_covid_chestxray_dataset, translated_covid_chestxray_dataset])

# randomly sample 2500 of these
covid_chestxray_dataset = Subset(combined_covid_chestxray_dataset, np.random.choice(len(combined_covid_chestxray_dataset), 2500, replace=False))

In [None]:
# X-Ray Healthy Class: create datasets 

class healthy_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

        return image, label

class healthy_HFlipped_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

        return image, label

class healthy_VFlipped_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

        return image, label

class healthy_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class healthy_HFlipped_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class healthy_VFlipped_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class healthy_Translated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomAffine(0, (0.2,0.2)),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class healthy_Translated_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomAffine(0, (0.2,0.2)),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class healthy_Perspective_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 0]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomPerspective(0.6, 1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

# specify the path for metadata and images
metadata_path = 'processed_metadata.csv'
images_path = 'dataset/chest-xray8/images'

# create the datasets for the augmented images
healthy_chest_xray8_dataset = healthy_Chest_Xray8(metadata_path, images_path)

healthy_h_flipped_chest_xray8_dataset = healthy_HFlipped_Chest_Xray8(metadata_path, images_path)
healthy_v_flipped_chest_xray8_dataset = healthy_VFlipped_Chest_Xray8(metadata_path, images_path)
healthy_rotated_chest_xray8_dataset = healthy_Rotated_Chest_Xray8(metadata_path, images_path)
healthy_h_flipped_rotated_chest_xray8_dataset = healthy_HFlipped_Rotated_Chest_Xray8(metadata_path, images_path)
healthy_v_flipped_rotated_chest_xray8_dataset = healthy_VFlipped_Rotated_Chest_Xray8(metadata_path, images_path)
healthy_translated_chest_xray8_dataset = healthy_Translated_Chest_Xray8(metadata_path, images_path)
healthy_translated_rotated_chest_xray8_dataset = healthy_Translated_Rotated_Chest_Xray8(metadata_path, images_path)
healthy_perspective_chest_xray8_dataset = healthy_Perspective_Chest_Xray8(metadata_path, images_path)

# combine all the augmented image datasets
healthy_combined_chest_xray8_dataset = ConcatDataset([healthy_chest_xray8_dataset, healthy_h_flipped_chest_xray8_dataset, healthy_v_flipped_chest_xray8_dataset, healthy_rotated_chest_xray8_dataset, healthy_h_flipped_rotated_chest_xray8_dataset, healthy_v_flipped_rotated_chest_xray8_dataset, healthy_translated_chest_xray8_dataset, healthy_translated_rotated_chest_xray8_dataset, healthy_perspective_chest_xray8_dataset])

# randomly subsample 2500 images
healthy_chest_xray8_dataset = Subset(healthy_combined_chest_xray8_dataset, np.random.choice(len(healthy_combined_chest_xray8_dataset), 2500, replace=False))

In [None]:
# X-Ray Other Class: create datasets 

class other_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

        return image, label

class other_HFlipped_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

        return image, label

class other_VFlipped_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

        return image, label

class other_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class other_HFlipped_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class other_VFlipped_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class other_Translated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomAffine(0, (0.2,0.2)),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class other_Translated_Rotated_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomAffine(0, (0.2,0.2)),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class other_Perspective_Chest_Xray8(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_labels = self.img_labels[self.img_labels['Finding Labels'] == 2]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomPerspective(0.6, 1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

# specify metadata and images path
metadata_path = 'processed_metadata.csv'
images_path = 'dataset/chest-xray8/images'

# create the augmented image datasets
other_chest_xray8_dataset = other_Chest_Xray8(metadata_path, images_path)

other_h_flipped_chest_xray8_dataset = other_HFlipped_Chest_Xray8(metadata_path, images_path)
other_v_flipped_chest_xray8_dataset = other_VFlipped_Chest_Xray8(metadata_path, images_path)
other_rotated_chest_xray8_dataset = other_Rotated_Chest_Xray8(metadata_path, images_path)
other_h_flipped_rotated_chest_xray8_dataset = other_HFlipped_Rotated_Chest_Xray8(metadata_path, images_path)
other_v_flipped_rotated_chest_xray8_dataset = other_VFlipped_Rotated_Chest_Xray8(metadata_path, images_path)
other_translated_chest_xray8_dataset = other_Translated_Chest_Xray8(metadata_path, images_path)
other_translated_rotated_chest_xray8_dataset = other_Translated_Rotated_Chest_Xray8(metadata_path, images_path)
other_perspective_chest_xray8_dataset = other_Perspective_Chest_Xray8(metadata_path, images_path)

# combine the augmented images datasets
other_combined_chest_xray8_dataset = ConcatDataset([other_chest_xray8_dataset, other_h_flipped_chest_xray8_dataset, other_v_flipped_chest_xray8_dataset, other_rotated_chest_xray8_dataset, other_h_flipped_rotated_chest_xray8_dataset, other_v_flipped_rotated_chest_xray8_dataset, other_translated_chest_xray8_dataset, other_translated_rotated_chest_xray8_dataset, other_perspective_chest_xray8_dataset])

# subsample 2500 scans
other_chest_xray8_dataset = Subset(other_combined_chest_xray8_dataset, np.random.choice(len(other_combined_chest_xray8_dataset), 2500, replace=False))

In [None]:
# CT Scan Positive COVID-19 Class: create dataset

class Sars_Cov_2_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = ['Covid ('+str(i)+').png' for i in range(1, 1251)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 1

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class Rotated_Sars_Cov_2_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = ['Covid ('+str(i)+').png' for i in range(1, 1251)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 1

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

# specify the images path
images_path = 'dataset/sars-cov-2-ct-scan/images/COVID/'

# create the augmented image datasets
Sars_Cov_2_CT_dataset = Sars_Cov_2_CT(images_path)
rotated_Sars_Cov_2_CT_dataset = Rotated_Sars_Cov_2_CT(images_path)

# combine the augmented datasets (2500 images exactly)
Sars_Cov_2_CT_dataset = ConcatDataset([Sars_Cov_2_CT_dataset, rotated_Sars_Cov_2_CT_dataset])

In [None]:
# CT Scan Healthy Class: create the dataset

from os import listdir

# specify the image path
images_path = 'dataset/covid-ct/images/'
class Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class HFlipped_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class VFlipped_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class Rotated_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class HFlipped_rotated_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class VFlipped_rotated_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomVerticalFlip(p=1.0),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class Translated_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomAffine(0, (0.2,0.2)),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class Translated_Rotated_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomAffine(0, (0.2,0.2)),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class Perspective_Covid_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = [f for f in listdir(images_path)]
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 0

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomPerspective(0.6, 1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

# create the augmented image datasets
covid_ct_dataset = Covid_CT(images_path)

h_flipped_covid_ct_dataset = HFlipped_Covid_CT(images_path)
v_flipped_covid_ct_dataset = VFlipped_Covid_CT(images_path)
rotated_covid_ct_dataset = Rotated_Covid_CT(images_path)
translated_covid_ct_dataset = Translated_Covid_CT(images_path)
translated_rotated_covid_ct_dataset = Translated_Rotated_Covid_CT(images_path)
h_flipped_rotated_covid_ct_dataset = HFlipped_rotated_Covid_CT(images_path)
v_flipped_rotated_covid_ct_dataset = VFlipped_rotated_Covid_CT(images_path)
perspective_covid_ct_dataset = Perspective_Covid_CT(images_path)

# combine the augmented image datasets
combined_covid_ct_dataset = ConcatDataset([covid_ct_dataset, h_flipped_covid_ct_dataset, v_flipped_covid_ct_dataset, rotated_covid_ct_dataset, h_flipped_rotated_covid_ct_dataset, v_flipped_rotated_covid_ct_dataset, translated_covid_ct_dataset, translated_rotated_covid_ct_dataset, perspective_covid_ct_dataset])

# subsample 2500 datapoints
covid_ct_dataset = Subset(combined_covid_ct_dataset, np.random.choice(len(combined_covid_ct_dataset), 2500, replace=False))

In [None]:
# CT Scan Other Class: create dataset

class other_Sars_Cov_2_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = os.listdir(img_dir)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class other_Rotated_Sars_Cov_2_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = os.listdir(img_dir)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomRotation(15),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label

class other_HFlip_Sars_Cov_2_CT(Dataset):
    def __init__(self, img_dir):
        self.img_labels = os.listdir(img_dir)
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        with Image.open(img_path) as image:
            label = 2

            transform = transforms.Compose([
                #transforms.ToPILImage(),
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # , transforms.ConvertImageDtype(torch.uint8)    
            ])

            image = transform(image)

            return image, label


# specify the image path
images_path = 'dataset/sars-cov-2-ct-scan/images/non-COVID/'

# create the augmentged image datasets
other_Sars_Cov_2_CT_dataset = other_Sars_Cov_2_CT(images_path)
other_rotated_Sars_Cov_2_CT_dataset = other_Rotated_Sars_Cov_2_CT(images_path)
other_HFlip_Sars_Cov_2_CT_dataset = other_HFlip_Sars_Cov_2_CT(images_path)

# combine the augmented image datasets
other_Sars_Cov_2_CT_dataset = ConcatDataset([other_Sars_Cov_2_CT_dataset, other_rotated_Sars_Cov_2_CT_dataset, other_HFlip_Sars_Cov_2_CT_dataset])

# subsample 2500 datapoints
other_Sars_Cov_2_CT_dataset = Subset(other_Sars_Cov_2_CT_dataset, np.random.choice(len(other_Sars_Cov_2_CT_dataset), 2500, replace=False))

In [None]:
# Create the final dataloaders 

from torch.utils.data import random_split
from torchvision.utils import make_grid

# x ray dataloaders
covid_xray_dataset = covid_chestxray_dataset
healthy_xray_dataset = healthy_chest_xray8_dataset
other_xray_dataset = other_chest_xray8_dataset

# ct dataloaders
covid_positive_ct_dataset = Sars_Cov_2_CT_dataset
healthy_ct_dataset = covid_ct_dataset
other_ct_dataset = other_Sars_Cov_2_CT_dataset

# create combined datasets lists
all_datasets = [covid_xray_dataset, healthy_xray_dataset, other_xray_dataset, covid_positive_ct_dataset, healthy_ct_dataset, other_ct_dataset]
x_ray_datasets = [covid_xray_dataset, healthy_xray_dataset, other_xray_dataset]
ct_datasets = [covid_positive_ct_dataset, healthy_ct_dataset, other_ct_dataset]

# create combined datasets
combined_dataset = ConcatDataset(all_datasets)
x_ray_combined = ConcatDataset(x_ray_datasets)
ct_combined = ConcatDataset(ct_datasets)

# split combined dataset into train, validation and test splits
train_size = int(0.7 * len(combined_dataset))
val_size = int(0.2 * len(combined_dataset))
test_size = int(0.1 * len(combined_dataset))
train_dataset, val_dataset, test_dataset = random_split(combined_dataset, [train_size, val_size, test_size])

# create train, validation and test dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

print('Total Dataset Size: ', len(combined_dataset))

In [None]:
# Show an example batch from the dataloader

def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    plt.axis('off')
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

def show_databatch(inputs, classes):
    out = make_grid(inputs)
    imshow(out)

# Get a batch of training data
inputs, classes = next(iter(x_ray_train_healthy_or_not_dataloader))
show_databatch(inputs, classes)

In [None]:
# Functions: Training loop, testing loop, model evaluation, heatmap production

import time
import copy
from torch.autograd import Variable
from sklearn.metrics import confusion_matrix


def train_model(model, criterion, optimizer, training_dataloader, validating_dataloader, training_dataset, validating_dataset, num_epochs=10, use_gpu=False):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    avg_loss = 0
    avg_acc = 0
    avg_loss_val = 0
    avg_acc_val = 0
    
    train_batches = len(training_dataloader)
    val_batches = len(validating_dataloader)
    
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs))
        print('-' * 10)
        
        loss_train = 0
        loss_val = 0
        acc_train = 0
        acc_val = 0
        
        model.train(True)
        
        for i, data in enumerate(training_dataloader):
            if i % 100 == 0:
                print("\rTraining batch {}/{}".format(i, train_batches / 2), end='', flush=True)
                
            # Use half training dataset
            if i >= train_batches / 2:
                break
                
            inputs, labels = data
            
#             print(data)
            
            if use_gpu:
                inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            else:
                inputs, labels = Variable(inputs), Variable(labels)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            
            _, preds = torch.max(outputs.data, 1)
            
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            loss_train += loss.data
            acc_train += torch.sum(preds == labels.data)
            
            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()
        
        print()
        # * 2 as we only used half of the dataset
        avg_loss = loss_train * 2 / len(training_dataset)
        avg_acc = acc_train * 2 / len(training_dataset)
        
        model.train(False)
        model.eval()
            
        for i, data in enumerate(validating_dataloader):
            if i % 100 == 0:
                print("\rValidation batch {}/{}".format(i, val_batches), end='', flush=True)
                
            inputs, labels = data
            
            if use_gpu:
                inputs, labels = Variable(inputs.cuda(), volatile=True), Variable(labels.cuda(), volatile=True)
            else:
                inputs, labels = Variable(inputs, volatile=True), Variable(labels, volatile=True)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            loss_val += loss.data
            acc_val += torch.sum(preds == labels.data)
            
            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()
        
        avg_loss_val = loss_val / len(validating_dataset)
        avg_acc_val = acc_val / len(validating_dataset)
        
        print()
        print("Epoch {} result: ".format(epoch))
        print("Avg loss (train): {:.4f}".format(avg_loss))
        print("Avg acc (train): {:.4f}".format(avg_acc))
        print("Avg loss (val): {:.4f}".format(avg_loss_val))
        print("Avg acc (val): {:.4f}".format(avg_acc_val))
        print('-' * 10)
        print()
        
        if avg_acc_val > best_acc:
            best_acc = avg_acc_val
            best_model_wts = copy.deepcopy(model.state_dict())
        
    elapsed_time = time.time() - since
    print()
    print("Training completed in {:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60))
    print("Best acc: {:.4f}".format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model

def test_model(model, criterion, optimizer_ft, test_dataloader, test_dataset, use_gpu):
    model.train(False)
    model.eval()

    loss_test = 0
    acc_test = 0

    test_labels = []
    test_preds = []

    test_batches = len(test_dataloader)
    for i, data in enumerate(test_dataloader):
        if i % 100 == 0:
            print("\Test batch {}/{}".format(i, test_batches), end='', flush=True)

        inputs, labels = data
        
        if use_gpu:
            inputs, labels = Variable(inputs.cuda(), volatile=True), Variable(labels.cuda(), volatile=True)
        else:
            inputs, labels = Variable(inputs), Variable(labels)

        optimizer_ft.zero_grad()

        outputs = model(inputs)

        _, preds = torch.max(outputs.data, 1)
        loss = criterion(outputs, labels)

        all_labels = labels.data.cpu()
        all_preds = preds.cpu()
        for i in range(len(labels.data.cpu())):
            test_labels.append(all_labels[i])
            test_preds.append(all_preds[i])

        loss_test += loss.data
        acc_test += torch.sum(preds == labels.data)

        del inputs, labels, outputs, preds
        torch.cuda.empty_cache()

    avg_loss_test = loss_test / len(test_dataset)
    avg_acc_test = acc_test / len(test_dataset)

    print()
    print("Avg loss (test): {:.4f}".format(avg_loss_test))
    print("Avg acc (test): {:.4f}".format(avg_acc_test))
    print('-' * 10)
    print()
    
    return test_labels, test_preds
    
from sklearn.metrics import ConfusionMatrixDisplay, classification_report

# produce confusion matrix and classification report
def evaluate_model(test_labels, test_preds):
    conf_matrix = confusion_matrix(test_labels, test_preds)

    disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix,
                                  display_labels=['healthy', 'covid', 'other'])

    classif_report = classification_report(test_labels, test_preds)
    print(classif_report)

    print(disp.plot())
    print(conf_matrix)

    RocCurveDisplay.from_predictions(test_labels, test_preds, pos_label=1)
    
# produce a heatmap for the given model on a sample from the dataloader
def show_heatmap(model, test_dataloader, vgg16=False, densenet=False, darknet=False, efficientnet=False):    
    if vgg16 or densenet:
        target_layer = model.features[-1]
    elif darknet:
        target_layer = model.conv5[-4]
    elif efficientnet:
        target_layer = model.features[-1]
        
    # ensure its a heatmap of a COVID positive image
    index = 0
    for i, data in enumerate(test_dataloader):
        inputs, labels = data
        
        for i in range(len(labels)):
            input_tensor = inputs[i,:]
            label = np.array(labels[i])
            if label == 1:
                index = i

        break 

    input_tensor = inputs[index,:]
    label = np.array(labels[index])
    input_tensor = torch.tensor(np.expand_dims(input_tensor, axis=0))

    from pytorch_grad_cam import GradCAM
    cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=True)
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    targets = [ClassifierOutputTarget(1)]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    
    # ROAD Most Relevant First calculation 
    from pytorch_grad_cam.metrics.road import ROADMostRelevantFirst, ROADLeastRelevantFirst
    # alter the percentile argument to test at different percentiles of pixel importances
    cam_metric = ROADMostRelevantFirst(percentile=90)
    scores, perturbation_visualizations = cam_metric(input_tensor.cuda(), 
    grayscale_cam, targets, model, return_visualization=True)

    # specify the metric and the percentiles for ROAD Combined
    from pytorch_grad_cam.metrics.road import ROADCombined
    cam_metric = ROADCombined(percentiles=[50, 60, 70, 80, 90])
    road_combined_scores = cam_metric(input_tensor.cuda(), grayscale_cam, targets, model)
    print(f"Combined metric avg confidence increase with ROAD accross 5 thresholds (positive is better): {road_combined_scores[0]}")

    # produce the image after pixels perturbated
    perturbation_visualizations = (perturbation_visualizations+1)*0.5
    perturbation_visualizations = (np.array(perturbation_visualizations[0].cpu() * 255)).astype(np.uint8)
    perturbation_visualizations = perturbation_visualizations.transpose((1, 2, 0))
    evaluation_out = Image.fromarray(perturbation_visualizations)

    input_copy = input_tensor.detach().clone().cuda()
    output = model(input_copy)
    _, prediction = torch.max(output.data, 1)

    grayscale_cam = grayscale_cam[0, :]

    from pytorch_grad_cam.utils.image import show_cam_on_image
    import cv2 

    # overlay the heatmap onto the original image
    new_input_tensor = input_tensor[0,:]

    img = np.array(new_input_tensor)

    new_img = img.transpose((1, 2, 0))

    new_img = (new_img+1)*0.5

    new_img = np.uint8(255*new_img)


    transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Grayscale(num_output_channels=3),
                ])

    new_img = np.array(transform(new_img)) 

    heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_HOT)

    # heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

    image_weight = 0.7

    cam = (1 - image_weight) * heatmap + image_weight * new_img

    cam = np.uint8(cam)

    out = Image.fromarray(np.hstack((new_img, heatmap, cam)))
    
    return out, label, prediction, scores, evaluation_out

In [None]:
# VGG16

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg16

use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")

vgg16 = vgg16()
num_features = vgg16.classifier[6].in_features
features = list(vgg16.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 3)]) # Change output size to 3
vgg16.classifier = nn.Sequential(*features) # Replace the model classifier

vgg16.cuda()

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(vgg16.parameters(), lr=0.001, momentum=0.9)

# vgg16 = train_model(vgg16, criterion, optimizer_ft, train_dataloader, val_dataloader, train_dataset, val_dataset, num_epochs=100, use_gpu=use_gpu)
# torch.save(vgg16.state_dict(), './trained_models/VGG16/VGG16.pt')
vgg16.load_state_dict(torch.load('./trained_models/VGG16/VGG16.pt'))

import time 

# evaluate the trained model on the test set
start = time.time()
test_labels, test_preds = test_model(vgg16, criterion, optimizer_ft, test_dataloader, test_dataset, True)
end = time.time()

# produce confusion matrix and classification report
evaluate_model(test_labels, test_preds)

print('time taken: ', end - start)

# output the perturbated image and ROAD scores
heatmap, label, prediction, scores, perturbation_visualizations = show_heatmap(vgg16, test_dataloader, vgg16=True)
print(scores)
print(perturbation_visualizations)
file = './trained_models/VGG16/evaluation_heatmap.jpg'
perturbation_visualizations.save(file)

# save the heatmap
file = './trained_models/VGG16/heatmap.jpg'
heatmap.save(file)

In [None]:
# DenseNet201

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import densenet201

use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")

densenet = densenet201()
num_features = densenet.classifier.in_features
features = list(densenet.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 3)]) # Change output size to 3
densenet.classifier = nn.Sequential(*features) # Replace the model classifier

densenet.cuda()

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(densenet.parameters(), lr=0.001, momentum=0.9)

# densenet = train_model(densenet, criterion, optimizer_ft, train_dataloader, val_dataloader, train_dataset, val_dataset, num_epochs=100, use_gpu=use_gpu)
# torch.save(densenet.state_dict(), './trained_models/DenseNet201/DenseNet201.pt')
densenet.load_state_dict(torch.load('./trained_models/DenseNet201/DenseNet201.pt'))

import time 

# evaluate the trained model on the test set
start = time.time()
test_labels, test_preds = test_model(densenet, criterion, optimizer_ft, test_dataloader, test_dataset, True)
end = time.time()

# produce the confusion matrix and classification report
evaluate_model(test_labels, test_preds)
print('time taken: ', end - start)

# output the perturbated image and ROAD scores
heatmap, label, prediction, scores, perturbation_visualizations = show_heatmap(densenet, test_dataloader, densenet=True)
print(scores)
print(perturbation_visualizations)
file = './trained_models/DenseNet201/evaluation_heatmap.jpg'
perturbation_visualizations.save(file)

# save the heatmap
file = './trained_models/DenseNet201/heatmap.jpg'
heatmap.save(file)

In [None]:
# DarkNet19

import torch
import torch.nn as nn
import torch.optim as optim

# Specify Global Avg Pooling layer required for DarkNet19
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
        
    def forward(self, x):
        N = x.data.size(0)
        C = x.data.size(1)
        H = x.data.size(2)
        W = x.data.size(3)
        
        x = nn.functional.avg_pool2d(x, (H,W))
        x = x.view(N, C)
        
        return x

# The DarkNet19 architecture
class DarkNet19(nn.Module):
    def __init__(self):
        super(DarkNet19, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))
            
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.conv3 = nn.Sequential(    
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))
            
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.conv5 = nn.Sequential(    
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))
            
        self.conv6 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(1024)
        )
        
        self.classifier = nn.Sequential(    
            nn.Conv2d(1024, 3, kernel_size=(1,1), stride=(1,1)),
            GlobalAvgPool2d(),
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        return self.classifier(x)
    
import torch
import torch.nn as nn
import torch.optim as optim

use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")

darknet = DarkNet19()

darknet.cuda()

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(darknet.parameters(), lr=0.001, momentum=0.9)

# darknet = train_model(darknet, criterion, optimizer_ft, train_dataloader, val_dataloader, train_dataset, val_dataset, num_epochs=100, use_gpu=use_gpu)
# torch.save(darknet.state_dict(), './trained_models/DarkNet19/DarkNet19.pt')
darknet.load_state_dict(torch.load('./trained_models/DarkNet19/DarkNet19.pt'))

import time 

# evaluate the trained model on the test set
start = time.time()
test_labels, test_preds = test_model(darknet, criterion, optimizer_ft, test_dataloader, test_dataset, True)
end = time.time()

# produce the confusion matrix and classification report
evaluate_model(test_labels, test_preds)
print('time taken: ', end - start)

# output the perturbated image and ROAD scores
heatmap, label, prediction, scores, perturbation_visualizations = show_heatmap(darknet, test_dataloader, darknet=True)
print(scores)
print('Label here: ', label)
print(perturbation_visualizations)
file = './trained_models/DarkNet19/evaluation_heatmap.jpg'
perturbation_visualizations.save(file)

# save the heatmap
file = './trained_models/DarkNet19/heatmap.jpg'
heatmap.save(file)

In [None]:
# EfficientNetB0

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import efficientnet_b0

use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")

efficientnet = efficientnet_b0()
num_features = efficientnet.classifier[1].in_features
features = list(efficientnet.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 3)]) # Change output size to 3
efficientnet.classifier = nn.Sequential(*features) # Replace the model classifier

efficientnet.cuda()

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(efficientnet.parameters(), lr=0.001, momentum=0.9)

# efficientnet = train_model(efficientnet, criterion, optimizer_ft, train_dataloader, val_dataloader, train_dataset, val_dataset, num_epochs=100, use_gpu=use_gpu)
# torch.save(efficientnet.state_dict(), './trained_models/EfficientNetB0/EfficientNetB0.pt')
efficientnet.load_state_dict(torch.load('./trained_models/EfficientNetB0/EfficientNetB0.pt'))

import time 

# evaluate the trained model on the test set
start = time.time()
test_labels, test_preds = test_model(efficientnet, criterion, optimizer_ft, test_dataloader, test_dataset, True)
end = time.time()

# produce the confusion matrix and classification report
evaluate_model(test_labels, test_preds)
print('time taken: ', end - start)

# produce the perturbated image and ROAD scores
heatmap, label, prediction, scores, perturbation_visualizations = show_heatmap(efficientnet, test_dataloader, efficientnet=True)
print(scores)
print(perturbation_visualizations)
file = './trained_models/EfficientNetB0/evaluation_heatmap.jpg'
perturbation_visualizations.save(file)

# save the heatmap 
file = './trained_models/EfficientNetB0/heatmap.jpg'
heatmap.save(file)

In [None]:
# Proposed Ensemble Model

import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg16

use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")

# The VGG16 model

vgg16 = vgg16()
num_features = vgg16.classifier[6].in_features
features = list(vgg16.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 3)]) # Change output size to 3
vgg16.classifier = nn.Sequential(*features) # Replace the model classifier

vgg16.load_state_dict(torch.load('./trained_models/VGG16/VGG16.pt'))

from torchvision.models import efficientnet_b0

use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")

# The EfficientNetB0 model

efficientnet = efficientnet_b0()
num_features = efficientnet.classifier[1].in_features
features = list(efficientnet.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 3)]) # Change output size to 3
efficientnet.classifier = nn.Sequential(*features) # Replace the model classifier

efficientnet.cuda()

efficientnet.load_state_dict(torch.load('./trained_models/EfficientNetB0/EfficientNetB0.pt'))

# Define the ensemble model structure
class MyEnsemble(nn.Module):

    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB

        self.fc1 = nn.Linear(3, 3)

    def forward(self, x):
        out1 = self.modelA(x)
        out2 = self.modelB(x)

        out = out1 + out2

        return self.fc1(out)

ensemblenet = MyEnsemble(vgg16, efficientnet)
ensemblenet.cuda()

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(ensemblenet.parameters(), lr=0.001, momentum=0.9)

# ensemblenet = train_model(ensemblenet, criterion, optimizer_ft, train_dataloader, val_dataloader, train_dataset, val_dataset, num_epochs=100, use_gpu=use_gpu)
# torch.save(ensemblenet.state_dict(), './latest_ensemble_method/ensemblenet.pt')
ensemblenet.load_state_dict(torch.load('./latest_ensemble_method/ensemblenet.pt'))

import time 

# evaluate the trained model on the test set
start = time.time()
test_labels, test_preds = test_model(ensemblenet, criterion, optimizer_ft, test_dataloader, test_dataset, True)
end = time.time()

# produce the confusion matrix and classification report
evaluate_model(test_labels, test_preds)

print('time taken: ', end - start)

# produce the perturbated image and ROAD scores
heatmap, label, prediction, scores, perturbation_visualizations = show_heatmap(ensemblenet, test_dataloader, ensemblenet=True)
print('scores: ', scores)
print('Label here: ', label)
print(perturbation_visualizations)
file = './latest_ensemble_method/evaluation_heatmap.jpg'
perturbation_visualizations.save(file)

# save the heatmap
file = './latest_ensemble_method/heatmap.jpg'
heatmap.save(file)