In [1]:
import sys
import os
sys.path.append("..")
import numpy as np
import math

# ================================== Torch Imports ==================================

import torch
import torch.utils.data as data_utils
import torch.optim as optim
import torch.nn as nn

from torch.autograd import Variable

# =================================== Metrics =======================================

from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryF1Score
from torcheval.metrics import BinaryAUROC, R2Score
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import classification_report

import matplotlib.pyplot as plt
import statistics

# ============================= Models and Datasets =================================

from models.models import Attention, GatedAttention, AdditiveAttention, ModAdditiveAttention
from wsi_datasets import tumor_collate, tumor_pad_collate_fn, gene_collate, gene_pad_collate_fn, GeneExpressionDataset

# ============================ TensorBoard and Logging =============================
from torch.utils.tensorboard import SummaryWriter

import pickle

#================================= Finetuning ======================================

import optuna

ModuleNotFoundError: No module named 'models'

In [40]:
#==================================== PARAMETERS ====================================================

CUDA = True
SEED = 5

# --------------------------------- Hyperparameters ----------------------------------------

LR = 0.000023422110057
REG = 0.000152652510413
NUM_EPOCHS = 35
KFOLD_SPLITS = 5
BATCH_SIZE = 32
uniform = torch.nn.init.xavier_uniform_
uni = "xavier" if uniform == torch.nn.init.xavier_uniform_ else "kaiming"

# --------------------------------- Metrics Indexes -----------------------------------------

LOSS_INDEX = 0
ACCURACY_INDEX = 1
AUC_INDEX = 2
RECALL_INDEX = 3
F1_INDEX = 4
PERCENTAGE_ACCURACY_INDEX = 5

# ------------------------------------- Labels ----------------------------------------------

LABELS = ["tumor", "tp53"]

# -------------------------------- Train/Test Split -----------------------------------------

TRAIN_PERC = 0.8

# -------------------------------------------------------------------------------------------

EMBEDDINGS = True
USE_TENSORBOARD = True

DATASET_HDF5 = "../datasets/gene_expression/BRCA_TP53_10x.hdf5"
CUDA_DEVICE = "cuda:0"

TENSORBOARD_DIRECTORY = ""
MODEL_WEIGHTS_FILE = ""
TEST_SET_INDICES = 'test_indices.pkl'
print(TENSORBOARD_DIRECTORY)

../runs/cross-val/gene_expression/BRCA_10x/baselines/Attention/trial_133_


In [41]:
# ================================ Initializations ===============================================

# CUDA initializations
torch.cuda.init()
torch.cuda.memory_summary(device=None, abbreviated=False)
cuda = CUDA and torch.cuda.is_available()
device = torch.device(CUDA_DEVICE)
torch.manual_seed(SEED)
if cuda:
    torch.cuda.manual_seed(SEED)
    print('\nGPU is ON!')

# Model to train
model_type = Attention
# Cross Validation
splits = KFold(n_splits = KFOLD_SPLITS, shuffle=True, random_state=42)

# Tensorboard Writer
if USE_TENSORBOARD:
    writer = SummaryWriter(TENSORBOARD_DIRECTORY)

# Weights Initialization Function
def weights_init(m):
    if isinstance(m, nn.Linear):
        uniform(m.weight)
        m.bias.data.fill_(0)


GPU is ON!


In [42]:
# ========================================== Metrics ======================================================

# -------------------------------- Metrics Intialization --------------------------

auc_metric = BinaryAUROC().to(device)    
accuracy_metric = BinaryAccuracy(threshold=0.5).to(device)
recall_metric = BinaryRecall(threshold=0.5).to(device)
F1_metric = BinaryF1Score(threshold=0.5).to(device)
R2Score_metric = R2Score()

# ----------------------------- Metrics Calculation Functions -----------------------

def calculate_r2_score(probs, true_labels):
    R2Score_metric.update(probs, true_labels)
    r2score = R2Score_metric.compute()
    R2Score_metric.reset()
    return r2score

def calculate_auc(probs, true_labels):
    auc_metric.update(probs, true_labels)
    auc = auc_metric.compute()
    auc_metric.reset()
    return auc

def calculate_accuracy(probs, true_labels):
    return accuracy_metric(probs, true_labels)

def calculate_recall(probs, true_labels):
    return recall_metric(probs, true_labels)

def calculate_F1(probs, true_labels):
    return F1_metric(probs, true_labels)


def calculate_metrics(probs, true_labels):
    probs = torch.FloatTensor(probs).to(device)
    true_labels = torch.FloatTensor(true_labels).to(device)
    
    accuracy = calculate_accuracy(probs, true_labels)
    auc = calculate_auc(probs, true_labels)
    f1 = calculate_F1(probs, true_labels)
    recall = calculate_recall(probs, true_labels)
    
    return accuracy, auc, f1, recall


# --------------------------- Metrics Presentation Functions ---------------------

def get_classification_report(probs, true_labels):
    report = classification_report(true_labels, probs, target_names=["Negative", "Positive"])
    print(report)

def metrics_to_tensorboard(fold_metrics):
    if USE_TENSORBOARD:
        for fold in range(fold_metrics.shape[0]):
            for epoch in range(fold_metrics.shape[1]):
                writer.add_scalar('{} Loss - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, LOSS_INDEX], epoch)
                writer.add_scalar('{} Accuracy - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, ACCURACY_INDEX], epoch)
                writer.add_scalar('{} AUC - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, AUC_INDEX], epoch)
                writer.add_scalar('{} Recall - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, RECALL_INDEX], epoch)
                writer.add_scalar('{} F1 - Fold {}'.format("training", fold), fold_metrics[fold, epoch, 0, F1_INDEX], epoch)
                
                writer.add_scalar('{} Loss - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, LOSS_INDEX], epoch)
                writer.add_scalar('{} Accuracy - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, ACCURACY_INDEX], epoch)
                writer.add_scalar('{} AUC - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, AUC_INDEX], epoch)
                writer.add_scalar('{} Recall - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, RECALL_INDEX], epoch)
                writer.add_scalar('{} F1 - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, F1_INDEX], epoch)
                writer.add_scalar('{} Percentages - Fold {}'.format("validation", fold), fold_metrics[fold, epoch, 1, PERCENTAGE_ACCURACY_INDEX], epoch) 


        for epoch in range(fold_metrics.shape[1]):
            writer.add_scalar('{} Loss - Fold {}'.format("training", "avg"), np.mean(fold_metrics[:, epoch, 0, LOSS_INDEX], axis=0), epoch)
            writer.add_scalar('{} Accuracy - Fold {}'.format("training", "avg"), np.mean(fold_metrics[:, epoch, 0, ACCURACY_INDEX], axis=0), epoch)
            writer.add_scalar('{} AUC - Fold {}'.format("training", "avg"), np.mean(fold_metrics[:, epoch, 0, AUC_INDEX], axis=0), epoch)
            writer.add_scalar('{} Recall - Fold {}'.format("training", "avg"), np.mean(fold_metrics[fold, epoch, 0, RECALL_INDEX]), epoch)
            writer.add_scalar('{} F1 - Fold {}'.format("training", "avg"), np.mean(fold_metrics[fold, epoch, 0, F1_INDEX]), epoch)   
            
            writer.add_scalar('{} Loss - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, LOSS_INDEX], axis=0), epoch)
            writer.add_scalar('{} Accuracy - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, ACCURACY_INDEX], axis=0), epoch)
            writer.add_scalar('{} AUC - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, AUC_INDEX], axis=0), epoch)
            writer.add_scalar('{} Recall - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, RECALL_INDEX]), epoch)
            writer.add_scalar('{} F1 - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, F1_INDEX]), epoch) 
            writer.add_scalar('{} Percentages - Fold {}'.format("validation", "avg"), np.mean(fold_metrics[:, epoch, 1, PERCENTAGE_ACCURACY_INDEX]), epoch)

    return np.mean(fold_metrics[:, -1, 1, LOSS_INDEX]), np.mean(fold_metrics[:, -1, 1, ACCURACY_INDEX]), np.mean(fold_metrics[:, -1, 1, AUC_INDEX])   


def external_validation_to_tensorboard(loss, auc, accuracy, f1, recall, dataloader_len):
    writer.add_text('loss', "{:.4f}".format(loss/dataloader_len), 0)
    writer.add_text('auc', "{:.4f}".format(auc), 0)
    writer.add_text('accuracy',"{:.4f}".format(accuracy), 0)
    writer.add_text('f1',"{:.4f}".format(f1), 0)
    writer.add_text('recall',"{:.4f}".format(recall), 0)


In [43]:
def remove_padding(data):
    mask = np.all(np.array(data) != -np.Inf, axis=1)
    return data[mask]

In [44]:
# =========================================== Dataset ======================================================

# Initialiation
dataset = GeneExpressionDataset(DATASET_HDF5, LABELS[1])

# Train/Test Split

train_size = int(TRAIN_PERC * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
dataloader = data_utils.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True, collate_fn=gene_pad_collate_fn)
pos, neg = 0, 0
for batch_idx, (data, coords, labels, slide_id, case_id, _) in enumerate(dataloader):
    # print(labels)
    if labels[0] == 1:
        pos += 1
    elif labels[0] == 0:
        neg += 1
print(pos, neg)
# -------------------------------------- Auxiliary Functions ---------------------------------------------

def remove_padding(data):
    mask = np.all(np.array(data) != -np.Inf, axis=1)
    return data[mask]


662
331 331


In [45]:
# ===================================== Train Function ========================================

# ------------------------------------ Train Functions ----------------------------------------

# Train for a Single Batch
def train_batch(model, data, labels):
    losses, Y_probs, true_labels = [], [], []
    num_nan = 0
    train_error = 0
    train_correct = 0
    
    for data_element, label in zip(data, labels):
        if torch.isnan(data_element).any():
            num_nan += 1
            continue
        data_element = remove_padding(data_element)
        if cuda:
            data_element, label = data_element.to(device), label.to(device)
        
        loss, attention, scores = model.calculate_objective(data_element, label)
        error, Y_hat, Y_prob    = model.calculate_classification_error(data_element, label)

        if torch.isnan(Y_prob).any():
            num_nan += 1
            continue
        losses.append(loss)
        train_error += error
        train_correct += (Y_hat == label).sum().item()
        Y_probs.append(Y_prob.item())
        true_labels.append(label)

    return losses, train_error, train_correct, Y_probs, true_labels, num_nan


# Full epoch training
def train_epoch_embeddings(model, dataloader, optimizer, epoch, fold=""):
    model.train()
    train_loss, train_error, train_correct = 0., 0., 0.
    Y_probs = []
    true_labels = []
    num_nan = 0
    
    for batch_idx, (data, coords, labels, slide_id, case_id, _) in enumerate(dataloader):
        batch_losses, batch_error, batch_correct, batch_y_probs, batch_labels, batch_nan = train_batch(model, data, labels)
        final_loss = torch.stack(batch_losses)
        train_loss += final_loss.sum()
        train_error += batch_error
        train_correct += batch_correct
        Y_probs += batch_y_probs
        batch_nan += batch_nan
        true_labels += batch_labels

        l2_lambda = 0.001
        l2_norm = sum(p.pow(2.0).sum()
                        for p in model.parameters())

        final_loss = final_loss.mean() #+ l2_norm * l2_lambda
        
        final_loss.backward()
        #nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    return train_loss, train_correct, Y_probs, true_labels, num_nan


In [46]:
# ====================================== Validation Functions =====================================

def valid_batch(model, data, slide_ids, case_ids, labels, augmentation_flags, score_list):
    valid_loss, valid_error, valid_correct = 0, 0, 0
    Y_probs, true_labels = [], []
    num_nan = 0

    for data_element, label, slide_id, is_aug in zip(data, labels, slide_ids, augmentation_flags):
        if is_aug:
            break
        data_element = remove_padding(data_element)

        if cuda:
            data_element, label = data_element.to(device), label.to(device)
      
        loss, attention, scores = model.calculate_objective(data_element, label)
        error, Y_hat, Y_prob    = model.calculate_classification_error(data_element, label)
        
        score_list[slide_id] = scores 
        if torch.any(torch.isnan(Y_prob)):
            num_nan += 1
            continue
        valid_loss += loss.item()
        valid_error += error
        valid_correct += (Y_hat == label).sum().item()
        Y_probs.append(Y_prob.item())
        true_labels.append(label)

    return valid_loss, valid_correct, Y_probs, true_labels,  num_nan, score_list




def valid_epoch_embeddings(model, dataloader, epoch, fold=""):
    model.eval()
    valid_loss, valid_error, valid_correct = 0., 0., 0
    true_labels = []
    Y_probs = []
    num_nan = 0
    score_list = {}

    for batch_idx, (data, coords, labels, slide_ids, case_ids, aug_flags) in enumerate(dataloader):
        batch_loss, batch_correct, batch_y_probs, batch_labels, batch_nan, score_list = valid_batch(model, data, slide_ids, case_ids, labels, aug_flags, score_list)
        valid_loss += batch_loss
        valid_correct += batch_correct
        Y_probs += batch_y_probs
        num_nan += batch_nan
        true_labels += batch_labels
        

    if epoch == "validation" and USE_TENSORBOARD:
        accuracy, auc, f1, recall = calculate_metrics(Y_probs, true_labels)
        external_validation_to_tensorboard(valid_loss, auc, accuracy, f1, recall, len(dataloader))

    return valid_loss, valid_correct, Y_probs, true_labels, score_list

In [47]:
# ============================ Cross Validation ===============================

def model_initialization(model_type):
    model = model_type()
    model = model.to(device)
    model.apply(weights_init)
    return model

def cross_validation(model_type, dataset, optimizer_type, lr, weight_decay, batch_size=1, num_epochs=15):
    history = {'train_loss': [], 'test_loss': [],'train_acc':[],'test_acc':[]}

    fold_metrics = np.empty((KFOLD_SPLITS, num_epochs, 2, 6))


    for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(dataset)))):
        model = model_initialization(model_type)
        optimizer = optimizer_type(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,5,10], gamma=0.5)

        print('Fold {}'.format(fold + 1))
        train_sampler = data_utils.SubsetRandomSampler(train_idx)
        test_sampler = data_utils.SubsetRandomSampler(val_idx)
        train_loader = data_utils.DataLoader(dataset, batch_size, sampler=train_sampler, pin_memory=True, collate_fn=gene_pad_collate_fn)
        test_loader = data_utils.DataLoader(dataset, batch_size, sampler=test_sampler, pin_memory=True, collate_fn=gene_pad_collate_fn)

        train_auc_epochs, train_accuracy_epochs, train_loss_epochs = [], [], []
        test_auc_epochs, test_accuracy_epochs, test_loss = [], [], []
        
        for epoch in range(0, num_epochs):
            train_loss, train_correct, train_probs, train_true_labels, num_nan = train_epoch_embeddings(model,train_loader, optimizer, epoch, fold)
            test_loss, test_correct, test_probs, test_true_labels, _  = valid_epoch_embeddings(model,test_loader, epoch, fold)
            
            train_accuracy, train_auc, train_f1, train_recall = calculate_metrics(train_probs, train_true_labels)
            test_accuracy, test_auc, test_f1, test_recall = calculate_metrics(test_probs, test_true_labels)  


            num_bags_per_sample = train_dataset[0][0].shape[0] 

            fold_metrics[fold, epoch, 0, LOSS_INDEX] = train_loss / (len(train_sampler) * num_bags_per_sample - num_nan)
            fold_metrics[fold, epoch, 0, ACCURACY_INDEX] = train_accuracy
            fold_metrics[fold, epoch, 0, AUC_INDEX] = train_auc
            fold_metrics[fold, epoch, 0, F1_INDEX] = train_f1
            fold_metrics[fold, epoch, 0, RECALL_INDEX] = train_recall
            
            fold_metrics[fold, epoch, 1, LOSS_INDEX] = test_loss / len(test_loader)
            fold_metrics[fold, epoch, 1, ACCURACY_INDEX] = test_accuracy
            fold_metrics[fold, epoch, 1, AUC_INDEX] = test_auc
            fold_metrics[fold, epoch, 1, F1_INDEX] = test_f1
            fold_metrics[fold, epoch, 1, RECALL_INDEX] = test_recall
        
            lr_scheduler.step()
            
    loss, accuracy, auc = metrics_to_tensorboard(fold_metrics)

    return model, loss, accuracy, auc    

In [48]:
# Cross Validation
model, loss, _, auc = cross_validation(model_type, train_dataset, optim.Adam, LR, REG, BATCH_SIZE, NUM_EPOCHS)

print("loss", loss, "\nauc", auc)
# External Validation
dataloader = data_utils.DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True, collate_fn=gene_pad_collate_fn)
valid_loss, valid_correct, probs, true_labels, scores = valid_epoch_embeddings(model, dataloader, "validation")
print("External Validation Loss", valid_loss/len(dataloader))

writer.close()

# Save Weights and Dataset Indices
torch.save(model.state_dict(), MODEL_WEIGHTS_FILE)
test_indices = test_dataset.indices
with open('test_indices.pkl', 'wb') as f:
    pickle.dump(test_indices, f)

Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
loss 0.6917188495397568 
auc 0.55
External Validation Loss 0.6739078834092707


RuntimeError: Parent directory ../model_weights/gene_expression/BRCA_10x/baselines does not exist.