Load Dataset

In [1]:
import argparse
import time
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from train_test import load_model_from_checkpoint, per_class_accuracy
from models import ResNet18_Scratch, ResNet18
from torch.utils.tensorboard import SummaryWriter
from train_test import test, load_model_from_checkpoint


def data_loader(batch_size):    
    '''Code taken from pytorch tutorial'''
    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, saturation=0.4, hue=0.4),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    data_dir = 'chest_xray/'
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                            data_transforms[x])
                    for x in ['train', 'val','test']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                                    shuffle=True, num_workers=8)
                    for x in ['train', 'val','test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val','test']}
    class_names = image_datasets['train'].classes

    return dataloaders, dataset_sizes, class_names

dataloader , _, _, = data_loader(32)

Test trained 'scratch' model and display per class accuracies

In [2]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
model_scratch = ResNet18_Scratch(0.1).to(device)

load_model_from_checkpoint(model_scratch, 'scratch-model.pth')
print('===========Scratch Model===========')
test(model_scratch, device, dataloader['test'])
print()
per_class_accuracy(model_scratch, device, dataloader['test'])





Test [32/624], Loss: 0.081643, Acc: 100.00
Test [352/624], Loss: 0.209811, Acc: 94.03
Test [624/624], Loss: 0.232008, Acc: 92.63

Accuracy for Normal Class: 83.33%
Accuracy for Pneumonia Class: 98.21%


Test trained 'pre-trained' model and display per class accuracies

In [4]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
model_pre = ResNet18(0.1).to(device)

load_model_from_checkpoint(model_pre, 'pretrained-model.pth')
print('==========Pretrained Model==========')
test(model_pre, device, dataloader['test'])
print()
per_class_accuracy(model_pre, device, dataloader['test'])



Test [32/624], Loss: 0.337024, Acc: 90.62
Test [352/624], Loss: 0.205154, Acc: 93.75
Test [624/624], Loss: 0.208152, Acc: 93.91

Accuracy for Normal Class: 84.19%
Accuracy for Pneumonia Class: 99.74%
