In [2]:
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data._utils.collate import default_collate
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
from torch.utils.tensorboard import SummaryWriter


In [3]:
class ResNet18_Scratch(nn.Module):
    def __init__(self):
        super(ResNet18_Scratch, self).__init__()
        self.model = models.resnet18(pretrained=False)
        num_ftrs = self.model.fc.in_features
        # Here the size of each output sample is set to 2.
        # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names))
        self.model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_ftrs, 2)
        )


    def forward(self, x):
        return self.model(x)
        



In [4]:
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        num_ftrs = self.model.fc.in_features
        # Here the size of each output sample is set to 2.
        self.model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_ftrs, 2)
        )


    def forward(self, x):
        return self.model(x)



In [10]:
# model = ResNet18()
# print(model)

In [11]:
import time 
import copy

def train_model(model, dataloaders, dataset_sizes, device, criterion, optimizer, scheduler, writer, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # train_set = np.zeros(num_epochs)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            if phase == 'train':
                writer.add_scalar("Loss/train_batch", epoch_loss, epoch)
                writer.add_scalar("Accuracy/train_batch", epoch_acc, epoch)        
                print()
            
            else:
                writer.add_scalar("Loss/val_batch", epoch_loss, epoch)
                writer.add_scalar("Accuracy/val_batch", epoch_acc, epoch)

            print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [12]:
def test(model, device, test_loader, epoch, writer):
    torch.cuda.empty_cache()
    model.eval()
    correct, total = 0, 0
    total_loss = 0.0
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            data, target = batch[0].to(device), batch[-1].to(device)
            output = model(data)
            loss = F.cross_entropy(output, target)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)
            running_acc = 100 * correct / total
            total_loss += loss.item()
            running_loss = total_loss / (batch_idx + 1)
            if batch_idx % 10 == 0 or batch_idx == len(test_loader) - 1:
                print("Test [{}/{}], Loss: {:.6f}, Acc: {:.2f}".format(
                    total, len(test_loader.dataset), running_loss, running_acc))
                
        writer.add_scalar("Loss/test_batch", running_loss, epoch)
        writer.add_scalar("Accuracy/test_batch", running_acc, epoch)

'''Loads a saved model from a checkpoint'''
def load_model_from_checkpoint(model: nn.Module, load_path):
    model.load_state_dict(torch.load(load_path))

# model = ResNet18_Scratch().to(device)


In [13]:
def data_loader(batch_size):    
    '''Code taken from pytorch tutorial'''
    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, saturation=0.4, hue=0.4),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    data_dir = 'chest_xray/'
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                            data_transforms[x])
                    for x in ['train', 'val','test']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                                    shuffle=True, num_workers=8)
                    for x in ['train', 'val','test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val','test']}
    class_names = image_datasets['train'].classes

    return dataloaders, dataset_sizes, class_names

In [22]:
dt, ds, cn = data_loader(16)
ds

{'train': 5216, 'val': 16, 'test': 624}

In [14]:
import argparse

def run_main(FLAGS):
    
    # Check if cuda is available
    use_cuda = torch.cuda.is_available()
    
    # Set proper device based on cuda availability 
    device = torch.device("cuda" if use_cuda else "cpu")
    print("Torch device selected: ", device)
    
    # Initialize the model and send to device
    if FLAGS.mode == 'scratch':
        model = ResNet18_Scratch().to(device) 
    else:   
        model = ResNet18().to(device)
    
    # ======================================================================
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, weight_decay=0.0001)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    writer = SummaryWriter()

    dataloaders, dataset_sizes, class_names = data_loader(FLAGS.batch_size)
        
    best_model = train_model(model, dataloaders, dataset_sizes, device, criterion, optimizer, scheduler, 
                             writer, num_epochs=FLAGS.num_epochs)
    
    torch.save(best_model.state_dict(), "model.pth")

    load_model_from_checkpoint(best_model, "model.pth")
    test(best_model, device, dataloaders['test'], 1, writer)
    
    writer.close()
    print("Training and evaluation finished")
    
    # Plot the results
    # plot_accuracies(train_set, test_set,str(FLAGS.mode),FLAGS.num_epochs)
    
    
    
if __name__ == '__main__':
    # Evaluate model performance in terms of time.
    start = time.time()
    
    # Set parameters for Sparse Autoencoder
    parser = argparse.ArgumentParser('ResNet used to detect pneumonia in chest x-rays.')
    parser.add_argument('--mode',
                        type=str, default='pretrained',
                        help="Select between 'scratch' and 'pretrained'.")
    parser.add_argument('--learning_rate',
                        type=float, default= 0.001,
                        help='Initial learning rate.')
    parser.add_argument('--num_epochs',
                        type=int,
                        default=30,
                        help='Number of epochs to run trainer.')
    parser.add_argument('--batch_size',
                        type=int, default=16,
                        help='Batch size. Must divide evenly into the dataset sizes.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=str,
                        help='Directory to put logging.')
    
    FLAGS = None
    FLAGS, unparsed = parser.parse_known_args()
    
    run_main(FLAGS)
    print("Total Running time = {:.3f} seconds".format(time.time() - start))

Torch device selected:  cuda
Epoch 0/29
----------
train Loss: 0.4354 Acc: 0.8158


val Loss: 0.7000 Acc: 0.5000

Epoch 1/29
----------
train Loss: 0.3541 Acc: 0.8520


val Loss: 0.4268 Acc: 0.7500

Epoch 2/29
----------
train Loss: 0.3076 Acc: 0.8758


val Loss: 0.4137 Acc: 0.8750

Epoch 3/29
----------
train Loss: 0.2944 Acc: 0.8777


val Loss: 0.5297 Acc: 0.7500

Epoch 4/29
----------
train Loss: 0.2971 Acc: 0.8771


val Loss: 0.5083 Acc: 0.7500

Epoch 5/29
----------
train Loss: 0.2875 Acc: 0.8831


val Loss: 0.7968 Acc: 0.6250

Epoch 6/29
----------
train Loss: 0.2704 Acc: 0.8875


val Loss: 0.4042 Acc: 0.8125

Epoch 7/29
----------
train Loss: 0.2169 Acc: 0.9114


val Loss: 0.5519 Acc: 0.7500

Epoch 8/29
----------
train Loss: 0.2029 Acc: 0.9191


val Loss: 0.3241 Acc: 0.8125

Epoch 9/29
----------
train Loss: 0.1960 Acc: 0.9202


val Loss: 0.2432 Acc: 0.9375

Epoch 10/29
----------
train Loss: 0.1835 Acc: 0.9302


val Loss: 0.4042 Acc: 0.8125

Epoch 11/29
----------
train Loss: 