<a href="https://colab.research.google.com/github/jhorapb/covid19-pytorch/blob/master/covid19_finetuning_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
"""
Model for transfer learning from CheXNet by: 
- Pretraining (Feature Extraction): training only 
the output layer (last fully-connected one).
We are using here the "freezing" approach.
- Fine Tuning: updating all the weights of the model.
"""
# PyTorch imports
import torch
import torch.nn.functional as F
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import (
    datasets, 
    models, 
    transforms, 
    utils
)

# Image imports
# from skimage import io, transform
# from PIL import Image

# General imports
import os
import re
import time
from shutil import copyfile
from shutil import rmtree
from pathlib import Path

# import pandas as pd
import numpy as np
import csv

# import covid_dataset as COVID_XR
# import eval_model as E


In [3]:
root_path = '/drive/My Drive/1-COVID-19_DeepLearning/'
# os.chdir(root_path)
%cd drive/My\ Drive/1-COVID-19_DeepLearning/
!ls
# !git clone https://github.com/jhorapb/covid19-pytorch.git
# print(os.listdir())

/content/drive/My Drive/1-COVID-19_DeepLearning
covid19-pytorch


In [4]:
%cd covid19-pytorch/covid_models/
# !git pull

/content/drive/My Drive/1-COVID-19_DeepLearning/covid19-pytorch/covid_models


In [5]:

RESULTS_PATH = '../results/fine_tuning/'
use_gpu = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_gpu else "cpu")
# print("Device: " + str(device))
gpu_count = torch.cuda.device_count()
print("Available GPU count: " + str(gpu_count))


def load_checkpoint(PATH_CHECKPOINT, mode='state_dict', fine_tuning=False, 
                    unfrozen_names=None):
    #####
    # The pre-trained model checkpoint from 'reproduce-chexnet' contains:
    # state = {
    #     'model': model,
    #     'best_loss': best_loss,
    #     'epoch': epoch,
    #     'rng_state': torch.get_rng_state(),
    #     'LR': LR
    # }
    #####
    
    # Define new base model
    model_tl = models.densenet121(pretrained=False)
    model_dict = model_tl.state_dict()
    
    # Locate checkpoint
    chexnet_checkpoint = torch.load(PATH_CHECKPOINT)
    if mode == 'state_dict':
        # Load pretrained CheXNet model (mode state_dict)
        state_dict_chexnet = chexnet_checkpoint['state_dict']
        # model_tl = torch.nn.DataParallel(model_tl)
    else:
        # Load pretrained CheXNet model (mode full_model)
        chexnet_model = chexnet_checkpoint['model']
        state_dict_chexnet = chexnet_model.state_dict()
    
    # 1. Filter out unnecessary keys
    state_dict_chexnet = {k: v for k, v in state_dict_chexnet.items() 
                          if k in model_dict}
    # 2. Load the new state dict
    model_tl.load_state_dict(model_dict)    
    
    # epoch = chexnet_checkpoint['epoch']
    # loss = chexnet_checkpoint['loss']
    # LR = chexnet_checkpoint['LR']
    # print('\nPretrained model:', model_tl)
    
    # Freeze the parameters for feature extraction (ony for pretraining), 
    # NO Fine Tuning
    for parameter in model_tl.parameters():
        parameter.requires_grad = False
    # If Fine Tuning, we unfreeze 'unfrozen_names' layers 
    # (with more processing resources we could unfreeze the whole network)
    if fine_tuning:
        # If a list with the name of the layers to be enabled is received
        if unfrozen_names:
          for name, child in model_tl.named_children():
              if name == 'features':
                  for sub_name, sub_child in child.named_children():
                      if sub_name in unfrozen_names:
                          print('=>', sub_name, 'is being unfrozen.')
                          for parameter in sub_child.parameters():
                              parameter.requires_grad = True
    
    del chexnet_checkpoint
    return model_tl


def save_checkpoint(model, best_loss, best_accuracy, epoch, LR, optimizer):
    """
    Saves checkpoint of torchvision model during training.

    Args:
        model: torchvision model to be saved
        best_loss: best val loss achieved so far in training
        epoch: current epoch of training
        LR: current learning rate in training
        optimizer: pytorch optimizer to be saved
    Returns:
        None
    """
    state = {
        'model': model.state_dict(),
        'best_accuracy': best_accuracy,
        'best_loss': best_loss,
        'epoch': epoch,
        'rng_state': torch.get_rng_state(),
        'LR': LR,
        'optimizer': optimizer.state_dict(),
    }

    torch.save(state, RESULTS_PATH + 'tl_pretraining_checkpoint')

def show_optimizer_params(optimizer_state_dict):
    # Print optimizer's state_dict
    print("\nOptimizer's state_dict:")
    for var_name in optimizer_state_dict:
        print(var_name, '\t', optimizer_state_dict[var_name])

def train_model(
        model,
        criterion,
        optimizer,
        LR,
        num_epochs,
        dataloaders,
        dataset_sizes,
        weight_decay):
    """
    Fine tunes torchvision model to COVID-19 CXR data.

    Args:
        model: torchvision model to be finetuned (densenet-121 in this case)
        criterion: loss criterion (binary cross entropy loss, BCELoss)
        optimizer: optimizer to use in training (SGD)
        LR: learning rate
        num_epochs: continue training up to this many epochs
        dataloaders: pytorch train and val dataloaders
        dataset_sizes: length of train and val datasets
        weight_decay: weight decay parameter we use in SGD with momentum
    Returns:
        model: trained torchvision model
        best_epoch: epoch on which best model val loss was obtained

    """
    since = time.time()

    start_epoch = 1
    best_loss = 999999
    best_acc = 0.0
    best_epoch = -1
    last_train_loss = -1

    # Iterate over epochs
    for epoch in range(start_epoch, num_epochs + 1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 17)

        # set model to train or eval mode based on whether we are in train or
        # val; necessary to get correct predictions given batchnorm
        for phase in ['train', 'val']:
            if phase == 'train':
                # model.train(True)
                model.train(True)
            else:
                # model.train(False)
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            
            total_done = 0
            
            # Iterate over dataset (train/val)
            for inputs, labels in dataloaders[phase]:
                batch_size = inputs.shape[0]
                inputs = Variable(inputs.cuda())
                labels = Variable(labels.cuda()) #.long()
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):    
                    # Compute loss
                    outputs = model(inputs)
                    something_else, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                
                # Backward pass: compute gradient and update 
                # parameters in training 
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Computing loss for the whole batch dataset
                running_loss += loss.item() * batch_size
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == 'train':
                last_train_loss = epoch_loss

            print('{} epoch {}=> Loss: {:.4f} | Acc: {:.4f} | Data size: {}'.format(
                phase, epoch, epoch_loss, epoch_acc, dataset_sizes[phase]))

            if phase == 'val':
                # # Decay learning rate if validation loss plateaus in this epoch
                # if epoch_loss > best_loss:
                #     decayed_LR = LR / 10
                #     print('Decay Loss from {} to {} \
                #             as not seeing improvement in val loss'.format(
                #                 str(LR), str(decayed_LR))
                #             )
                    # LR = decayed_LR
                    # # Create new optimizer with lower learning rate
                    # optimizer = optim.Adam(
                    #     filter(
                    #         lambda p: p.requires_grad, 
                    #         model_tl.parameters()), 
                    #     lr=LR, betas=(0.9, 0.999))
                #     print("Created new optimizer with LR " + str(LR))
                
                # Checkpoint model if has best val loss yet
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_epoch = epoch
                    save_checkpoint(model, best_loss, best_acc, 
                                    epoch, LR, optimizer)

                # Log training and validation loss over each epoch
                with open(RESULTS_PATH + '/log_train', 'a') as logfile:
                    logwriter = csv.writer(logfile, delimiter=',')
                    if(epoch == 1):
                        logwriter.writerow(["epoch", "train_loss", "val_loss"])
                    logwriter.writerow([epoch, last_train_loss, epoch_loss])

        total_done += batch_size
        if(total_done % (100 * batch_size) == 0):
            print("completed " + str(total_done) + " so far in epoch")

        # Apply early stopping if there is no val loss improvement in 3 epochs
        if ((epoch - best_epoch) >= 8):
            print("No improvement in the model accuracy in 3 epochs, stop!")
            break

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Validation Loss: {:4f}'.format(best_loss))
    print('Best Validation Accuracy: {:4f}'.format(best_acc))
    print('Best Epoch: {}'.format(best_epoch))

    # Load best model weights to return
    checkpoint_best = torch.load(RESULTS_PATH + 'tl_pretraining_checkpoint')
    model_state_dict = checkpoint_best['model']
    model.load_state_dict(model_state_dict)

    return model, best_epoch, checkpoint_best


def perform_tl_cnn(PATH_TO_IMAGES, CHEXNET_CHECKPOINT, checkpoint_type, 
                   LR, WEIGHT_DECAY, FINE_TUNING, unfrozen_names):
    """
    Trains model to COVID-19 dataset.

    Args:
        PATH_TO_IMAGES: path to COVID-19 image data collection
        LR: learning rate
        WEIGHT_DECAY: weight decay parameter for SGD

    Returns:
        preds: torchvision model predictions on test fold with ground truth for comparison
        aucs: AUCs for each train,test tuple
    """
    NUM_EPOCHS = 12 # 10 # 20
    # Since the COVID-19 dataset at the moment is considerably small, 
    # it makes sense to use Batch Gradient Descent (all the samples 
    # being used to update the model parameters)
    minibatch_gd = False
    BATCH_SIZE = 375 if not minibatch_gd else 10

    # Create path to save model results
    Path(RESULTS_PATH).mkdir(parents=True, exist_ok=True)

    # ImageNet parameters for normalization
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]

    # Binary classifier
    # N_LABELS = 1
    # Multi-class classifier
    N_LABELS = 2

    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(means, stds)
        ]),
    }
    
    image_datasets = {x: datasets.ImageFolder(os.path.join(PATH_TO_IMAGES, x), 
                                              data_transforms[x]) 
                      for x in ['train', 'val']}
    
    # Option num. workers 8
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, 
                                                  shuffle=True, num_workers=1) 
                   for x in ['train', 'val']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    class_names = image_datasets['train'].classes
    print('CLASS NAMES\n:', class_names, '\n')
    
    # Load pre-trained CheXNet model
    model_tl = load_checkpoint(CHEXNET_CHECKPOINT, mode=checkpoint_type, 
                               fine_tuning=FINE_TUNING, 
                               unfrozen_names=unfrozen_names)

    # Verify if GPU is available
    if not use_gpu:
        raise ValueError("GPU is required")
    if use_gpu:
        model_tl = model_tl.cuda()

    # print('Pre-trained Model:\n', model_tl)
    num_ftrs = model_tl.classifier.in_features
    # Size of each output sample.
    model_tl.classifier = nn.Linear(num_ftrs, N_LABELS)
    # model_tl.classifier = nn.Sequential(nn.Linear(num_ftrs, N_LABELS), 
    #                                     nn.Sigmoid())
    model_tl.classifier = model_tl.classifier.cuda()
    print('COVID Model:\n', model_tl)
    # If multiple-class classifier were used, a Sequential
    # container would be necessary. 
    # E.g., nn.Sequential(nn.Linear(num_ftrs, N_LABELS), nn.Softmax())

    # Define Loss Function (Binary Cross-Entropy Loss)
    # criterion = nn.BCELoss()
    # criterion = nn.BCEWithLogitsLoss() # It applies a sigmoid activation internally
    criterion = nn.CrossEntropyLoss()
    # Define optimizer for the new model
    # With Adam Optimizer
    optimizer = optim.Adam(model_tl.parameters(), lr=LR, betas=(0.9, 0.999))
    # With SGD Optimizer
    # Observe that all parameters are being optimized
    # optimizer = optim.SGD(model_tl.parameters(), lr=0.001, momentum=0.9)
    
    # Decay LR by a factor of 0.1 every 7 epochs (when using SGD optimizer).
    # ---> If we use this, then we should not perform Decay LR in the training 
    # function, since we are defining a schedule already in that case.
    # exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

    # Train COVID model
    model, best_epoch, checkpoint_best = train_model(
        model_tl, criterion, optimizer, LR, num_epochs=NUM_EPOCHS, 
        dataloaders=dataloaders, dataset_sizes=dataset_sizes, weight_decay=WEIGHT_DECAY
        )
    print('Best Epoch Saved: ', checkpoint_best['epoch'])
    # Print best optimizer saved params
    show_optimizer_params(checkpoint_best['optimizer'])
    # get preds and AUCs on test fold
    # preds, aucs = E.make_pred_multilabel(
    #     data_transforms, model, PATH_TO_IMAGES)
    
    return model


if __name__ == "__main__":
    binary_classifier = True
    if binary_classifier:
        PATH_TO_IMAGES = "../images/cleaned_up/binary_classifier"
    else:
        PATH_TO_IMAGES = "../images/cleaned_up/multiclass_classifier"
    
    checkpoint_type = 'full_model'
    CHEXNET_CHECKPOINT = '../pretrained_chexnet/checkpoint'
    FINE_TUNING = True
    if FINE_TUNING:
        # There are four main denseblocks: 'denseblock1', 'denseblock2', 
        # 'denseblock3', 'denseblock4'
        unfrozen_names = ['denseblock4']
    else:
        unfrozen_names = None
    # Hyperparams for Adam Optimizer: LR=0.001, betas=(0.9, 0.999)
    LEARNING_RATE = 0.001
    WEIGHT_DECAY = 1e-4
    best_model = perform_tl_cnn(PATH_TO_IMAGES, CHEXNET_CHECKPOINT, 
                                checkpoint_type, LEARNING_RATE, WEIGHT_DECAY, 
                                FINE_TUNING, unfrozen_names)
    # print(best_model)

Available GPU count: 1
CLASS NAMES
: ['covid', 'no_covid'] 





[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        -3.8676e-07, -3.2509e-06,  6.4241e-06,  7.7353e-06, -5.0019e-06,
         1.2537e-07,  5.8401e-06, -1.1408e-05,  5.4847e-06, -6.8604e-06,
         9.0068e-06, -1.5666e-06, -7.8197e-08,  9.0389e-06, -6.2889e-06,
         1.3967e-05,  9.9197e-07, -5.3254e-06, -5.7541e-06,  3.3670e-06,
         2.3046e-06,  1.0631e-05, -1.6072e-05, -5.1620e-06, -3.3458e-06,
        -3.9052e-06,  1.1710e-05, -7.8793e-06, -2.6179e-05, -1.0661e-05,
        -6.3917e-06,  1.9916e-05, -1.8094e-06, -9.5000e-07,  1.9917e-05,
        -6.2300e-06,  7.0684e-06,  4.5523e-06, -6.3310e-06, -1.5175e-05,
        -5.4962e-06,  5.1826e-06, -8.8620e-06,  8.8471e-06,  9.2087e-06,
         1.1700e-05,  6.6750e-06,  1.9009e-06,  5.2933e-06,  6.3436e-06,
         3.7724e-06,  6.5122e-06, -2.7000e-05,  7.0723e-06, -4.5608e-06,
         2.1292e-05, -9.9308e-06,  1.4838e-06, -4.3579e-06,  2.6168e-06,
        -1.7439e-05,  2.9600e-06,  2.5054e-06,  7.7915e-06,