In [None]:
import argparse
import os
import random
import warnings
import pandas as pd
import numpy as np
import time
import shutil
import datetime
from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

from skimage.io import imread
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
from scipy.stats import pearsonr
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from utils.logging import *
from data.datasets import PatchPathDataset
from data.process import *

### Define Parameters

In [None]:
args = pd.Series({
    'checkpoint':'checkpoint/',
    'version': '1.0',
    'image_dir': 'Data/',
    'patch_label': 'Metadata/PatchLabels.csv',
    'predicting_var': 'response',
    'cohort': 'Cohort1',
    'magnification': '10X',
    'num_classes': 1,
    'prediction': 'binary classification', # ['regression', 'binary classification', classification']
    'upsample': True,
    'train_val_split': 0.7,
    'random_crop': False,
    'features': 2048,
    'loss': 'bce',
    'epochs': 20, 
    'start_epoch': 0,
    'lr': 0.1,
    'momentum': 0.9,
    'weight_decay': 1e-4,
    'batch_size': 256,
    'workers': 4,
    'evaluate': False,
    'seed': 0,
    'gpu': 0,
    'log_interval': 50,
    'log': False
})

if args.log:
    sub_dir = 'tensorboard'
    tensorboard_dir = f'logs/{sub_dir}/BaselineResNet{args.version}/' 
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    
    current_time = str(datetime.datetime.now().strftime("%d%m%Y-%H:%M:%S"))
    train_log_dir = tensorboard_dir + 'train/' + current_time
    val_log_dir = tensorboard_dir + 'val/' + current_time
    train_summary_writer = SummaryWriter(log_dir=train_log_dir)
    val_summary_writer = SummaryWriter(log_dir=val_log_dir)

### Metrics and save functions

In [None]:
def metrics(output, target):
    with torch.no_grad():
        sig = nn.Sigmoid()
        if args.prediction == 'binary classification':
            prob_output = sig(output)
            first_acc = roc_auc_score(target.cpu(), prob_output.cpu(), multi_class='ovr', average='weighted')
            second_acc = accuracy_score(torch.round(prob_output).cpu(), target.cpu())
        elif args.prediction == 'classification':
            first_acc = roc_auc_score(target.cpu(), output.cpu(), multi_class='ovo', average='macro')
            second_acc = accuracy_score(torch.argmax(output).cpu(), target.cpu())
        elif args.prediction == 'regression':
            mae_loss = nn.L1Loss()
            first_acc = mae_loss(output.squeeze(), target)
            second_acc, _ = pearsonr(target.cpu(), output.squeeze().cpu())
        else:
            raise IOError(f'Metrics not defined for prediction format: {args.prediction}')
        return first_acc, second_acc
    
def save_checkpoint(state, is_best, epoch, args, filename='checkpoint.pth.tar'):
    save_dir = os.path.join(args.checkpoint, f'BaselineResNet{args.version}')
    epoch_dir = os.path.join(save_dir, f'epoch_{epoch}')
    if not os.path.exists(epoch_dir):
        os.makedirs(epoch_dir)
        print(f'Created new directory for saving models at {epoch_dir}')
    save_path = os.path.join(epoch_dir, filename)
    torch.save(state, save_path)
    best_path = os.path.join(save_dir, 'model_best.pth.tar')
    if is_best:
        shutil.copyfile(save_path, best_path)

### Train and loss functions

In [None]:
def train(train_loader, model, criterion, optimizer, epoch, args, first_metric, second_metric):
    batch_time = AverageMeter('Time', ':4.3f')
    data_time = AverageMeter('Data', ':4.3f')
    losses = AverageMeter('Loss', ':4.2f')
    metric = AverageMeter(first_metric, ':4.2f')
    metric2 = AverageMeter(second_metric, ':4.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, metric, metric2],
        prefix="Epoch: [{}]".format(epoch),
        summary_prefix='Training:')
    
    sig = nn.Sigmoid()
    
    outputs = []
    targets = []

    model.train()
    end = time.time()
    train_size = len(train_loader)
    
    for i, (images, target, _) in enumerate(train_loader):
        data_time.update(time.time() - end)

        if (args.gpu is not None) and (torch.cuda.is_available()):
            images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

        output = model(images)
        if args.loss == 'mse':
            output = sig(output) # scale between 0 and 1
            output = output.type(torch.FloatTensor)
            target = target.type(torch.FloatTensor)
        loss = criterion(output.squeeze(), target)
        
        losses.update(loss.item(), images.size(0))
        outputs.extend(output)
        targets.extend(target)
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_interval == 0:
            progress.display(i)    
        
        del i, images, target
    
    outputs = torch.stack(outputs, dim=0)
    targets = torch.stack(targets, dim=0)

    acc1, acc2 = metrics(outputs, targets)
    metric.update(acc1, train_size)
    metric2.update(acc2, train_size)
    
    del outputs, targets
    
    progress.display_summary()
    if args.log:
        train_summary_writer.add_scalar(f'Loss/{args.loss}', losses.avg, epoch)
        train_summary_writer.add_scalar(first_metric, metric.avg, epoch)
        train_summary_writer.add_scalar(second_metric, metric2.avg, epoch)


def validate(val_loader, model, criterion, epoch, args, first_metric, second_metric):
    batch_time = AverageMeter('Time', ':4.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':4.2f', Summary.NONE)
    metric1 = AverageMeter(first_metric, ':4.2f', Summary.AVERAGE)
    metric2 = AverageMeter(second_metric, ':4.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, metric1, metric2],
        prefix='Validation: ',
        summary_prefix='Validation: ')
    
    outputs = []
    targets = []

    model.eval()
    val_size = len(val_loader)
    
    with torch.no_grad():
        sig = nn.Sigmoid()
        end = time.time()
        for i, (images, target, _) in enumerate(val_loader):
            if (args.gpu is not None) and (torch.cuda.is_available()):
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            output = model(images)
            if args.loss == 'mse':
                output = sig(output) # scale between 0 and 1
                output = output.type(torch.FloatTensor)
                target = target.type(torch.FloatTensor)
            loss = criterion(output.squeeze(), target)
            
            losses.update(loss.item(), images.size(0))

            outputs.extend(output)
            targets.extend(target)

            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_interval == 0:
                progress.display(i)
            del i, images, target
        
    outputs = torch.stack(outputs, dim=0)
    targets = torch.stack(targets, dim=0)
    
    acc1, acc2 = metrics(outputs, targets)
    metric1.update(acc1, val_size)
    metric2.update(acc2, val_size)
        
    progress.display_summary()
    
    if args.log:
        val_summary_writer.add_scalar(f'Loss/{args.loss}',  losses.avg, epoch)
        val_summary_writer.add_scalar(first_metric,  metric1.avg, epoch)
        val_summary_writer.add_scalar(second_metric,  metric2.avg, epoch)
    
    return metric1.avg

# Train ResNet

In [None]:
def main_worker(gpu, args):
    global best_acc1
    args.gpu = gpu

    print('Using ResNet pre-trained on ImageNet')
    model = torchvision.models.resnet50(pretrained=True)
    model.fc = nn.Linear(in_features=args.features, out_features=args.num_classes, bias=True)

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        
    if args.prediction in ['classification', 'binary classification']:
        metric = 'AUC'
        second_metric = 'Accuracy'
    elif args.prediction == 'regression':
        metric = 'Pearson correlation'
        second_metric = 'MAE'
    else:
        raise IOError(f'Metrics not defined for prediction format: {args.prediction}')

    # define loss function (criterion), optimizer, and learning rate scheduler
    if args.loss == 'ce':  # use in classification
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    elif args.loss == 'mse':  # use in cts case (on either normalized or standardized or raw data)
        criterion = nn.MSELoss().cuda(args.gpu)
    elif args.loss == 'bce':  # use in binary case
        criterion = nn.BCEWithLogitsLoss().cuda(args.gpu)  # bce with sigmoid, so don't have to add to model

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    cudnn.benchmark = True

    # Load data
    patch_labels = pd.read_csv(args.patch_label, index_col=0)
    patch_labels = patch_labels[patch_labels.magnification == args.magnification]
    patch_labels = patch_labels.dropna(subset=[args.predicting_var])
    
    train_patch_labels, val_patch_labels, val_cases, _ = split_train_val(patch_labels, args.cohort, 
                                                                         args.train_val_split, args.seed, 
                                                                         args.prediction, args.predicting_var,
                                                                         args.upsample)
    
    full_transforms, lim_transforms = image_transforms(args.random_crop)

    train_dataset = PatchPathDataset(patch_labels=train_patch_labels, image_folder=args.image_dir,
                                     predicting_var=args.predicting_var, transform=full_transforms)

    val_dataset = PatchPathDataset(patch_labels=val_patch_labels, image_folder=args.image_dir,
                                   predicting_var=args.predicting_var, transform=lim_transforms)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True, sampler=None)

    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, pin_memory=True)
    
    if args.log:
        images, labels = next(iter(train_loader))
        grid = torchvision.utils.make_grid(images, nrow=16)
        train_summary_writer.add_image(f'Input images/v{args.version}', grid, 0)
        train_summary_writer.add_graph(model, images.cuda(args.gpu))
    
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                                   num_workers=args.workers, pin_memory=True, sampler=None)

    if args.evaluate:
        validate(val_loader, model, criterion, args, metric, second_metric)
        return
    
    best_epoch = -1
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args, metric, second_metric)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, metric, second_metric)

        scheduler.step()

        # remember best acc@1 and save checkpoint
        if metric == 'AUC' or 'Pearson correlation':
            is_best = acc1 > best_acc1 # ONLY FOR positive metrics!
            best_acc1 = max(acc1, best_acc1)
        elif metric == 'MAE':
            is_best = acc1 < best_acc1
            best_acc1 = min(acc1, best_acc1)
        else:
            raise IOError(f'best metric scenario not defined for metric {metric}')
        
        if is_best:
            best_epoch = epoch

        save_checkpoint({
            'epoch': epoch,
            'arch': 'resnet50',
            'state_dict': model.state_dict(),
            'best_metric': best_acc1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'val_cases': val_cases
        }, is_best, epoch, args)
        
    if args.log:
        metric_dict = {f'Best/{metric}': best_acc1, 'Best/Epoch': best_epoch}
        val_summary_writer.add_hparams(hparam_dict=args.to_dict(), metric_dict=metric_dict)
    
        train_summary_writer.close()
        #val_summary_writer.close()
    
    del model
    
    return val_cases

In [None]:
best_acc1 = 0

def main():

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    return main_worker(args.gpu, args)

In [None]:
val_cases = main()

Check validation cases

In [None]:
print(val_cases)

# Training Done

## Run below to plot final predictions - hist and density plot and confusion matrix for binary classification only

Load best model and evaluate on that

In [None]:
def load_best_model(args):
    state_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', 'model_best.pth.tar')
    best_state = torch.load(state_path)
    print('Best model at epoch: ', best_state['epoch'])
    best_state_dict = best_state['state_dict']
    
    model = torchvision.models.resnet50(pretrained=True)
    model.fc = nn.Linear(in_features=args.features, out_features=args.num_classes, bias=True)
    model.load_state_dict(best_state_dict, strict=True)
    
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    
    return model

In [None]:
def validation_data(args, val_cases):
    patch_labels = pd.read_csv(args.patch_label, index_col=0)
    patch_labels = patch_labels[patch_labels.magnification == args.magnification]
    patch_labels = patch_labels.dropna(subset=[args.predicting_var])
    patch_labels = select_cohort(patch_labels, args.cohort)

    if val_cases is None:
        cases = patch_labels.case.unique()
        num_train_cases = int(np.ceil(len(cases) * args.train_val_split))
        random.seed(args.seed)
        random.shuffle(cases)
        train_cases = cases[:num_train_cases]
        val_cases = cases[num_train_cases:]
        val_patch_labels = patch_labels[patch_labels.case.isin(val_cases)].reset_index(drop=True)
    else:
        val_patch_labels = patch_labels[patch_labels.case.isin(val_cases)].reset_index(drop=True)
    return val_cases, val_patch_labels

In [None]:
def validate_slides(model, val_cases, args):
    model.eval()
    
    outputs = []
    targets = []
    predictions = []
    
    val_cases, val_patch_labels = validation_data(args, val_cases)
    
    val_slides = val_patch_labels.slide.unique()
    print(f'{len(val_slides)} slides')
    
    _, lim_transforms = image_transforms(args.random_crop)
    
    with torch.no_grad():
        sig = nn.Sigmoid()
        for j in range(len(val_slides)):
            
            slide = val_slides[j]
            slide_patch_labels = val_patch_labels[val_patch_labels.slide==slide].reset_index(drop=True)
            print(f'{slide} has {len(slide_patch_labels)} patches')
            
            val_dataset = PatchPathDataset(patch_labels=slide_patch_labels, image_folder=args.image_dir,
                                           predicting_var=args.predicting_var, transform=lim_transforms)

            val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                                                     num_workers=args.workers, pin_memory=True, sampler=None)
            patch_outputs = []
            for i, (images, target, _) in enumerate(val_loader):
                if args.gpu is not None:
                    images = images.cuda(args.gpu, non_blocking=True)

                output = model(images)
                patch_outputs.extend(output)
            
            del val_dataset, val_loader
            
            patch_outputs = torch.stack(patch_outputs, dim=0)
            slide_output = torch.mean(sig(patch_outputs))
            slide_target = slide_patch_labels[args.predicting_var].iloc[0]
            
            outputs.append(slide_output.item())
            targets.append(slide_target)
            predictions.append(torch.round(slide_output).item())
            
    return targets, outputs, predictions

In [None]:
def save_result_figure(img_name, fig):
    img_save_dir = os.path.join('Results/Figures', f'BaselineResNet{args.version}')
    if not os.path.exists(img_save_dir):
        os.makedirs(img_save_dir)
    img_save_path = os.path.join(img_save_dir, img_name)
    fig.savefig(img_save_path)
    

def plot_confusion_matrix(targets, predictions, label_names=['0', '1'], save=True):
    cm = confusion_matrix(targets, predictions)
    cm_df = pd.DataFrame(cm,
                         index = label_names, 
                         columns = label_names)
    plt.figure(figsize=(7,5))
    g = sns.heatmap(cm_df, annot=True, fmt='g', cmap='Blues')
    fig = g.get_figure()
    # or fig = on the plt.figure()
    plt.title(f'Confusion Matrix v{args.version}')
    plt.ylabel('Actual Values')
    plt.xlabel('Predicted Values')
    
    if save:
        print('Saving figure')
        save_result_figure('ValidationConfusionMatrix', fig)
        if args.log:
            val_summary_writer.add_figure('Validation Confusion Matrix', fig)
    
    plt.show()
    
# binary classification only 
def density_plot(targets, outputs, label_names=['0', '1'], save=True):
    nocr_idx = [idx for idx, elt in enumerate(targets) if elt == 0]
    nocr_probs = [outputs[idx] for idx in nocr_idx]
    cr_idx = [idx for idx, elt in enumerate(targets) if elt == 1]
    cr_probs = [outputs[idx] for idx in cr_idx]
    
    sns.set_style('whitegrid')
    g = sns.kdeplot(nocr_probs, bw_adjust=0.5, label=label_names[0], c=list(colors.TABLEAU_COLORS.values())[0])
    g = sns.kdeplot(cr_probs, bw_adjust=0.5, label=label_names[1], c=list(colors.TABLEAU_COLORS.values())[1])
    g.set(xlim=(0, 1))
    g.set_title('Density plot of predicted probabilties across true labels')
    g.legend()
    fig = g.get_figure()

    if save:
        print('Saving figure')
        save_result_figure('ValidationDensityPlot', fig)
        if args.log:
            val_summary_writer.add_figure('Validation Density Plot', fig)
    
    plt.show()
    
def plot_prediction_scatter(targets, outputs, save=True):
    fig = plt.figure(figsize=(6, 4))
    plt.plot(targets, outputs, '.', alpha=1.0, color=list(colors.TABLEAU_COLORS.values())[0])
    plt.xlabel("targets")
    plt.ylabel("predictions")
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.title(f'Predicted vs Target')
    m, b = np.polyfit(x=targets, y=predictions, deg=1)
    plt.plot(targets, m*np.array(list(map(float, targets))) + b, color=list(colors.TABLEAU_COLORS.values())[1],
             alpha=0.6)
    
    if save:
        print('Saving figure')
        save_result_figure('ValidationScatterPlot', fig)
        if args.log:
            val_summary_writer.add_figure('Validation Scatter Plot', fig)
    
    plt.show()

### Generate predictions from best model epoch over validation set

In [None]:
def evaluate(args):
    best_model = load_best_model(args)
    targets, outputs, predictions = validate_slides(best_model, val_cases, args)
    
    if args.prediction == 'binary classification':
        if args.log:
            plot_confusion_matrix(targets, predictions)
            density_plot(targets, outputs)
        
        plot_confusion_matrix(targets, predictions, save=False)
        density_plot(targets, outputs, save=False)
    
        # slide-level accuracy and AUC
        slide_level_auc = roc_auc_score(targets, predictions, multi_class='ovr', average='weighted')
        slide_level_acc = accuracy_score(predictions, targets)
    
        print('Validation AUC:', slide_level_auc)
        print('Validation Accuracy:', slide_level_acc)
        
        if args.log:
            val_summary_writer.add_scalar('Best/Slide-level AUC', slide_level_auc)
            val_summary_writer.add_scalar('Best/Slide-level accuracy', slide_level_acc)
    else:
        print(f'Implement evaluation for {args.prediction}')
    
    val_summary_writer.close()

In [None]:
evaluate(args)