In [1]:
import sys
import os
sys.path.append("..")
sys.path.append(os.path.abspath('../KimiaNet'))
import numpy as np
import math

# ============================== Torch Imports =====================================

import torch
import torch.utils.data as data_utils
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn


# =================================== Metrics =======================================

from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryF1Score
from torcheval.metrics import BinaryAUROC
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import classification_report

# ============================= Models and Datasets =================================

from model import Attention, GatedAttention, AdditiveAttention
from WSI_dataloader import collate, BreastWSIDataset, BreastEmbeddingDataset

# ========================= TensorBoard and Logging =================================
from torch.utils.tensorboard import SummaryWriter

In [2]:
#================================================ PARAMETERS ====================================================
CUDA = True
SEED = 5

# --------------------------------- Hyperparameters ----------------------------------------

LR = 0.0001
REG = 10e-5
NUM_EPOCHS = 40
KFOLD_SPLITS = 5
BAGS_PER_BATCH = 3

# --------------------------------- Metrics Indexes -----------------------------------------

LOSS_INDEX = 0
ACCURACY_INDEX = 1
AUC_INDEX = 2
RECALL_INDEX = 3
F1_INDEX = 4

# -------------------------------------------------------------------------------------------
EMBEDDINGS = True
USE_TENSORBOARD = True

DATASET_HDF5 = "/media/mdastorage/breast_5x_aug_2.h5"
DATASET_FOLDER = "/media/mdastorage/breast_5x_2"
CUDA_DEVICE = "cuda:0"
TENSORBOARD_DIRECTORY = "../runs/cross-val/Breast5x_Aug2_AdditiveAttention"

MODEL_WEIGHTS_FILE = "../model_weights/additiveAttentionMIL_aug2.pt"
TEST_SET_INDICES = 'test_indices.pkl'

In [3]:
# =============================================== Initializations ===============================================

torch.cuda.init()
torch.cuda.memory_summary(device=None, abbreviated=False)

cuda = CUDA and torch.cuda.is_available()
device = torch.device(CUDA_DEVICE)

torch.manual_seed(SEED)
if cuda:
    torch.cuda.manual_seed(SEED)
    print('\nGPU is ON!')

splits = KFold(n_splits = KFOLD_SPLITS, shuffle=True, random_state=42)
#splits = StratifiedKFold(n_splits = KFOLD_SPLITS, shuffle=False, random_state=None)

AttentionModel = Attention()
GatedAttentionModel = GatedAttention()
if cuda:
    AttentionModel.cuda()
    GatedAttentionModel.cuda()
    AttentionModel.to(device)
    GatedAttentionModel.to(device)

optimizer = optim.Adam(AttentionModel.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=REG)
if USE_TENSORBOARD:
    writer = SummaryWriter(TENSORBOARD_DIRECTORY)

def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0)



GPU is ON!


In [4]:
# =========================================== Dataset ======================================================

if EMBEDDINGS:
    dataset = BreastEmbeddingDataset(DATASET_HDF5)
else:
    dataset = BreastWSIDataset(DATASET_FOLDER, 5, (0,370), (0,370), data_augmentations=True)

786 786 786
<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
Dataset Fetched!


In [6]:
# ========================================== Metrics ======================================================

auc_metric = BinaryAUROC().to(device)    
accuracy_metric = BinaryAccuracy(threshold=0.5).to(device)
recall_metric = BinaryRecall(threshold=0.5).to(device)
F1_metric = BinaryF1Score(threshold=0.5).to(device)

def calculate_auc(probs, true_labels):
    auc_metric.update(probs, true_labels)
    auc = auc_metric.compute()
    auc_metric.reset()
    return auc

def calculate_accuracy(probs, true_labels):
    return accuracy_metric(probs, true_labels)

def calculate_recall(probs, true_labels):
    return recall_metric(probs, true_labels)

def calculate_F1(probs, true_labels):
    return F1_metric(probs, true_labels)


def calculate_metrics(probs, true_labels):
    probs = torch.FloatTensor(probs).to(device)
    true_labels = torch.FloatTensor(true_labels).to(device)
    
    accuracy = calculate_accuracy(probs, true_labels)
    auc = calculate_auc(probs, true_labels)
    f1 = calculate_F1(probs, true_labels)
    recall = calculate_recall(probs, true_labels)
    
    return accuracy, auc, f1, recall

def get_classification_report(probs, true_labels):
    report = classification_report(true_labels, probs, target_names=["Negative", "Positive"])
    print(report)

def metrics_to_tensorboard(fold_metrics):
    if USE_TENSORBOARD:
        for fold in range(fold_metrics.shape[0]):
            for epoch in range(fold_metrics.shape[1]):
                writer.add_scalar('{} Loss - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, LOSS_INDEX], epoch)
                writer.add_scalar('{} Accuracy - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, ACCURACY_INDEX], epoch)
                writer.add_scalar('{} AUC - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, AUC_INDEX], epoch)
                writer.add_scalar('{} Recall - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, RECALL_INDEX], epoch)
                writer.add_scalar('{} F1 - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, F1_INDEX], epoch)
                
                
                writer.add_scalar('{} Loss - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, LOSS_INDEX], epoch)
                writer.add_scalar('{} Accuracy - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, ACCURACY_INDEX], epoch)
                writer.add_scalar('{} AUC - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, AUC_INDEX], epoch)
                writer.add_scalar('{} Recall - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 0, RECALL_INDEX], epoch)
                writer.add_scalar('{} F1 - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 0, F1_INDEX], epoch)

        for epoch in range(fold_metrics.shape[1]):
            writer.add_scalar('{} Loss - Fold {}'.format("training", "avg"), np.mean(fold_metrics[:, epoch, 0, LOSS_INDEX], axis=0), epoch)
            writer.add_scalar('{} Accuracy - Fold {}'.format("training", "avg"), np.mean(fold_metrics[:, epoch, 0, ACCURACY_INDEX], axis=0), epoch)
            writer.add_scalar('{} AUC - Fold {}'.format("training", "avg"), np.mean(fold_metrics[:, epoch, 0, AUC_INDEX], axis=0), epoch)
            writer.add_scalar('{} Recall - Fold {}'.format("training", "avg"), np.mean(fold_metrics[fold, epoch, 0, RECALL_INDEX]), epoch)
            writer.add_scalar('{} F1 - Fold {}'.format("training", "avg"), np.mean(fold_metrics[fold, epoch, 0, F1_INDEX]), epoch)   

            
            writer.add_scalar('{} Loss - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, LOSS_INDEX], axis=0), epoch)
            writer.add_scalar('{} Accuracy - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, ACCURACY_INDEX], axis=0), epoch)
            writer.add_scalar('{} AUC - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, AUC_INDEX], axis=0), epoch)
            writer.add_scalar('{} Recall - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[fold, epoch, 0, RECALL_INDEX]), epoch)
            writer.add_scalar('{} F1 - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[fold, epoch, 0, F1_INDEX]), epoch)         

# ======================================= Auxiliary Functions =============================================

def count_labels(label, pos, neg):
    if label == 0:
        neg += 1
    else:
        pos += 1
    return pos, neg 

In [12]:
# ========================================== Train Functions =============================================

def train_epoch_embeddings(model, dataloader, optimizer, epoch, fold=""):
    model.train()
    train_loss, train_error = 0., 0.
    train_correct = 0
    pos, neg = 0, 0
    probs, true_labels = [], []
    nan_loss = False
    num_aug = 0
    num_nan = 0

    for batch_idx, (data, coords, label, path) in enumerate(dataloader):
        bag_label = label[0]
        pos, neg = count_labels(label, pos, neg)
        num_aug = data.shape[0]
        for aug in range(num_aug):
            for param in model.parameters():
                param.grad = None
            
            embedding = data[aug,:,:]
            
            
            if cuda:
                embedding, bag_label = embedding.to(device), label.to(device)
            
            loss, _ = model.calculate_objective(embedding, bag_label)


            if torch.isnan(embedding).any():
                num_nan += 1
                continue
            

            train_loss += loss.item()
            
            error, Y_hat, Y_prob = model.calculate_classification_error(embedding, bag_label)
            if torch.any(torch.isnan(Y_prob)):
                #print("y is nan", embedding)
                num_nan += 1
                continue

            train_error += error
            train_correct += (Y_hat == bag_label).sum().item()

            probs.append(Y_prob.item())
            true_labels.append(label[0].item())
            
            loss.backward()
            optimizer.step()
            
    return train_loss, train_correct, probs, true_labels, num_nan


# ========================================== Validation Functions =====================================================

def valid_epoch_embeddings(model, dataloader, epoch, fold="", get_scores=False):
    model.eval()
    valid_loss, valid_error, valid_correct = 0., 0., 0
    probs, true_labels = [], []
    pos, neg = 0, 0
    nan_loss = False
    num_aug = 0
    attention_scores = {}

    for batch_idx, (data, coords, label, path) in enumerate(dataloader):
        bag_label = label[0]
        pos, neg = count_labels(bag_label, pos, neg)
        num_aug = data.shape[0]
        embedding = data[0,:,:]
        if cuda:
            embedding, bag_label = embedding.to(device), label.to(device)

        loss, attention_weights = model.calculate_objective(embedding, bag_label)
        error, Y_hat, Y_prob = model.calculate_classification_error(embedding, bag_label)

        if torch.any(torch.isnan(Y_prob)):
            #print("y is nan", embedding)
            num_nan += 1
            continue

        if get_scores:
            attention_scores[path[0]] = (coords, attention_weights)

        valid_loss += loss.item()
        valid_error += error
        valid_correct += (Y_hat == bag_label).sum().item()
        probs.append(Y_prob.item())
        true_labels.append(label[0].item())


    if epoch == "validation" and USE_TENSORBOARD:
        accuracy, auc, f1, recall = calculate_metrics(probs, true_labels)
        writer.add_text('loss', "{:.4f}".format(valid_loss/len(dataloader)))
        writer.add_text('auc', "{:.4f}".format(auc))
        writer.add_text('accuracy',"{:.4f}".format(accuracy))
        writer.add_text('f1',"{:.4f}".format(f1))
        writer.add_text('recall',"{:.4f}".format(recall))


    if get_scores:
        return valid_loss, valid_correct, probs, true_labels, attention_scores
    else:
        return valid_loss, valid_correct, probs, true_labels


In [13]:
#============================================ Train/Test Split =====================================


train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size


train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = data_utils.DataLoader(train_dataset, batch_size=1, num_workers=8, pin_memory=True, prefetch_factor=20,collate_fn=collate)
test_loader = data_utils.DataLoader(test_dataset, batch_size=1, num_workers=8, pin_memory=True, prefetch_factor=20,collate_fn=collate)



In [14]:
import matplotlib.pyplot as plt
import statistics

def fold_data(loader, fold, desc):
    bag_sizes = [0]*len(loader)
    indexes = [i for i in range(len(loader))]
    print(len(indexes), len(bag_sizes))
    pos,neg = 0,0
    for batch_idx, (data, coords, label, path) in enumerate(loader):
        bag_label = label[0]
        pos, neg = count_labels(bag_label, pos, neg)
        embedding = data[0,:,:]
        bag_sizes[embedding.shape[0]] += 1
    fig, ax = plt.subplots(1,2, figsize=(40,20))
    ax.bar(indexes, bag_sizes)
    plt.title("fold: " + str(fold) + " " + desc)
    #print(desc + " " + str(fold) + " mean " + str(statistics.mean(bag_sizes)))
    #print(desc + " " + str(fold) + " stdev " + str(statistics.stdev(bag_sizes)))
    #print(desc + " " + str(fold) + " median " + str(statistics.median(bag_sizes)))
    plt.show()

def check_fold_data(train_loader, test_loader, fold):
    train_paths, test_paths = [], []
    train_pos_neg, test_pos_neg = [0,0], [0,0]
    for batch_idx, (data, coords, label, path) in enumerate(train_loader):
        train_paths.append(path[0])
        train_pos_neg[label] += 1
    for batch_idx, (data, coords, label, path) in enumerate(test_loader):
        test_paths.append(path[0])
        test_pos_neg[label] += 1

    train_paths = set(train_paths)
    test_paths = set(test_paths)

    if len(train_paths.intersection(test_paths)) > 0:
        print("Duplicate found ", str(fold))

    # print("train:", train_pos_neg)
    # print("test:", test_pos_neg)


In [15]:
def cross_validation(model, dataset):
    history = {'train_loss': [], 'test_loss': [],'train_acc':[],'test_acc':[]}

    fold_metrics = np.empty((KFOLD_SPLITS, NUM_EPOCHS, 2, 5))

    for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(dataset)))):
        model = AdditiveAttention()
        model.to(device)
        model.apply(weights_init)
        print('Fold {}'.format(fold + 1))
        train_sampler = data_utils.SubsetRandomSampler(train_idx)
        test_sampler = data_utils.SubsetRandomSampler(val_idx)
        train_loader = data_utils.DataLoader(dataset, batch_size=1, sampler=train_sampler, num_workers=0, pin_memory=True, collate_fn=collate)
        test_loader = data_utils.DataLoader(dataset, batch_size=1, sampler=test_sampler, num_workers=0, pin_memory=True, collate_fn=collate)

        optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=REG)

        train_auc_epochs, train_accuracy_epochs, train_loss_epochs = [], [], []
        test_auc_epochs, test_accuracy_epochs, test_loss = [], [], []
        
        for epoch in range(0, NUM_EPOCHS):
            train_loss, train_correct, train_probs, train_true_labels, num_nan = train_epoch_embeddings(model,train_loader,optimizer, epoch, fold)
            test_loss, test_correct, test_probs, test_true_labels = valid_epoch_embeddings(model,test_loader, epoch, fold)
            
            train_accuracy, train_auc, train_f1, train_recall = calculate_metrics(train_probs, train_true_labels)
            test_accuracy, test_auc, test_f1, test_recall = calculate_metrics(test_probs, test_true_labels)   
            
            fold_metrics[fold, epoch, 0, LOSS_INDEX] = train_loss / (len(train_loader) * BAGS_PER_BATCH - num_nan)
            fold_metrics[fold, epoch, 0, ACCURACY_INDEX] = train_accuracy
            fold_metrics[fold, epoch, 0, AUC_INDEX] = train_auc
            fold_metrics[fold, epoch, 0, F1_INDEX] = train_f1
            fold_metrics[fold, epoch, 0, RECALL_INDEX] = train_recall
            

            fold_metrics[fold, epoch, 1, LOSS_INDEX] = test_loss / len(test_loader)
            fold_metrics[fold, epoch, 1, ACCURACY_INDEX] = test_accuracy
            fold_metrics[fold, epoch, 1, AUC_INDEX] = test_auc
            fold_metrics[fold, epoch, 1, F1_INDEX] = test_f1
            fold_metrics[fold, epoch, 1, RECALL_INDEX] = test_recall
            

    metrics_to_tensorboard(fold_metrics)


    return model    

In [16]:
model = cross_validation(AttentionModel, train_dataset)

dataloader = data_utils.DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True, collate_fn=collate)

valid_loss, valid_correct, probs, true_labels, attention_scores = valid_epoch_embeddings(model, dataloader, "validation", get_scores=True)

#writer.close()
print("External Validation Loss", valid_loss/len(dataloader))
torch.save(model.state_dict(), MODEL_WEIGHTS_FILE)



Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
External Validation Loss 0.2851210172816191


In [17]:
import pickle
test_indices = test_dataset.indices
with open('test_indices.pkl', 'wb') as f:
    pickle.dump(test_indices, f)