In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
from utils import get_num_correct
from custom_dataset import ChestXRayDataset


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # set the device type
device

device(type='cpu')

In [3]:
# declare tranformations
transform = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
}

In [4]:
# get train and test directories
dirs = {
    'train': {
        'covid': 'COVID-19 Radiography Database/train/covid',
        'normal': 'COVID-19 Radiography Database/train/normal',
        'viral': 'COVID-19 Radiography Database/train/viral'
    },
    'test': 'COVID-19 Radiography Database/test'
}

In [5]:
# prepare the train data-loader
train_set = ChestXRayDataset(dirs['train'], transform['train'])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=8, shuffle=True, num_workers=2)

found 179 covid examples.
found 1301 normal examples.
found 1305 viral examples.


In [6]:
# prepare the test and validation data-loader
test_set = datasets.ImageFolder(dirs['test'], transform['test'])

valid_size = 0.5  # fraction of test_set to be used as validation set

# obtain test indices that will be used for validation
num_test = len(test_set)
indices = list(range(num_test))
np.random.shuffle(indices)
split = int(np.floor(valid_size*num_test))
test_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining test and validation batches
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)

# prepare the data loaders
valid_loader = torch.utils.data.DataLoader(test_set, batch_size=8, sampler=valid_sampler, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=8, sampler=test_sampler, num_workers=2)

In [7]:
# resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18 = torchvision.models.resnet18(pretrained=False)
resnet18.load_state_dict(
    torch.load('../models/resnet18.pth',
    map_location=device)
)  # load the pretrained resnet18 model


<All keys matched successfully>

In [8]:
# change the last fc layer so that it could output 3 classes 
resnet18.fc = torch.nn.Linear(in_features=512, out_features=3)
resnet18.to(device)  # move to GPU (if available)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
criterion = nn.CrossEntropyLoss()  # loss function (categorical cross-entropy)
optimizer = optim.Adam(resnet18.parameters(), lr=3e-5)

comment = '-resnet18-covid'  # # will be used for naming the run
tb = SummaryWriter(comment=comment)

# initialize tracker for minimum validation loss
valid_loss_min = np.Inf  # set initial minimum to infinity
num_epochs = 3  # number of epochs used for training

len_train = len(train_set)
len_val = len(valid_loader.sampler)
len_test = len(test_loader.sampler)

for epoch in range(num_epochs):
    train_loss, train_correct = 0, 0  # will be used to track the running loss and correct
    #######################
    # fine-tune the model #
    #######################
    train_loop = tqdm(train_loader)
    resnet18.train()  # set the model to train mode

    for batch in train_loop:
        images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
        preds = resnet18(images)  # forward pass
        loss = criterion(preds, labels)  # calculate loss
        optimizer.zero_grad()  # clear the accumulated gradients from the previous pass
        loss.backward()  # backward pass
        optimizer.step()  # perform a single optimization step

        train_loss += loss.item() * labels.size(0)  # update the running loss
        train_correct += get_num_correct(preds, labels)  # update running num correct

        train_loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        train_loop.set_postfix(loss=loss.item(), acc=train_correct/len_train)

    # add train loss and train accuracy for the current epoch to tensorboard
    tb.add_scalar('Train Loss', train_loss, epoch)
    tb.add_scalar('Train Accuracy', train_correct/len_train, epoch)


    resnet18.eval()  # set the model to evaluation mode
    with torch.no_grad():  # turn off grad tracking, as we don't need gradients for validation

        valid_loss, valid_correct = 0, 0  # will be used to track the running validation loss and correct
        ######################
        # validate the model #
        ######################
        for batch in valid_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = resnet18(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            valid_loss += loss.item() * labels.size(0)  # update the running loss
            valid_correct += get_num_correct(preds, labels)  # update running num correct

        # add validation loss and validation accuracy for the current epoch to tensorboard
        tb.add_scalar('Validation Loss', valid_loss, epoch)
        tb.add_scalar('Validation Accuracy', valid_correct/len_val, epoch)


        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = train_loss/len_train
        valid_loss = valid_loss/len_val
        train_loop.write(f'\t\tAvg training loss: {train_loss:.6f}\tAvg validation loss: {valid_loss:.6f}')


        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            train_loop.write(f'\t\tvalid_loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f})  saving model...')
            torch.save(resnet18.state_dict(), f'./model/lr3e-5{comment}.pth')
            valid_loss_min = valid_loss


        test_loss, test_correct = 0, 0  # will be used to track the running test loss and correct
        ##################
        # test the model #
        ##################
        for batch in test_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to available device
            preds = resnet18(images)  # forward pass
            loss = criterion(preds, labels)  # calculate the loss

            test_loss += loss.item() * labels.size(0)  # update the running loss
            test_correct += get_num_correct(preds, labels)  # update running num correct

        # add test loss and test accuracy for the current epoch to tensorboard
        tb.add_scalar('Test Loss', test_loss, epoch)
        tb.add_scalar('Test Accuracy', test_correct/len_test, epoch)

Epoch [ 1/3]: 100%|██████████| 349/349 [08:58<00:00,  1.54s/it, acc=0.92, loss=1.31]
		Avg training loss: 0.235437	Avg validation loss: 0.083287
		valid_loss decreased (inf --> 0.083287)  saving model...
Epoch [ 2/3]: 100%|██████████| 349/349 [08:45<00:00,  1.51s/it, acc=0.964, loss=0.885]
		Avg training loss: 0.111461	Avg validation loss: 0.097972
Epoch [ 3/3]: 100%|██████████| 349/349 [08:38<00:00,  1.49s/it, acc=0.983, loss=1.03]
		Avg training loss: 0.061282	Avg validation loss: 0.061795
		valid_loss decreased (0.083287 --> 0.061795)  saving model...
