In [1]:
%matplotlib inline
import time
import os
import copy
import shutil
import random
import torch
import torchvision
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torch.optim import lr_scheduler

print('Pytorch version  ', torch.__version__)

# Load Data

In [2]:
root_dir = "../input/covid19-radiography-database/COVID-19_Radiography_Dataset"
source_dirs= ["NORMAL", "Viral Pnemuonia", "COVID"]

In [3]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs, transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x[-3:].lower().endswith('png')]
            print(f'Found {len(images)} {class_name} examples')
            return images
        
        self.images = {}
        self.class_names = ['normal', 'viral', 'covid']
        
        for class_name in self.class_names:
            self.images[class_name] = get_images(class_name)
            
        self.image_dirs = image_dirs
        self.transform = transform
  
    
    def __len__(self):
        return sum([len(self.images[class_name]) for class_name in self.class_names])
    
    
    def __getitem__(self, index):
        class_name = random.choice(self.class_names)
        index = index % len(self.images[class_name])
        image_name = self.images[class_name][index]
        image_path = os.path.join(self.image_dirs[class_name], image_name)
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), self.class_names.index(class_name)
        

In [4]:
train_transform = torchvision.transforms.Compose([
                                                  torchvision.transforms.Resize(size = (300,300)),
                                                  torchvision.transforms.RandomHorizontalFlip(),
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize(mean = [0.485,0.456,0.406], std = [0.229, 0.224, 0.225])
])


test_transform = torchvision.transforms.Compose([
                                                  torchvision.transforms.Resize(size = (300,300)),
                                                 
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize([0.485,0.456,0.406], [0.229, 0.224, 0.225])
])

In [5]:
train_dirs = {
    "normal" : "../input/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal",
    "viral" : "../input/covid19-radiography-database/COVID-19_Radiography_Dataset/Viral Pneumonia",
    "covid" : "../input/covid19-radiography-database/COVID-19_Radiography_Dataset/COVID"
}

In [6]:
train_dataset = ChestXRayDataset(train_dirs, train_transform)

In [7]:
train_size = int(0.8 * len(train_dataset))
test_size = len(train_dataset) - train_size #80-20 split train-test


train_dataset, test_dataset = random_split(train_dataset, [train_size, test_size])
print("Length of train set   :  ", len(train_dataset))
print("Length of test set    :  ", len(test_dataset))

In [8]:
batch_size =16
train_loader = DataLoader(train_dataset, batch_size= batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


print("Length of training batches", len(train_loader))
print("Lentgth of test batches", len(test_loader))

# Visualize

In [9]:
class_names = ["normal x-ray", "viral x-ray","covid positive"]

def show_images(images,labels,preds):
    plt.figure(figsize=(30,15))
    for i , image in enumerate(images):
        plt.subplot(1,17,i+1, xticks = [], yticks= [])
        image = image.numpy().transpose((1,2,0))
        mean = np.array([0.5, 0.5, 0.4])
        std = np.array([0.2, 0.2, 2])
        image = image*std + mean
        image = np.clip(image,0.,1.)
        plt.imshow(image)
        
        colorr = "green"
        if(preds[i] != labels[i]):
            colorr="red"
            
        plt.xlabel(f'{class_names[int(labels[i].numpy())]}')
        plt.ylabel(f'{class_names[int(preds[i].numpy())]}', color=colorr)
    plt.tight_layout()
    plt.show()

In [10]:
def test_predicts(model,test_loader):
    model.to(device)
    resnet18.eval()
    images, labels = next(iter(test_loader))
    images = images.to(device)
    labels = labels.to(device)
    outputs = resnet18(images).cpu()
    _, preds = torch.max(outputs, 1)
    show_images(images.cpu(), labels.cpu(), preds.cpu())

In [11]:
images, labels = next(iter(train_loader))
show_images(images, labels, labels)

# Model

In [12]:
resnet18 = resnet18(pretrained = True)
resnet18

In [13]:
resnet18.fc = torch.nn.Linear(in_features=512, out_features=3)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters(), lr=3e-5)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [14]:
dataloaders = {}
dataloaders['train'] = train_loader
dataloaders['val'] = test_loader

In [15]:
dataset_sizes = {}
dataset_sizes['train'] = len(train_dataset)
dataset_sizes['val'] = len(test_dataset)

# Train

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
resnet18 = resnet18.to(device)

## Early stopping

In [17]:
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
            print('improvement!')
        else:
            self.num_bad_epochs += 1
            print(f'no improvement, bad_epochs counter: {self.num_bad_epochs}')

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [18]:
def train_model(model,dataloaders, criterion, optimizer, scheduler, num_epochs=10, patience=3):
    es = EarlyStopping(patience=patience)
    terminate_training = False
    
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        since_train = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            
            time_elapsed_training = time.time() - since
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            print('{} Epochh complete in {:.0f}m {:.0f}s'.format(phase,time_elapsed_training // 60, time_elapsed_training % 60))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, 'best-model-parameters.pt')
                
                
                
            if phase == 'val' and es.step(epoch_loss):
                terminate_training = True
                print('Early Stop')
                break
    
        print()
        if terminate_training:
            break

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # save and load best model weights
    torch.save(best_model_wts, 'best-model-parameters.pt')
    model.load_state_dict(best_model_wts)
    return model

In [19]:
model = train_model(resnet18, dataloaders, loss_fn, optimizer, exp_lr_scheduler, num_epochs=10)

In [20]:
test_predicts(model, test_loader)