In [None]:
import argparse
import os
import random
import warnings
import pandas as pd
import numpy as np
import time
import shutil
import datetime
from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

from skimage.io import imread
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
from scipy.stats import pearsonr
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from utils.logging import *
from data.datasets import PatchPathDataset
from data.process import *

import model.ViT as ViT
import model.ClusterViT as ClusterViT
import model.PREViT as PREViT

## Parameters

In [None]:
args = pd.Series({
    'checkpoint':'checkpoint/',
    'version': '1.0',
    'image_dir': 'Data/',
    'patch_label': 'Metadata/PatchLabels.csv',
    'predicting_var': 'response',
    'cohort': 'Cohort1',
    'magnification': '10X',
    'num_classes': 1,
    'prediction': 'binary classification', # ['regression', 'binary classification', classification']
    'att_model': 'ClusterPREViT',
    'att_version': '1.0',
    'base_epoch': 19,
    'max_num_patches': 10000,
    'normalize_clusters': True,
    'n_clusters': 4,
    'upsample': True,
    'train_val_split': 0.7,
    'random_crop': False,
    'features': 2048,
    'loss': 'bce',
    'epochs': 20, 
    'start_epoch': 0,
    'lr': 0.1,
    'momentum': 0.9,
    'weight_decay': 1e-4,
    'batch_size': 256,
    'workers': 4,
    'evaluate': False,
    'seed': 0,
    'gpu': 0,
    'log_interval': 50,
    'log': False
})


if args.log:
    tensorboard_dir = f'logs/tensorboard/BaselineResNet{args.version}/{args.att_model}{args.att_version}/' 
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    
    current_time = str(datetime.datetime.now().strftime("%d%m%Y-%H:%M:%S"))
    train_log_dir = tensorboard_dir + 'train/' + current_time
    val_log_dir = tensorboard_dir + 'val/' + current_time
    train_summary_writer = SummaryWriter(log_dir=train_log_dir)
    val_summary_writer = SummaryWriter(log_dir=val_log_dir)

### Functions to define metrics and save figures

In [None]:
def save_result_figure(img_name, fig):
    img_save_dir = os.path.join('/well/rittscher/users/axs296/Code/FromScratch/Results/Figures', 
                                 f'BaselineResNet{args.version}/{args.att_model}{args.att_version}')
    if not os.path.exists(img_save_dir):
        os.makedirs(img_save_dir)
    img_save_path = os.path.join(img_save_dir, img_name)
    fig.savefig(img_save_path)

    
def save_checkpoint(state, is_best, epoch, args, filename='checkpoint.pth.tar'):
    save_dir = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', f'{args.att_model}{args.att_version}')
    epoch_dir = os.path.join(save_dir, f'epoch_{epoch}')
    if not os.path.exists(epoch_dir):
        os.makedirs(epoch_dir)
        print(f'Created new directory for saving models at {epoch_dir}')
    save_path = os.path.join(epoch_dir, filename)
    torch.save(state, save_path)
    best_path = os.path.join(save_dir, 'model_best.pth.tar')
    if is_best:
        shutil.copyfile(save_path, best_path)


def metrics(output, target):
    with torch.no_grad():
        sig = nn.Sigmoid()
        if args.prediction == 'binary classification':
            prob_output = sig(output)
            first_acc = roc_auc_score(target.cpu(), prob_output.cpu(), multi_class='ovr', average='weighted')
            second_acc = accuracy_score(torch.round(prob_output).cpu(), target.cpu())
        elif args.prediction == 'classification':
            first_acc = roc_auc_score(target.cpu(), output.cpu(), multi_class='ovo', average='macro')
            second_acc = accuracy_score(torch.argmax(output).cpu(), target.cpu())
        elif args.prediction == 'regression':
            mae_loss = nn.L1Loss()
            first_acc = mae_loss(output.squeeze(), target)
            second_acc, _ = pearsonr(target.cpu(), output.squeeze().cpu())
        else:
            raise IOError(f'Metrics not defined for prediction format: {args.prediction}')
        return first_acc, second_acc

### Functions to load features and clusters

In [None]:
def load_slide_features(slide):
    # slide should be slide name
    feature_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', 'Features', 
                                f'epoch_{args.base_epoch}')
    if not os.path.exists(feature_path):
        os.makedirs(feature_path)
    slide_feature_path = os.path.join(feature_path, slide)
    if os.path.exists(slide_feature_path):
        slide_embeddings_paths = torch.load(slide_feature_path, map_location=torch.device('cuda'))
    else:
        print(f'WARNING: No features found for slide {slide} at {slide_feature_path}.' + 
              'Run "Save ResNet Feature Embeddings" notebook first.')
    return slide_embeddings_paths

def get_target(slide, patch_labels, gpu):
    slide_patch_labels = patch_labels[patch_labels.slide==slide].reset_index(drop=True)
    target = torch.FloatTensor([[slide_patch_labels[args.predicting_var].iloc[0]]])
    if torch.cuda.is_available():
        target = target.cuda(args.gpu, non_blocking=True)
    return target

def add_one_to_clusters(clusters):
    for key, val in clusters.items():
        clusters[key] = val + 1
    return clusters

def load_clusters(args, stage):
    if args.normalize_clusters:
        norm_str = 'normalized_'
    else:
        norm_str = ''
    
    cluster_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', 'Clusters', 
                            f'epoch_{args.base_epoch}', stage)
    assert os.path.exists(cluster_path)
    
    clusters = pickle.load(open(os.path.join(cluster_path, f'{args.n_clusters}_{norm_str}clusters.p'), 'rb')) 
    print('Loaded clusters.')

    # add one to cluster labels
    return add_one_to_clusters(clusters)

### Functions for predicting with attention model on slide level

In [None]:
def limit_input_size(slide_embeddings, clusters, patch_paths, max_num_patches):
    if slide_embeddings.size()[0] > max_num_patches:
        slide_cluster_path_embeddings = list(zip(slide_embeddings, clusters, patch_paths))
        random.shuffle(slide_cluster_path_embeddings)
        slide_cluster_path_embeddings = slide_cluster_path_embeddings[:(max_num_patches - 1)]
        slide_embeddings, clusters, patch_paths = zip(*slide_cluster_path_embeddings)
        slide_embeddings = torch.stack(slide_embeddings)
        clusters = list(clusters)
        patch_paths = list(patch_paths)
        del slide_cluster_path_embeddings
        print(f'Taking {max_num_patches - 1} random patches from slide {slide}' +
              ' because image too large')
    return slide_embeddings, clusters, patch_paths

def predict(att_model, att_model_name, slide_embeddings, clusters=None, patch_paths=None):
    if att_model_name == 'ViT':
        output = att_model(slide_embeddings.unsqueeze(0))
    elif att_model_name == 'PREViT':
        output = att_model(slide_embeddings.unsqueeze(0), patch_paths)
    elif att_model_name == 'ClusterViT':
        output = att_model(slide_embeddings.unsqueeze(0), clusters)
    if att_model_name == 'ClusterPREViT':
        output = att_model(slide_embeddings.unsqueeze(0), clusters, patch_paths)
    return output

def num_patches_in_slide(slide, patch_labels):
    return len(patch_labels[patch_labels.slide == slide])

### Train and validate functions

In [None]:
def train(train_patch_labels, upsampled_train_slides, train_clusters, att_model, criterion, optimizer, 
          epoch, args, first_metric, second_metric, max_num_patches):

    batch_time = AverageMeter('Time', ':4.3f')
    data_time = AverageMeter('Data', ':4.3f')
    losses = AverageMeter('Loss', ':4.2f')
    metric = AverageMeter(first_metric, ':4.2f')
    metric2 = AverageMeter(second_metric, ':4.2f')
    progress = ProgressMeter(
        len(upsampled_train_slides),
        [batch_time, data_time, losses, metric, metric2],
        prefix="Epoch: [{}]".format(epoch),
        summary_prefix='Training:')
    
    full_transforms, _ = image_transforms(args.random_crop)
    
    att_model.train()
    end = time.time()
        
    outputs = []
    targets = []

    for i in range(len(upsampled_train_slides)):
        data_time.update(time.time() - end)
        
        slide = upsampled_train_slides[i]

        target = get_target(slide, train_patch_labels, args.gpu)
        clusters = train_clusters[slide]

        slide_embeddings_paths = load_slide_features(slide)
        slide_embeddings = slide_embeddings_paths['slide_embeddings']
        patch_paths = slide_embeddings_paths['patch_paths']
        del slide_embeddings_paths
        
        slide_embeddings, clusters, patch_paths = limit_input_size(slide_embeddings, clusters, patch_paths, 
                                                                   max_num_patches)

        output = predict(att_model, args.att_model, slide_embeddings, clusters=clusters, patch_paths=patch_paths)

        loss = criterion(output, target)
        losses.update(loss.item(), len(patch_paths))
        
        outputs.extend(output)
        targets.extend(target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
    
        if i % args.log_interval == 0:
            progress.display(i)    
        
        del i, slide_embeddings, target, output
        
    outputs = torch.stack(outputs, dim=0)
    targets = torch.stack(targets, dim=0)
    
    acc1, acc2 = metrics(outputs, targets, args.predicting_var)
    metric.update(acc1, len(upsampled_train_slides))
    metric2.update(acc2, len(upsampled_train_slides))
    
    del outputs, targets
    
    progress.display_summary()
    
    if args.log:
        train_summary_writer.add_scalar(f'Loss/{args.loss}', losses.avg, epoch)
        train_summary_writer.add_scalar(first_metric, metric.avg, epoch)
        train_summary_writer.add_scalar(second_metric, metric2.avg, epoch)

    
def validate(val_patch_labels, val_clusters, att_model, criterion, epoch, args, first_metric, 
             second_metric, max_num_patches):
    val_slides = val_patch_labels.slide.unique()

    batch_time = AverageMeter('Time', ':4.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':4.2f', Summary.NONE)
    metric1 = AverageMeter(first_metric, ':4.2f', Summary.AVERAGE)
    metric2 = AverageMeter(second_metric, ':4.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_slides),
        [batch_time, losses, metric1, metric2],
        prefix='Validation: ',
        summary_prefix='Validation: ')

    _, lim_transforms = image_transforms(args.random_crop)
    
    att_model.eval()
    
    end = time.time()

    outputs = []
    targets = []

    with torch.no_grad():
        for i in range(len(val_slides)):            
            slide = val_slides[i]
            
            target = get_target(slide, val_patch_labels, args.gpu)
            clusters = val_clusters[slide]

            slide_embeddings_paths = load_slide_features(slide)
            slide_embeddings = slide_embeddings_paths['slide_embeddings']
            patch_paths = slide_embeddings_paths['patch_paths']
            del slide_embeddings_paths
            
            slide_embeddings, clusters, patch_paths = limit_input_size(slide_embeddings, clusters, patch_paths, 
                                                                   max_num_patches)
        
            output = predict(att_model, args.att_model, slide_embeddings, clusters=clusters, patch_paths=patch_paths)

            loss = criterion(output, target)
            losses.update(loss.item(), len(patch_paths))
            
            outputs.extend(output)
            targets.extend(target)
        
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
        
            if i % args.log_interval == 0:
                progress.display(i)    
            
            del i, slide_embeddings, target, output
        
    outputs = torch.stack(outputs, dim=0)
    targets = torch.stack(targets, dim=0)

    acc1, acc2 = metrics(outputs, targets, args.predicting_var)
    metric1.update(acc1, len(val_slides))
    metric2.update(acc2, len(val_slides))
    del outputs, targets
    
    progress.display_summary()
    
    if args.log:
        val_summary_writer.add_scalar(f'Loss/{args.loss}', losses.avg, epoch)
        val_summary_writer.add_scalar(first_metric, metric1.avg, epoch)
        val_summary_writer.add_scalar(second_metric, metric2.avg, epoch)
    
    return metric1.avg

# Train Attention Model

In [None]:
def main_worker(gpu, args):
    global best_acc1
    args.gpu = gpu

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
        torch.cuda.set_device(args.gpu)

    if args.prediction in ['classification', 'binary classification']:
        metric = 'AUC'
        second_metric = 'Accuracy'
    elif args.prediction == 'regression':
        metric = 'Pearson correlation'
        second_metric = 'MAE'
    else:
        raise IOError(f'Metrics not defined for prediction format: {args.prediction}')
    
    # Load data
    patch_labels = pd.read_csv(args.patch_label, index_col=0)
    patch_labels = patch_labels[patch_labels.magnification == args.magnification]
    # drop NAs in column trying to predict
    patch_labels = patch_labels.dropna(subset=[args.predicting_var])

    train_patch_labels, val_patch_labels, val_cases, upsampled_train_slides = split_train_val(patch_labels, 
                                                                                              args.cohort, 
                                                                                              args.train_val_split,
                                                                                              args.seed, 
                                                                                              args.prediction,
                                                                                              args.predicting_var,
                                                                                              args.upsample)    

    # check saved_val_cases from baseline model are same as val_cases for attention model
    assert (val_cases == saved_val_cases).all()
    
    train_slides = train_patch_labels.slide.unique()
    print(f'{len(train_slides)} training slides')
    val_slides = val_patch_labels.slide.unique()
    print(f'{len(val_slides)} validation slides')

    patch_distn = [num_patches_in_slide(slide, patch_labels) for slide in patch_labels.slide.unique()]
    print(f'Max number of patches over all slides in dataset is {max(patch_distn)}')
    max_num_patches = max(patch_distn)
    if args.max_num_patches < max_num_patches:
        max_num_patches = args.max_num_patches
        print(f'Setting max number of patches to {max_num_patches}')
    
    train_clusters = load_clusters(args, 'Train')
    val_clusters = load_clusters(args, 'Validation')
    
    # Define attention model
    if args.att_model == 'ViT':    
        att_model = ViT.ViT(num_classes=args.num_classes, dim=512, patch_dim=args.features, depth=4, heads=4, 
                            mlp_dim=512, max_num_patches=max_num_patches, pool='cls', dim_head=64, dropout=0.3,
                            emb_dropout=0.3)
    elif args.att_model == 'PREViT':
        att_model = PREViT.PREViT(num_classes=args.num_classes, dim=512, depth=4, heads=4, mlp_dim=512, 
                                  patch_dim=args.features, pool='cls', dim_head=64, dropout=0.3, emb_dropout=0.3)
    elif args.att_model == 'ClusterViT':
        att_model = ClusterViT.ClusterViT(num_classes=args.num_classes, dim=512, depth=4, heads=4, mlp_dim=512, 
                                          patch_dim=args.features, max_num_patches=max_num_patches,
                                          n_clusters=args.n_clusters, pool='cls', dim_head=64, dropout=0.3, 
                                          emb_dropout=0.3)
    elif args.att_model == 'ClusterPREViT':
        att_model = PREViT.ClusterPREViT(num_classes=args.num_classes, dim=512, depth=4, heads=4, mlp_dim=512, 
                                         patch_dim=args.features, n_clusters=args.n_clusters, pool='cls', 
                                         dim_head=64, dropout=0.3, emb_dropout=0.3)
    else:
        raise IOException(f"No model defined for {args.att_model}")
    att_model = att_model.cuda(args.gpu)
    
    # define loss function (criterion), optimizer, and learning rate scheduler
    if args.loss == 'ce':  # use in classification
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    elif args.loss == 'mse':  # use in cts case (on either normalized or standardized or raw data)
        criterion = nn.MSELoss().cuda(args.gpu)
    elif args.loss == 'bce':  # use in binary case
        criterion = nn.BCELoss().cuda(args.gpu)  # bce without sigmoid as already have sigmoid in ViT

    optimizer = torch.optim.SGD(att_model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    cudnn.benchmark = True
    
    if args.evaluate:
        validate(val_patch_labels, val_clusters, att_model, criterion, epoch, args, metric, second_metric, 
                 max_num_patches)
        return
    
    best_epoch = -1
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        random.shuffle(upsampled_train_slides)
        train(train_patch_labels, upsampled_train_slides, train_clusters, att_model, criterion, 
              optimizer, epoch, args, metric, second_metric, max_num_patches)

        # evaluate on validation set
        acc1 = validate(val_patch_labels, val_clusters, att_model, criterion, epoch, args, metric, 
                        second_metric, max_num_patches)

        scheduler.step()
        
        # remember best acc@1 and save checkpoint
        if metric == 'AUC' or 'Pearson correlation':
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)
        elif metric == 'MAE':
            is_best = acc1 < best_acc1
            best_acc1 = min(acc1, best_acc1)
        else:
            raise IOError(f'best metric scenario not defined for metric {metric}')
        
        if is_best:
            best_epoch = epoch
        
        save_checkpoint({
            'epoch': epoch,
            'arch': args.att_model,
            'state_dict': att_model.state_dict(),
            'best_metric': best_acc1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }, is_best, epoch, args)
    
    if args.log:
        metric_dict = {f'Best/{metric}': best_acc1, 'Best/Epoch': best_epoch}
        val_summary_writer.add_hparams(hparam_dict=args.to_dict(), metric_dict=metric_dict)
    
        train_summary_writer.close()
        #val_summary_writer.close()
    
    del att_model
    return val_cases

In [None]:
best_acc1 = 0

def main():
    #args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    return main_worker(args.gpu, args)

In [None]:
val_cases = main()

# Training done

### Evaluate model by generating predictions from best epoch

In [None]:
def load_best_model(args):
    state_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', f'{args.att_model}{args.att_version}',
                              'model_best.pth.tar')
    best_state = torch.load(state_path)
    print('Best model at epoch: ', best_state['epoch'])
    best_state_dict = best_state['state_dict']
    
    if args.att_model == 'ViT':    
        att_model = ViT.ViT(num_classes=args.num_classes, dim=512, patch_dim=args.features, depth=4, heads=4, 
                            mlp_dim=512, max_num_patches=max_num_patches, pool='cls', dim_head=64, dropout=0.3,
                            emb_dropout=0.3)
    elif args.att_model == 'PREViT':
        att_model = PREViT.PREViT(num_classes=args.num_classes, dim=512, depth=4, heads=4, mlp_dim=512, 
                                  patch_dim=args.features, pool='cls', dim_head=64, dropout=0.3, emb_dropout=0.3)
    elif args.att_model == 'ClusterViT':
        att_model = ClusterViT.ClusterViT(num_classes=args.num_classes, dim=512, depth=4, heads=4, mlp_dim=512, 
                                          patch_dim=args.features, max_num_patches=max_num_patches,
                                          n_clusters=args.n_clusters, pool='cls', dim_head=64, dropout=0.3, 
                                          emb_dropout=0.3)
    elif args.att_model == 'ClusterPREViT':
        att_model = PREViT.ClusterPREViT(num_classes=args.num_classes, dim=512, depth=4, heads=4, mlp_dim=512, 
                                         patch_dim=args.features, n_clusters=args.n_clusters, pool='cls', 
                                         dim_head=64, dropout=0.3, emb_dropout=0.3)
    else:
        raise IOException(f"No model defined for {args.att_model}")
    att_model.load_state_dict(best_state_dict, strict=True)
    
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        att_model = att_model.cuda(args.gpu)
    
    return att_model

In [None]:
def validation_data(args, val_cases):
    patch_labels = pd.read_csv(args.patch_label, index_col=0)
    patch_labels = patch_labels[patch_labels.magnification == args.magnification]
    patch_labels = patch_labels.dropna(subset=[args.predicting_var])
    patch_labels = select_cohort(patch_labels, args.cohort)
    
    val_clusters = load_clusters(args, 'Validation')

    if val_cases is None:
        cases = patch_labels.case.unique()
        num_train_cases = int(np.ceil(len(cases) * args.train_val_split))
        random.seed(args.seed)
        random.shuffle(cases)
        train_cases = cases[:num_train_cases]
        val_cases = cases[num_train_cases:]
        val_patch_labels = patch_labels[patch_labels.case.isin(val_cases)].reset_index(drop=True)
    else:
        val_patch_labels = patch_labels[patch_labels.case.isin(val_cases)].reset_index(drop=True)
    return val_cases, val_patch_labels, val_clusters

In [None]:
def validate_slides(model, val_cases, max_num_patches, args):
    torch.manual_seed(args.seed)
    cudnn.deterministic = True
    model.eval()
    
    outputs = []
    targets = []
    predictions = []
    
    val_cases, val_patch_labels, val_clusters = validation_data(args, val_cases)
    val_slides = val_patch_labels.slide.unique()
    print(f'{len(val_slides)} slides')

    _, lim_transforms = image_transforms(args.random_crop)
    
    slides = []
    with torch.no_grad():
        for slide in val_slides:
            slide_patch_labels = val_patch_labels[val_patch_labels.slide==slide].reset_index(drop=True)
            print(f'{slide} has {len(slide_patch_labels)} patches')
            clusters = val_clusters[slide]

            slide_embeddings_paths = load_slide_features(slide)
            slide_embeddings = slide_embeddings_paths['slide_embeddings']
            patch_paths = slide_embeddings_paths['patch_paths']
            del slide_embeddings_paths
            
            slide_embeddings, clusters, patch_paths = limit_input_size(slide_embeddings, clusters, patch_paths, 
                                                                       max_num_patches)
            output = predict(model, args.att_model, slide_embeddings, clusters=clusters, patch_paths=patch_paths)
                        
            slide_target = slide_patch_labels[args.predicting_var].iloc[0]

            outputs.append(output.item())
            predictions.append(torch.round(output).item())
            targets.append(slide_target)
            slides.append(slide)
            
    return targets, outputs, predictions, slides

In [None]:
def plot_confusion_matrix(targets, predictions, label_names=['0', '1'], save=True):
    cm = confusion_matrix(targets, predictions)
    cm_df = pd.DataFrame(cm,
                         index = label_names, 
                         columns = label_names)
    plt.figure(figsize=(7,5))
    g = sns.heatmap(cm_df, annot=True, fmt='g', cmap='Blues')
    fig = g.get_figure()
    # or fig = on the plt.figure()
    plt.title(f'Confusion Matrix v{args.version}')
    plt.ylabel('Actual Values')
    plt.xlabel('Predicted Values')
    
    if save:
        print('Saving figure')
        save_result_figure('ValidationConfusionMatrix', fig)
        if args.log:
            val_summary_writer.add_figure('Validation Confusion Matrix', fig)
    
    plt.show()
    
# binary classification only 
def density_plot(targets, outputs, label_names=['0', '1'], save=True):
    nocr_idx = [idx for idx, elt in enumerate(targets) if elt == 0]
    nocr_probs = [outputs[idx] for idx in nocr_idx]
    cr_idx = [idx for idx, elt in enumerate(targets) if elt == 1]
    cr_probs = [outputs[idx] for idx in cr_idx]
    
    sns.set_style('whitegrid')
    g = sns.kdeplot(nocr_probs, bw_adjust=0.5, label=label_names[0], c=list(colors.TABLEAU_COLORS.values())[0])
    g = sns.kdeplot(cr_probs, bw_adjust=0.5, label=label_names[1], c=list(colors.TABLEAU_COLORS.values())[1])
    g.set(xlim=(0, 1))
    g.set_title('Density plot of predicted probabilties across true labels')
    g.legend()
    fig = g.get_figure()
    if save:
        print('Saving figure')
        save_result_figure('ValidationDensityPlot', fig)
        if args.log:
            val_summary_writer.add_figure('Validation Density Plot', fig)
    
    plt.show()
    
def plot_prediction_scatter(targets, outputs, save=True):
    fig = plt.figure(figsize=(6, 4))
    plt.plot(targets, outputs, '.', alpha=1.0, color=list(colors.TABLEAU_COLORS.values())[0])
    plt.xlabel("targets")
    plt.ylabel("predictions")
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.title(f'Predicted vs Target')
    m, b = np.polyfit(x=targets, y=predictions, deg=1)
    plt.plot(targets, m*np.array(list(map(float, targets))) + b, color=list(colors.TABLEAU_COLORS.values())[1],
             alpha=0.6)
    
    if save:
        print('Saving figure')
        save_result_figure('ValidationScatterPlot', fig)
        if args.log:
            val_summary_writer.add_figure('Validation Scatter Plot', fig)
    
    plt.show()

In [None]:
def evaluate(args):
    best_model = load_best_model(args)
    targets, outputs, predictions, slides = validate_slides(best_model, val_cases=None, 
                                                            max_num_patches=args.max_num_patches, args=args)
    
    if args.prediction == 'binary classification':
        if args.log:
            plot_confusion_matrix(targets, predictions)
            density_plot(targets, outputs)
        
        plot_confusion_matrix(targets, predictions, save=False)
        density_plot(targets, outputs, save=False)
    
        # slide-level accuracy and AUC
        auc = roc_auc_score(targets, outputs, average='weighted')
        acc = accuracy_score(targets, predictions)
        weighted_acc = balanced_accuracy_score(targets, predictions)
        f1 = f1_score(targets, predictions, average='weighted')
        precision = precision_score(targets, predictions, average='weighted')
        recall = recall_score(targets, predictions, average='weighted')
        
        print(f'{args.version}:')
        print('- AUC', auc)
        print('- Accuracy', acc)
        print('- Balanced accuracy', weighted_acc)
        print('- F1 score', f1)
        print('- Precision', precision)
        print('- Recall', recall)
        
        if args.log:
            val_summary_writer.add_scalar('Best/Slide-level AUC', auc)
            val_summary_writer.add_scalar('Best/Slide-level accuracy', acc)
    else:
        print(f'Implement evaluation for {args.prediction}')
    
    val_summary_writer.close()
    
    return targets, outputs, predictions, slides, auc, acc, weighted_acc, f1, precision, recall

In [None]:
targets, outputs, predictions, slides, auc, acc, weighted_acc, f1, precision, recall = evaluate(args)

### Save metrics

In [None]:
metrics = pd.DataFrame(data=[(auc, acc, weighted_acc, f1, precision, recall)], 
                       index=[f'round_{args.seed}'], 
                       columns=['AUC', 'Accuracy', 'Balanced accuracy', 'F1', 'Precision', 'Recall'])

mets_save_dir = os.path.join('Results/Metrics', f'{args.att_model}', 
                             f'BaselineResNet{args.version}_{args.att_model}{args.att_version}')
if not os.path.exists(mets_save_dir):
    os.makedirs(mets_save_dir)

metrics.to_csv(os.path.join(mets_save_dir, f'round_{args.seed}'))

### Save predictions

In [None]:
results = pd.DataFrame(data=list(zip(targets, outputs, predictions)), 
                       index=slides, columns=[f'Round{args.seed}_Target', f'Round{args.seed}_Output', 
                                              f'Round{args.seed}_Prediction'])

preds_save_dir = os.path.join('Results/Predictions', 
                              f'BaselineResNet{args.version}_ClusterPREViT{args.att_model}')
if not os.path.exists(preds_save_dir):
    os.makedirs(preds_save_dir)
    
results.to_csv(os.path.join(preds_save_dir, f'round_{args.seed}'))