In [1]:
from __future__ import print_function, division

import os
import time
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# pytorch imports
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.autograd import Variable
from torch.optim import lr_scheduler
from ignite.metrics import EpochMetric
from torchvision import datasets, transforms
from ignite.handlers import ModelCheckpoint, EarlyStopping
from ignite.contrib.metrics import roc_auc, ROC_AUC, RocCurve
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.metrics import Accuracy, Precision, Recall, ConfusionMatrix, Fbeta, Loss, MetricsLambda

##### Functions to help set the final classifier layer and other basic model features

In [2]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [3]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    # variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

##### Pre-process the data to make sure it is properly formatted for training and evaluation

In [4]:
# pre-process the image data for training and validation sets

def preprocess_data(train_data_dir = '/Users/jacksimonson/Documents/CBIS-DDSM Train',
                    test_data_dir = '/Users/jacksimonson/Documents/CBIS-DDSM Val'):

    # transform the train and validation data
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # standard pytorch normalization values
        ]),
        'val': transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }


    # collect in a dictionary
    image_datasets = {}
    image_datasets['train'] = datasets.ImageFolder(train_data_dir, data_transforms['train'])
    image_datasets['val'] = datasets.ImageFolder(test_data_dir, data_transforms['val'])


    # create data loaders
    dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    class_names = image_datasets['val'].classes

    return class_names, dataset_sizes, dataloaders_dict

##### This function is the global one. It initializes the model, the parameters that need to be trained/tuned, processes the data, and runs the trainier and executor.

In [5]:
def build_and_train(model_name = 'alexnet', use_pretrained = True, num_classes = 2, num_epochs = 25, device = torch.device("cpu"),
                    feature_extract = True, criterion = nn.CrossEntropyLoss()):
    # Initialize the model for this run
    model = models.resnet18(pretrained=use_pretrained) if model_name == 'resnet' else models.resnet18(pretrained=use_pretrained)
    for param in model.parameters():
        param.requires_grad = False
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    input_size = 224
    model, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

    # Print the model we just instantiated
    print(model)

    # get data downloaders
    class_names, dataset_sizes, dataloaders_dict = preprocess_data()
    
    # Send the model to GPU/CPU
    model = model.to(device)

    # Gather the parameters to be optimized/updated in this run
    params_to_update = model.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                print("\t",name)

    # Observe that all parameters are being optimized
    optimizer_ft = optim.Adam(params_to_update, lr=0.001)
    
    train_dl = dataloaders_dict['train']
    val_dl = dataloaders_dict['val']
    
    
    # create a trainer and evaluator using the ignite package to train and evaluate the models
    trainer = create_supervised_trainer(model, optimizer_ft, criterion, device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': Accuracy(),
                                                     'confusion': ConfusionMatrix(num_classes = 2),
                                                     'loss': Loss(criterion)},
                                            device=device)
    
    def thresholded_output_transform(output):
        y_pred, y = output
        y_pred = torch.sigmoid(y_pred)
        return y_pred, y
             
     # attach additional metrics
    precision = Precision(average=False)
    recall = Recall(average=False)
    F1 = (precision * recall * 2 / (precision + recall)).mean()

    precision.attach(evaluator, 'precision')
    recall.attach(evaluator, 'recall')
    F1.attach(evaluator, 'F1')
    
    # why won't this work??? bad input shape
#     ROC_AUC(output_transform=thresholded_output_transform).attach(evaluator, 'ROC')

    
    
    val_accuracy = []
    train_accuracy = []
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        evaluator.run(train_dl)
        metrics = evaluator.state.metrics
        train_accuracy.append(metrics['accuracy'])
#         print(
#             f"Training Results   - Epoch: {trainer.state.epoch}  "
#             f"accuracy: {metrics['accuracy']:.2f} "
#             f"loss: {metrics['loss']:.2f} "
#             f"prec: {metrics['precision'].cpu()} "
#             f"recall: {metrics['recall'].cpu()} "
#             f"F1: {metrics['F1']}")

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_dl)
        metrics = evaluator.state.metrics
        val_accuracy.append(metrics['accuracy'])
#         print(
#             f"Validation Results - Epoch: {trainer.state.epoch}  "
#             f"accuracy: {metrics['accuracy']:.2f} "
#             f"loss: {metrics['loss']:.2f} "
#             f"prec: {metrics['precision'].cpu()} "
#             f"recall: {metrics['recall'].cpu()} "
#             f"F1: {metrics['F1']}")

    trainer.run(train_dl, max_epochs=50)
    return trainer, evaluator, train_accuracy, val_accuracy

In [6]:
# run the analysis on alexnet and resnet
models_final = {}
for model in ['resnet']:
    trainer, evaluator, train_accuracy, val_accuracy = build_and_train(model)
    models_final[model] = (trainer, evaluator, train_accuracy, val_accuracy)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
models_final['resnet'][2]

[0.5761458846722524,
 0.6816165598817151,
 0.6037456875308034,
 0.6155741744701824,
 0.6954164613109907,
 0.7097092163627403,
 0.6515524889107935,
 0.6407097092163627,
 0.7195662888122227,
 0.7304090685066535,
 0.7121734844751109,
 0.6663380975850173,
 0.7102020699852144,
 0.6323311976343026,
 0.6752094627895515,
 0.7225234105470675,
 0.7259733859043864,
 0.6919664859536717,
 0.7299162148841793,
 0.7220305569245934,
 0.6407097092163627,
 0.6461310990635781,
 0.7249876786594381,
 0.7146377525874815,
 0.6791522917693446,
 0.7220305569245934,
 0.7387875800887137,
 0.6274026614095614,
 0.6939379004435683,
 0.720551996057171,
 0.7387875800887137,
 0.681123706259241,
 0.7161163134549039,
 0.7171020206998522,
 0.6136027599802859,
 0.7392804337111878,
 0.7353376047313948,
 0.6885165105963529,
 0.7028092656481025,
 0.7294233612617053,
 0.7230162641695417,
 0.7111877772301627,
 0.6382454411039921,
 0.7348447511089207,
 0.7136520453425332,
 0.700837851158206,
 0.7037949728930508,
 0.6239526860522

In [8]:
models_final['resnet'][3]

[0.6833333333333333,
 0.5833333333333334,
 0.7083333333333334,
 0.65,
 0.6083333333333333,
 0.6,
 0.6583333333333333,
 0.675,
 0.6333333333333333,
 0.625,
 0.6416666666666667,
 0.5,
 0.6166666666666667,
 0.675,
 0.5166666666666667,
 0.6166666666666667,
 0.625,
 0.6583333333333333,
 0.625,
 0.5833333333333334,
 0.6583333333333333,
 0.49166666666666664,
 0.6166666666666667,
 0.65,
 0.6833333333333333,
 0.5666666666666667,
 0.6333333333333333,
 0.65,
 0.55,
 0.5833333333333334,
 0.6166666666666667,
 0.7083333333333334,
 0.6333333333333333,
 0.5416666666666666,
 0.6916666666666667,
 0.5666666666666667,
 0.6333333333333333,
 0.6833333333333333,
 0.7,
 0.6333333333333333,
 0.6083333333333333,
 0.65,
 0.6916666666666667,
 0.65,
 0.5583333333333333,
 0.65,
 0.6833333333333333,
 0.6666666666666666,
 0.6916666666666667,
 0.55]

In [9]:
# plt.plot(range(len(train_acc)-1), train_acc)
# plt.plot(range(len(val_acc)-1), val_acc)
# plt.show()

In [10]:
# val_acc