In [1]:
# !touch __init__.py
# !cp 'drive/MyDrive/COVID-19_Radiography_Database.zip' .
# !unzip COVID-19_Radiography_Database.zip
# !rm COVID-19_Radiography_Database.zip
# !find "/content/COVID-19_Radiography_Dataset" -type f | wc -l  # 21170

In [2]:
# !python split_dataset.py

In [3]:
# !find "/content/COVID-19_Radiography_Dataset" -type f | wc -l  # 21170
# !zip -r temp.zip "COVID-19_Radiography_Dataset/"
# !mv temp.zip drive/MyDrive/

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import WeightedRandomSampler

import numpy as np
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
transform = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
}

In [4]:
dirs = {
    'train': 'data/COVID-19_Radiography_Dataset/train',
    'val': 'data/COVID-19_Radiography_Dataset/val'
}
train_set = datasets.ImageFolder(root=dirs['train'], transform=transform['train'])
val_set = datasets.ImageFolder(root=dirs['val'], transform=transform['val'])

In [5]:
class_freq = torch.as_tensor(train_set.targets).bincount()
weight = 1 / class_freq
samples_weight = weight[train_set.targets]
sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, sampler=sampler)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=32)

In [21]:
# count = dict((c, 0) for c in train_set.classes)
# idx_to_class = dict((v, k) for (k, v) in train_set.class_to_idx.items())
# for _, l in torch.utils.data.DataLoader(train_set, sampler=sampler):
#     count[idx_to_class[l.item()]] += 1

# count

{'covid_19': 5270, 'lung_opacity': 5235, 'normal': 5256, 'pneumonia': 5244}

In [6]:
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(in_features=512, out_features=4)
resnet18.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet18.parameters(), lr=3e-5)

val_loss_min = np.Inf
num_epochs = 10

len_train = len(train_set)
len_val = len(val_set)

for epoch in range(num_epochs):
    train_loss, train_correct = 0, 0
    train_loop = tqdm(train_loader)
    resnet18.train()

    for batch in train_loop:
        images, labels = batch[0].to(device), batch[1].to(device)
        preds = resnet18(images)
        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * labels.size(0)
        train_correct += get_num_correct(preds, labels)

        train_loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        train_loop.set_postfix(loss=loss.item(), acc=train_correct/len_train)


    resnet18.eval()
    with torch.no_grad():

        val_loss = 0
        for batch in val_loader:
            images, labels = batch[0].to(device), batch[1].to(device)
            preds = resnet18(images)
            loss = criterion(preds, labels)
            val_loss += loss.item() * labels.size(0)

        train_loss = train_loss/len_train
        val_loss = val_loss/len_val
        train_loop.write(f'\t\tAvg training loss: {train_loss:.6f}\tAvg validation loss: {val_loss:.6f}')

        # save model if validation loss has decreased
        if val_loss <= val_loss_min:
            train_loop.write(f'\t\tval_loss decreased ({val_loss_min:.6f} --> {val_loss:.6f})  saving model...')
            torch.save(resnet18.state_dict(), 'models/lr3e-5_resnet18_gpu.pth')
            val_loss_min = val_loss

Epoch [ 1/10]: 100%|██████████| 1313/1313 [04:28<00:00,  4.89it/s, acc=0.912, loss=0.12]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.244246	Avg validation loss: 0.056423
		val_loss decreased (inf --> 0.056423)  saving model...


Epoch [ 2/10]: 100%|██████████| 1313/1313 [04:28<00:00,  4.89it/s, acc=0.956, loss=0.0205]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.126862	Avg validation loss: 0.049716
		val_loss decreased (0.056423 --> 0.049716)  saving model...


Epoch [ 3/10]: 100%|██████████| 1313/1313 [04:28<00:00,  4.89it/s, acc=0.967, loss=0.0337]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.096389	Avg validation loss: 0.039880
		val_loss decreased (0.049716 --> 0.039880)  saving model...


Epoch [ 4/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.87it/s, acc=0.975, loss=0.00805]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.073555	Avg validation loss: 0.026281
		val_loss decreased (0.039880 --> 0.026281)  saving model...


Epoch [ 5/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.87it/s, acc=0.98, loss=0.0165]
Epoch [ 6/10]:   0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.063953	Avg validation loss: 0.032789


Epoch [ 6/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.86it/s, acc=0.98, loss=0.026]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.058138	Avg validation loss: 0.043833


Epoch [ 7/10]: 100%|██████████| 1313/1313 [04:30<00:00,  4.85it/s, acc=0.984, loss=0.028]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.048760	Avg validation loss: 0.027824


Epoch [ 8/10]: 100%|██████████| 1313/1313 [04:30<00:00,  4.85it/s, acc=0.986, loss=0.00897]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.040949	Avg validation loss: 0.022712
		val_loss decreased (0.026281 --> 0.022712)  saving model...


Epoch [ 9/10]: 100%|██████████| 1313/1313 [04:30<00:00,  4.86it/s, acc=0.988, loss=0.000634]
  0%|          | 0/1313 [00:00<?, ?it/s]

		Avg training loss: 0.036118	Avg validation loss: 0.035958


Epoch [10/10]: 100%|██████████| 1313/1313 [04:29<00:00,  4.88it/s, acc=0.99, loss=0.0028]


		Avg training loss: 0.031628	Avg validation loss: 0.016979
		val_loss decreased (0.022712 --> 0.016979)  saving model...
