## Import Required Packages

In [124]:
import tensorflow as tf
import tensorflow_addons as tfa
from tqdm import tqdm
import pandas as pd
import sklearn
from sklearn import metrics
import re
import numpy as np
import pickle as pkl
import PIL
import datetime
import os
import random
import shutil
import statistics
import time

## Load Self-Defined Functions

In [588]:
def get_data_from_tf(tf_path):
    feature = {'height': tf.io.FixedLenFeature([], tf.int64),
               'width': tf.io.FixedLenFeature([], tf.int64),
               'depth': tf.io.FixedLenFeature([], tf.int64),
               'label': tf.io.FixedLenFeature([], tf.int64),
               'image/format': tf.io.FixedLenFeature([], tf.string),
               'image_name': tf.io.FixedLenFeature([], tf.string),
               'image/encoded': tf.io.FixedLenFeature([], tf.string),
               'image_feature': tf.io.FixedLenFeature([], tf.string)}

    tfrecord_dataset = tf.data.TFRecordDataset(tf_path)

    def _parse_image_function(key):
        return tf.io.parse_single_example(key, feature)

    CLAM_dataset = tfrecord_dataset.map(_parse_image_function)

    image_features = list()

    for tfrecord_value in CLAM_dataset:
        img_feature = tf.io.parse_tensor(tfrecord_value['image_feature'], 'float32')
        slide_labels = tfrecord_value['label']
        slide_label = int(slide_labels)
        image_features.append(img_feature)

    return image_features, slide_label

In [126]:
def most_frequent(List):
    mf = max(set(List), key=List.count)
    return mf

In [127]:
def tf_shut_up(no_warn_op=False):
    if no_warn_op:
        tf.get_logger().setLevel('ERROR')
    else:
        print('Are you sure you want to receive the annoying TensorFlow Warning Messages?', \
              '\n', 'If not, check the value of your input prameter for this function and re-run it.')

## Load CLAM Model

### Import None-Gated Attention Network

In [581]:
class NG_Att_Net(tf.keras.Model):
    def __init__(self, dim_features=1024, dim_compress_features=512, n_hidden_units=256, n_classes=2,
                 dropout=False, dropout_rate=.25):
        super(NG_Att_Net, self).__init__()
        self.dim_features = dim_features
        self.dim_compress_features = dim_compress_features
        self.n_hidden_units = n_hidden_units
        self.n_classes = n_classes
        self.dropout = dropout
        self.dropout_rate = dropout_rate

        self.compression_model = tf.keras.models.Sequential()
        self.model = tf.keras.models.Sequential()

        self.fc_compress_layer = tf.keras.layers.Dense(units=dim_compress_features, activation='relu',
                                                       input_shape=(dim_features,), kernel_initializer='glorot_normal',
                                                       bias_initializer='zeros', name='Fully_Connected_Layer')

        self.compression_model.add(self.fc_compress_layer)
        self.model.add(self.fc_compress_layer)

        self.att_layer1 = tf.keras.layers.Dense(units=n_hidden_units, activation='tanh',
                                                input_shape=(dim_compress_features,),
                                                kernel_initializer='glorot_normal', bias_initializer='zeros',
                                                name='Attention_Layer1')

        self.att_layer2 = tf.keras.layers.Dense(units=n_classes, activation='linear', input_shape=(n_hidden_units,),
                                                kernel_initializer='glorot_normal', bias_initializer='zeros',
                                                name='Attention_Layer2')

        self.model.add(self.att_layer1)

        if dropout:
            self.model.add(tf.keras.layers.Dropout(dropout_rate, name='Dropout_Layer'))

        self.model.add(self.att_layer2)

    def att_model(self):
        attention_model = [self.compression_model, self.model]
        return attention_model

    def call(self, x):
        h = list()
        A = list()
        
        for i in x:
            c_imf = self.att_model()[0](i)
            h.append(c_imf)
        
        for i in x:
            a = self.att_model()[1](i)
            A.append(a)
        return h, A

### Import Gated Attention Network

In [582]:
class G_Att_Net(tf.keras.Model):
    def __init__(self, dim_features=1024, dim_compress_features=512, n_hidden_units=256, n_classes=2,
                 dropout=False, dropout_rate=.25):
        super(G_Att_Net, self).__init__()
        self.dim_features = dim_features
        self.dim_compress_features = dim_compress_features
        self.n_hidden_units = n_hidden_units
        self.n_classes = n_classes
        self.dropout = dropout
        self.dropout_rate = dropout_rate

        self.compression_model = tf.keras.models.Sequential()
        self.model1 = tf.keras.models.Sequential()
        self.model2 = tf.keras.models.Sequential()
        self.model = tf.keras.models.Sequential()

        self.fc_compress_layer = tf.keras.layers.Dense(units=dim_compress_features, activation='relu',
                                                       input_shape=(dim_features,), kernel_initializer='glorot_normal',
                                                       bias_initializer='zeros', name='Fully_Connected_Layer')

        self.compression_model.add(self.fc_compress_layer)
        self.model1.add(self.fc_compress_layer)
        self.model2.add(self.fc_compress_layer)

        self.att_layer1 = tf.keras.layers.Dense(units=n_hidden_units, activation='tanh', input_shape=(dim_features,),
                                                kernel_initializer='glorot_normal', bias_initializer='zeros',
                                                name='Attention_Layer1')

        self.att_layer2 = tf.keras.layers.Dense(units=n_hidden_units, activation='sigmoid', input_shape=(dim_features,),
                                                kernel_initializer='glorot_normal', bias_initializer='zeros',
                                                name='Attention_Layer2')

        self.att_layer3 = tf.keras.layers.Dense(units=n_classes, activation='linear', input_shape=(n_hidden_units,),
                                                kernel_initializer='glorot_normal', bias_initializer='zeros',
                                                name='Attention_Layer3')

        self.model1.add(self.att_layer1)
        self.model2.add(self.att_layer2)

        if dropout:
            self.model1.add(tf.keras.layers.Dropout(dropout_rate, name='Dropout_Layer'))
            self.model2.add(tf.keras.layers.Dropout(dropout_rate, name='Dropout_Layer'))

        self.model.add(self.att_layer3)

    def att_model(self):
        attention_model = [self.compression_model, self.model1, self.model2, self.model]
        return attention_model

    def call(self, x):
        h = list()
        A = list()
        
        for i in x:
            c_imf = self.att_model()[0](i)
            h.append(c_imf)
            
        for i in x:
            layer1_output = self.att_model()[1](i)  
            layer2_output = self.att_model()[2](i)  
            a = tf.math.multiply(layer1_output, layer2_output)  
            a = self.att_model()[3](a)  
            A.append(a)

        return h, A

### Import Instance Classifier Model

In [583]:
class Ins(tf.keras.Model):
    def __init__(self, dim_compress_features=512, n_class=2, n_ins=8, mut_ex=False):
        super(Ins, self).__init__()
        self.dim_compress_features = dim_compress_features
        self.n_class = n_class
        self.n_ins = n_ins
        self.mut_ex = mut_ex

        self.ins_model = list()
        self.m_ins_model = tf.keras.models.Sequential()
        self.m_ins_layer = tf.keras.layers.Dense(
            units=self.n_class, activation='linear', input_shape=(self.dim_compress_features,),
            name='Instance_Classifier_Layer'
        )
        self.m_ins_model.add(self.m_ins_layer)

        for i in range(self.n_class):
            self.ins_model.append(self.m_ins_model)

    def ins_classifier(self):
        return self.ins_model
    
    @staticmethod
    def generate_pos_labels(n_pos_sample):
        return tf.fill(dims=[n_pos_sample, ], value=1)

    @staticmethod
    def generate_neg_labels(n_neg_sample):
        return tf.fill(dims=[n_neg_sample, ], value=0)
    
    def in_call(self, ins_classifier, h, A_I):
        pos_label = self.generate_pos_labels(self.n_ins)
        neg_label = self.generate_neg_labels(self.n_ins)
        ins_label_in = tf.concat(values=[pos_label, neg_label], axis=0)
        A_I = tf.reshape(tf.convert_to_tensor(A_I), (1, len(A_I))) 
        
        top_pos_ids = tf.math.top_k(A_I, self.n_ins)[1][-1]  
        pos_index = list()
        for i in top_pos_ids:
            pos_index.append(i)

        pos_index = tf.convert_to_tensor(pos_index)
        top_pos = list()
        for i in pos_index:
            top_pos.append(h[i])
 
        top_neg_ids = tf.math.top_k(-A_I, self.n_ins)[1][-1]
        neg_index = list()
        for i in top_neg_ids:
             neg_index.append(i)

        neg_index = tf.convert_to_tensor(neg_index)
        top_neg = list()
        for i in neg_index:
            top_neg.append(h[i])

        ins_in = tf.concat(values=[top_pos, top_neg], axis=0)
        logits_unnorm_in = list()
        logits_in = list()
        
        for i in range(self.n_class * self.n_ins):
            ins_score_unnorm_in = ins_classifier(ins_in[i])
            logit_in = tf.math.softmax(ins_score_unnorm_in)
            logits_unnorm_in.append(ins_score_unnorm_in)
            logits_in.append(logit_in)

        return ins_label_in, logits_unnorm_in, logits_in
    
    def out_call(self, ins_classifier, h, A_O):
        # get compressed 512-dimensional instance-level feature vectors for following use, denoted by h
        A_O = tf.reshape(tf.convert_to_tensor(A_O), (1, len(A_O)))
        top_pos_ids = tf.math.top_k(A_O, self.n_ins)[1][-1]
        pos_index = list()
        for i in top_pos_ids:
            pos_index.append(i)

        pos_index = tf.convert_to_tensor(pos_index)
        top_pos = list()
        for i in pos_index:
            top_pos.append(h[i])

        # mutually-exclusive -> top k instances w/ highest attention scores ==> false pos = neg
        pos_ins_labels_out = self.generate_neg_labels(self.n_ins)
        ins_label_out = pos_ins_labels_out
        
        logits_unnorm_out = list()
        logits_out = list()
  
        for i in range(self.n_ins):
            ins_score_unnorm_out = ins_classifier(top_pos[i])
            logit_out = tf.math.softmax(ins_score_unnorm_out)
            logits_unnorm_out.append(ins_score_unnorm_out)
            logits_out.append(logit_out)

        return ins_label_out, logits_unnorm_out, logits_out
    
    def call(self, bag_label, h, A):
        for i in range(self.n_class):
            ins_classifier = self.ins_classifier()[i]
            if i == bag_label:
                A_I = list()
                for j in range(len(A)):
                    a_i = A[j][0][i]
                    A_I.append(a_i)
                ins_label_in, logits_unnorm_in, logits_in = self.in_call(ins_classifier, h, A_I)
            else:
                if self.mut_ex:
                    A_O = list()
                    for j in range(len(A)):
                        a_o = A[j][0][i]
                        A_O.append(a_o)
                    ins_label_out, logits_unnorm_out, logits_out = self.out_call(ins_classifier, h, A_O)
                else:
                    continue
                    
        if self.mut_ex:
            ins_labels = tf.concat(values=[ins_label_in, ins_label_out], axis=0)
            ins_logits_unnorm = logits_unnorm_in + logits_unnorm_out
            ins_logits = logits_in + logits_out
        else:
            ins_labels = ins_label_in
            ins_logits_unnorm = logits_unnorm_in
            ins_logits = logits_in
        
        return ins_labels, ins_logits_unnorm, ins_logits

### Import Bag Classifier Model

In [584]:
class S_Bag(tf.keras.Model):
    def __init__(self, dim_compress_features=512, n_class=2):
        super(S_Bag, self).__init__()
        self.dim_compress_features = dim_compress_features
        self.n_class = n_class

        self.s_bag_model = tf.keras.models.Sequential()
        self.s_bag_layer = tf.keras.layers.Dense(
            units=1, activation='linear', input_shape=(self.n_class, self.dim_compress_features),
            name='Bag_Classifier_Layer'
        )
        self.s_bag_model.add(self.s_bag_layer)

    def bag_classifier(self):
        return self.s_bag_model

    def h_slide(self, A, h):
        # compute the slide-level representation aggregated per the attention score distribution for the mth class
        SAR = list()
        for i in range(len(A)):
            sar = tf.linalg.matmul(tf.transpose(A[i]), h[i])  # shape be (2,512)
            SAR.append(sar)
        slide_agg_rep = tf.math.add_n(SAR)   # return h_[slide,m], shape be (2,512)
        
        return slide_agg_rep
    
    def call(self, bag_label, A, h):
        slide_agg_rep = self.h_slide(A, h)
        bag_classifier = self.bag_classifier()
        slide_score_unnorm = bag_classifier(slide_agg_rep)
        slide_score_unnorm = tf.reshape(slide_score_unnorm, (1, self.n_class))
        Y_hat = tf.math.top_k(slide_score_unnorm ,1)[1][-1]
        Y_prob = tf.math.softmax(tf.reshape(slide_score_unnorm, (1, self.n_class)))   #shape be (1,2), predictions for each of the classes
        predict_slide_label = np.argmax(Y_prob.numpy())
        
        Y_true = tf.one_hot([bag_label], 2)

        return slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true

In [589]:
class M_Bag(tf.keras.Model):
    def __init__(self, dim_compress_features=512, n_class=2):
        super(M_Bag, self).__init__()
        self.dim_compress_features = dim_compress_features
        self.n_class = n_class

        self.m_bag_models = list()
        self.m_bag_model = tf.keras.models.Sequential() 
        self.m_bag_layer = tf.keras.layers.Dense(
            units = 1, activation = 'linear', input_shape=(self.dim_compress_features,), name = 'Bag_Classifier_Layer'
        )
        self.m_bag_model.add(self.m_bag_layer)
        for i in range(self.n_class):
            self.m_bag_models.append(self.m_bag_model)
            
    def bag_classifier(self):       
        return self.m_bag_models

    def h_slide(self, A, h):
        # compute the slide-level representation aggregated per the attention score distribution for the mth class
        SAR = list()
        for i in range(len(A)):
            sar = tf.linalg.matmul(tf.transpose(A[i]), h[i])  # shape be (2,512)
            SAR.append(sar)
        slide_agg_rep = tf.math.add_n(SAR)  # return h_[slide,m], shape be (2,512)

        return slide_agg_rep

    def in_call(self, bag_classifier, h_slide_I):
        ssu_in = bag_classifier(h_slide_I)[0][0]

        return ssu_in
    
    def out_call(self, bag_classifier, h_slide_O):
        ssu_out = bag_classifier(h_slide_O)[0][0]
        
        return ssu_out
    
    def call(self, bag_label, A, h):
        slide_agg_rep = self.h_slide(A, h)
        # unnormalized slide-level score (s_[slide,m]) with uninitialized entries, shape be (1,num_of_classes)
        slide_score_unnorm = tf.Variable(np.empty((1, self.n_class)), dtype=tf.float32)
        slide_score_unnorm = tf.reshape(slide_score_unnorm, (1, self.n_class)).numpy()
 
        # return s_[slide,m] (slide-level prediction scores)
        for i in range(self.n_class):
            bag_classifier = self.bag_classifier()[i]
            if i == bag_label:
                h_slide_I = tf.reshape(slide_agg_rep[i], (1, self.dim_compress_features))
                ssu_in = self.in_call(bag_classifier, h_slide_I)
            else:
                h_slide_O = tf.reshape(slide_agg_rep[i], (1, self.dim_compress_features))
                ssu_out = self.out_call(bag_classifier, h_slide_O)
                
        for i in range(self.n_class):
            if i == bag_label:
                slide_score_unnorm[0, i] = ssu_in
            else:
                slide_score_unnorm[0, i] = ssu_out
        slide_score_unnorm = tf.convert_to_tensor(slide_score_unnorm)

        Y_hat = tf.math.top_k(slide_score_unnorm, 1)[1][-1]
        Y_prob = tf.math.softmax(slide_score_unnorm)
        predict_slide_label = np.argmax(Y_prob.numpy())
        
        Y_true = tf.one_hot([bag_label], 2)

        return slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true

### Import CLAM Model

In [586]:
class S_CLAM(tf.keras.Model):
    def __init__(self, att_gate=False, net_size='small', n_ins=8, n_class=2, mut_ex=False, 
                 dropout=False, drop_rate=.25, mil_ins=False, att_only=False):
        super(S_CLAM, self).__init__()
        self.att_gate = att_gate
        self.net_size = net_size
        self.n_ins = n_ins
        self.n_class = n_class
        self.mut_ex = mut_ex
        self.dropout = dropout
        self.drop_rate = drop_rate
        self.mil_ins = mil_ins
        self.att_only = att_only
        
        self.net_shape_dict = {
            'small': [1024, 512, 256],
            'big': [1024, 512, 384]
        }
        self.net_shape = self.net_shape_dict[self.net_size]
        
        if self.att_gate:
            self.att_net = G_Att_Net(dim_features=self.net_shape[0], dim_compress_features=self.net_shape[1], n_hidden_units=self.net_shape[2],
                                    n_classes=self.n_class, dropout=self.dropout, dropout_rate=self.drop_rate)
        else:
            self.att_net = NG_Att_Net(dim_features=self.net_shape[0], dim_compress_features=self.net_shape[1], n_hidden_units=self.net_shape[2],
                                    n_classes=self.n_class, dropout=self.dropout, dropout_rate=self.drop_rate)
        
        self.ins_net = Ins(dim_compress_features=self.net_shape[1], n_class=self.n_class, n_ins=self.n_ins, mut_ex=self.mut_ex)
        
        self.bag_net = S_Bag(dim_compress_features=self.net_shape[1], n_class=self.n_class)
        
    def clam_model(self):
        att_model = self.att_net.att_model()
        ins_classifier = self.ins_net.ins_classifier()
        bag_classifier = self.bag_net.bag_classifier()
        
        clam_model = [att_model, ins_classifier, bag_classifier]
        
        return clam_model

    def call(self, img_features, slide_label):
        """
        Args:
            img_features -> original 1024-dimensional instance-level feature vectors
            slide_label -> ground-truth slide label, could be 0 or 1 for binary classification
        """

        h, A = self.att_net.call(img_features)
        att_score = A  # output from attention network
        A = tf.math.softmax(A)   # softmax onattention scores 

        if self.att_only:
            return att_score
        
        if self.mil_ins:
            ins_labels, ins_logits_unnorm, ins_logits = self.ins_net.call(slide_label, h, A)

        slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = self.bag_net.call(slide_label, A, h)

        return att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, Y_prob, Y_hat, Y_true, predict_slide_label

In [587]:
class M_CLAM(tf.keras.Model):
    def __init__(self, att_gate=False, net_size='small', n_ins=8, n_class=2, mut_ex=False,
                 dropout=False, drop_rate=.25, mil_ins=False, att_only=False):
        super(M_CLAM, self).__init__()
        self.att_gate = att_gate
        self.net_size = net_size
        self.n_ins = n_ins
        self.n_class = n_class
        self.mut_ex = mut_ex
        self.dropout = dropout
        self.drop_rate = drop_rate
        self.mil_ins = mil_ins
        self.att_only = att_only

        self.net_shape_dict = {
            'small': [1024, 512, 256],
            'big': [1024, 512, 384]
        }
        self.net_shape = self.net_shape_dict[self.net_size]

        if self.att_gate:
            self.att_net = G_Att_Net(dim_features=self.net_shape[0], dim_compress_features=self.net_shape[1],
                                     n_hidden_units=self.net_shape[2], n_classes=self.n_class, 
                                     dropout=self.dropout, dropout_rate=self.drop_rate)
        else:
            self.att_net = NG_Att_Net(dim_features=self.net_shape[0], dim_compress_features=self.net_shape[1],
                                      n_hidden_units=self.net_shape[2], n_classes=self.n_class, 
                                      dropout=self.dropout, dropout_rate=self.drop_rate)

        self.ins_net = Ins(dim_compress_features=self.net_shape[1], n_class=self.n_class, 
                           n_ins=self.n_ins, mut_ex=self.mut_ex)
        
        self.bag_net = M_Bag(dim_compress_features=self.net_shape[1], n_class=self.n_class)
        
    def clam_model(self):
        att_model = self.att_net.att_model()
        ins_classifier = self.ins_net.ins_classifier()
        bag_classifier = self.bag_net.bag_classifier()
        
        clam_model = [att_model, ins_classifier, bag_classifier]
        
        return clam_model
    
    def call(self, img_features, slide_label):
        """
        Args:
            img_features -> original 1024-dimensional instance-level feature vectors
            slide_label -> ground-truth slide label, could be 0 or 1 for binary classification
        """

        h, A = self.att_net.call(img_features)
        att_score = A  # output from attention network
        A = tf.math.softmax(A)  # softmax onattention scores

        if self.att_only:
            return att_score

        if self.mil_ins:
            ins_labels, ins_logits_unnorm, ins_logits = self.ins_net.call(slide_label, h, A)

        slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = self.bag_net.call(slide_label, A, h)

        return att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, Y_prob, Y_hat, Y_true, predict_slide_label

## Train CLAM Model

### Train CLAM Model on the Given Training Data

In [167]:
def nb_optimize(img_features, slide_label, i_model, b_model, c_model, i_optimizer, b_optimizer, c_optimizer, 
                i_loss_func, b_loss_func, n_class, c1, c2, mutual_ex):
    
    with tf.GradientTape() as i_tape, tf.GradientTape() as b_tape, tf.GradientTape() as c_tape:
        
        att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
        Y_prob, Y_hat, Y_true, predict_slide_label = c_model.call(img_features, slide_label)

        ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
        ins_loss = list()
        for j in range(len(ins_logits)):
            i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
            ins_loss.append(i_loss)
        if mutual_ex:
            I_Loss = tf.math.add_n(ins_loss) / n_class
        else:
            I_Loss = tf.math.add_n(ins_loss)

        slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = b_model.call(slide_label, A, h)
        
        B_Loss = b_loss_func(Y_true, Y_prob)
        
        T_Loss = c1 * B_Loss + c2 * I_Loss

    i_grad = i_tape.gradient(I_Loss, i_model.trainable_weights)
    i_optimizer.apply_gradients(zip(i_grad, i_model.trainable_weights))

    b_grad = b_tape.gradient(B_Loss, b_model.trainable_weights)
    b_optimizer.apply_gradients(zip(b_grad, b_model.trainable_weights))

    c_grad = c_tape.gradient(T_Loss, c_model.trainable_weights)
    c_optimizer.apply_gradients(zip(c_grad, c_model.trainable_weights))
    
    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [590]:
def b_optimize(batch_size, n_ins, n_samples, img_features, slide_label, i_model, b_model,
               c_model, i_optimizer, b_optimizer, c_optimizer, i_loss_func, b_loss_func,
               n_class, c1, c2, mutual_ex):
    
    step_size = 0
    
    Ins_Loss = list()
    Bag_Loss = list()
    Total_Loss = list()
    
    label_predict = list()
    
    for n_step in range(0, (n_samples // batch_size + 1)):
        if step_size < (n_samples - batch_size):
            with tf.GradientTape() as i_tape, tf.GradientTape() as b_tape, tf.GradientTape() as c_tape:
                att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
                Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[step_size:(step_size + batch_size)],
                                                                    slide_label=slide_label)

                ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
                
                ins_loss = list()
                for j in range(len(ins_logits)):
                    i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                    ins_loss.append(i_loss)
                if mutual_ex:
                    Loss_I = tf.math.add_n(ins_loss) / n_class
                else:
                    Loss_I = tf.math.add_n(ins_loss)

                slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
                
                Loss_B = b_loss_func(Y_true, Y_prob)
                
                Loss_T = c1 * Loss_B + c2 * Loss_I

            i_grad = i_tape.gradient(Loss_I, i_model.trainable_weights)
            i_optimizer.apply_gradients(zip(i_grad, i_model.trainable_weights))

            b_grad = b_tape.gradient(Loss_B, b_model.trainable_weights)
            b_optimizer.apply_gradients(zip(b_grad, b_model.trainable_weights))

            c_grad = c_tape.gradient(Loss_T, c_model.trainable_weights)
            c_optimizer.apply_gradients(zip(c_grad, c_model.trainable_weights))
    
        else:
            with tf.GradientTape() as i_tape, tf.GradientTape() as b_tape, tf.GradientTape() as c_tape:
                att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
                Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[(step_size - n_ins):],
                                                                    slide_label=slide_label)

                ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
                
                ins_loss = list()
                for j in range(len(ins_logits)):
                    i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                    ins_loss.append(i_loss)
                if mutual_ex:
                    Loss_I = tf.math.add_n(ins_loss) / n_class
                else:
                    Loss_I = tf.math.add_n(ins_loss)

                slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
                
                Loss_B = b_loss_func(Y_true, Y_prob)
                
                Loss_T = c1 * Loss_B + c2 * Loss_I

            i_grad = i_tape.gradient(Loss_I, i_model.trainable_weights)
            i_optimizer.apply_gradients(zip(i_grad, i_model.trainable_weights))

            b_grad = b_tape.gradient(Loss_B, b_model.trainable_weights)
            b_optimizer.apply_gradients(zip(b_grad, b_model.trainable_weights))

            c_grad = c_tape.gradient(Loss_T, c_model.trainable_weights)
            c_optimizer.apply_gradients(zip(c_grad, c_model.trainable_weights))
            
        Ins_Loss.append(float(Loss_I))
        Bag_Loss.append(float(Loss_B))
        Total_Loss.append(float(Loss_T))
        
        label_predict.append(predict_label)
        
        step_size += batch_size
    
    I_Loss = statistics.mean(Ins_Loss)
    B_Loss = statistics.mean(Bag_Loss)
    T_Loss = statistics.mean(Total_Loss)
    
    predict_slide_label = most_frequent(label_predict)
    
    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [591]:
def train_step(i_model, b_model, c_model, train_path, i_optimizer_func, b_optimizer_func,
               c_optimizer_func, i_loss_func, b_loss_func, mutual_ex, n_class, c1, c2, 
               learn_rate, l2_decay, n_ins, batch_size, batch_op):
    
    loss_total = list()
    loss_ins = list()
    loss_bag = list()

    i_optimizer = i_optimizer_func(learning_rate=learn_rate, weight_decay=l2_decay)
    b_optimizer = b_optimizer_func(learning_rate=learn_rate, weight_decay=l2_decay)
    c_optimizer = c_optimizer_func(learning_rate=learn_rate, weight_decay=l2_decay)

    slide_true_label = list()
    slide_predict_label = list()

    train_sample_list = os.listdir(train_path)
    train_sample_list = random.sample(train_sample_list, len(train_sample_list))
    for i in train_sample_list:
        print('=', end="")
        single_train_data = train_path + i
        img_features, slide_label = get_data_from_tf(single_train_data)
        # shuffle the order of img features list in order to reduce the side effects of randomly drop potential 
        # number of patches' feature vectors during training when enable batch training option
        img_features = random.sample(img_features, len(img_features))
        
        if batch_op:
            I_Loss, B_Loss, T_Loss, predict_slide_label = b_optimize(batch_size=batch_size, n_ins=n_ins, n_samples=len(img_features), 
                                                                     img_features=img_features, slide_label=slide_label, 
                                                                     i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                     i_optimizer=i_optimizer, b_optimizer=b_optimizer, 
                                                                     c_optimizer=c_optimizer, i_loss_func=i_loss_func, 
                                                                     b_loss_func = b_loss_func, n_class=n_class, c1=c1, 
                                                                     c2=c2, mutual_ex=mutual_ex)
        else:
            I_Loss, B_Loss, T_Loss, predict_slide_label = nb_optimize(img_features=img_features, slide_label=slide_label,
                                                                      i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                      i_optimizer=i_optimizer, b_optimizer=b_optimizer, 
                                                                      c_optimizer=c_optimizer, i_loss_func=i_loss_func, 
                                                                      b_loss_func=b_loss_func, n_class=n_class, c1=c1, c2=c2, 
                                                                      mutual_ex=mutual_ex)

        loss_total.append(float(T_Loss))
        loss_ins.append(float(I_Loss))
        loss_bag.append(float(B_Loss))

        slide_true_label.append(slide_label)
        slide_predict_label.append(predict_slide_label)

    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(slide_true_label, slide_predict_label).ravel()
    train_tn = int(tn)
    train_fp = int(fp)
    train_fn = int(fn)
    train_tp = int(tp)

    train_sensitivity = round(train_tp / (train_tp + train_fn), 2)
    train_specificity = round(train_tn / (train_tn + train_fp), 2)
    train_acc = round((train_tp + train_tn) / (train_tn + train_fp + train_fn + train_tp), 2)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(slide_true_label, slide_predict_label, pos_label=1)
    train_auc = round(sklearn.metrics.auc(fpr, tpr), 2)

    train_loss = statistics.mean(loss_total)
    train_ins_loss = statistics.mean(loss_ins)
    train_bag_loss = statistics.mean(loss_bag)

    return train_loss, train_ins_loss, train_bag_loss, train_tn, train_fp, train_fn, train_tp, train_sensitivity, \
           train_specificity, train_acc, train_auc

### Validating CLAM Model

In [592]:
def nb_val(img_features, slide_label, i_model, b_model, c_model, 
           i_loss_func, b_loss_func, n_class, c1, c2, mutual_ex):
    
    att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
    Y_prob, Y_hat, Y_true, predict_slide_label = c_model.call(img_features, slide_label)
 
    ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
    
    ins_loss = list()
    for j in range(len(ins_logits)):
        i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
        ins_loss.append(i_loss)
    if mutual_ex:
        I_Loss = tf.math.add_n(ins_loss) / n_class
    else:
        I_Loss = tf.math.add_n(ins_loss)

    slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = b_model.call(slide_label, A, h)

    B_Loss = b_loss_func(Y_true, Y_prob)
    
    T_Loss = c1 * B_Loss + c2 * I_Loss
    
    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [593]:
def b_val(batch_size, n_ins, n_samples, img_features, slide_label, i_model, b_model,
          c_model, i_loss_func, b_loss_func, n_class, c1, c2, mutual_ex):
    
    step_size = 0
    
    Ins_Loss = list()
    Bag_Loss = list()
    Total_Loss = list()
    
    label_predict = list()
    
    for n_step in range(0, (n_samples // batch_size + 1)):
        if step_size < (n_samples - batch_size):
            att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
            Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[step_size:(step_size + batch_size)],
                                                                slide_label=slide_label)

            ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
            
            ins_loss = list()
            for j in range(len(ins_logits)):
                i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                ins_loss.append(i_loss)
            if mutual_ex:
                Loss_I = tf.math.add_n(ins_loss) / n_class
            else:
                Loss_I = tf.math.add_n(ins_loss)

            slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
            
            Loss_B = b_loss_func(Y_true, Y_prob)
            Loss_T = c1 * Loss_B + c2 * Loss_I
            
        else:
            att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
            Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[(step_size - n_ins):],
                                                                slide_label=slide_label)

            ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
            
            ins_loss = list()
            for j in range(len(ins_logits)):
                i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                ins_loss.append(i_loss)
            if mutual_ex:
                Loss_I = tf.math.add_n(ins_loss) / n_class
            else:
                Loss_I = tf.math.add_n(ins_loss)

            slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
            
            Loss_B = b_loss_func(Y_true, Y_prob)
            
            Loss_T = c1 * Loss_B + c2 * Loss_I
        
        Ins_Loss.append(float(Loss_I))
        Bag_Loss.append(float(Loss_B))
        Total_Loss.append(float(Loss_T))

        label_predict.append(predict_label)
        
        step_size += batch_size
    
    I_Loss = statistics.mean(Ins_Loss)
    B_Loss = statistics.mean(Bag_Loss)
    T_Loss = statistics.mean(Total_Loss)
    
    predict_slide_label = most_frequent(label_predict)

    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [594]:
def val_step(i_model, b_model, c_model, val_path, i_loss_func, b_loss_func, mutual_ex, 
             n_class, c1, c2, n_ins, batch_size, batch_op):
    loss_t = list()
    loss_i = list()
    loss_b = list()

    slide_true_label = list()
    slide_predict_label = list()

    val_sample_list = os.listdir(val_path)
    val_sample_list = random.sample(val_sample_list, len(val_sample_list))
    for i in val_sample_list:
        print('=', end="")
        single_val_data = val_path + i
        img_features, slide_label = get_data_from_tf(single_val_data)
        
        img_features = random.sample(img_features, len(img_features)) # follow the training loop, see details there
                                    
        if batch_op:
            I_Loss, B_Loss, T_Loss, predict_slide_label = b_val(batch_size=batch_size, n_ins=n_ins, n_samples=len(img_features),
                                                                img_features=img_features, slide_label=slide_label, 
                                                                i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                i_loss_func=i_loss_func, b_loss_func=b_loss_func, 
                                                                n_class=n_class, c1=c1, c2=c2, mutual_ex=mutual_ex)
        else:    
            I_Loss, B_Loss, T_Loss, predict_slide_label = nb_val(img_features=img_features, slide_label=slide_label, 
                                                                 i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                 i_loss_func=i_loss_func, b_loss_func=b_loss_func, 
                                                                 n_class=n_class, c1=c1, c2=c2, mutual_ex=mutual_ex)

        loss_t.append(float(T_Loss))
        loss_i.append(float(I_Loss))
        loss_b.append(float(B_Loss))

        slide_true_label.append(slide_label)
        slide_predict_label.append(predict_slide_label)

    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(slide_true_label, slide_predict_label).ravel()
    val_tn = int(tn)
    val_fp = int(fp)
    val_fn = int(fn)
    val_tp = int(tp)

    val_sensitivity = round(val_tp / (val_tp + val_fn), 2)
    val_specificity = round(val_tn / (val_tn + val_fp), 2)
    val_acc = round((val_tp + val_tn) / (val_tn + val_fp + val_fn + val_tp), 2)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(slide_true_label, slide_predict_label, pos_label=1)
    val_auc = round(sklearn.metrics.auc(fpr, tpr), 2)

    val_loss = statistics.mean(loss_t)
    val_ins_loss = statistics.mean(loss_i)
    val_bag_loss = statistics.mean(loss_b)

    return val_loss, val_ins_loss, val_bag_loss, val_tn, val_fp, val_fn, val_tp, val_sensitivity, val_specificity, \
           val_acc, val_auc

## Test Optimized CLAM Model

In [595]:
def test(i_model, b_model, c_model, test_path, result_path, result_file_name):
    start_time = time.time()

    slide_true_label = list()
    slide_predict_label = list()
    sample_names = list()

    for i in os.listdir(test_path):
        print('>', end="")
        single_test_data = test_path + i
        img_features, slide_label = get_data_from_tf(single_test_data)

        att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
        Y_prob, Y_hat, Y_true, predict_slide_label = c_model.call(img_features, slide_label)
    
        ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
        
        slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = b_model.call(slide_label, A, h)
    
        slide_true_label.append(slide_label)
        slide_predict_label.append(predict_slide_label)
        sample_names.append(i)

        test_results = pd.DataFrame(list(zip(sample_names, slide_true_label, slide_predict_label)),
                                    columns=['Sample Names', 'Slide True Label', 'Slide Predict Label'])
        test_results.to_csv(os.path.join(result_path, result_file_name), sep='\t', index=False)

    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(slide_true_label, slide_predict_label).ravel()
    test_tn = int(tn)
    test_fp = int(fp)
    test_fn = int(fn)
    test_tp = int(tp)

    test_sensitivity = round(test_tp / (test_tp + test_fn), 2)
    test_specificity = round(test_tn / (test_tn + test_fp), 2)
    test_acc = round((test_tp + test_tn) / (test_tn + test_fp + test_fn + test_tp), 2)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(slide_true_label, slide_predict_label, pos_label=1)
    test_auc = round(sklearn.metrics.auc(fpr, tpr), 2)

    test_run_time = time.time() - start_time

    template = '\n Test Accuracy: {}, Test Sensitivity: {}, Test Specificity: {}, Test Running Time: {}'
    print(template.format(f"{float(test_acc):.4%}",
                          f"{float(test_sensitivity):.4%}",
                          f"{float(test_specificity):.4%}",
                          "--- %s mins ---" % int(test_run_time / 60)))

    return test_tn, test_fp, test_fn, test_tp, test_sensitivity, test_specificity, test_acc, test_auc

## Saving & Restoring CLAM Model

In [596]:
def model_save(i_model, b_model, c_model, i_model_dir, b_model_dir, c_model_dir, n_class, m_bag_op, m_clam_op, g_att_op):
    for i in range(n_class):
        i_model.ins_classifier()[i].save(os.path.join(i_model_dir, 'M_Ins', 'Class_' + str(i)))
        
    if m_bag_op:
        for j in range(n_class):
            b_model.bag_classifier()[j].save(os.path.join(b_model_dir, 'M_Bag', 'Class_' + str(j)))
    else:
        b_model.bag_classifier().save(os.path.join(b_model_dir, 'S_Bag'))
    
    clam_model_names = ['_Att', '_Ins', '_Bag']
                                         
    if m_clam_op:                                 
        if g_att_op:
            att_nets = c_model.clam_model()[0]
            for m in range(len(att_nets)):
                att_nets[m].save(os.path.join(c_model_dir, 'G' + clam_model_names[0], 'Model_' + str(m + 1)))
        else:
            att_nets = c_model.clam_model()[0]
            for m in range(len(att_nets)):
                att_nets[m].save(os.path.join(c_model_dir, 'NG' + clam_model_names[0], 'Model_' + str(m + 1)))                             
                                         
        for n in range(n_class):
            ins_nets = c_model.clam_model()[1]
            bag_nets = c_model.clam_model()[2]
            
            ins_nets[n].save(os.path.join(c_model_dir, 'M' + clam_model_names[1], 'Class_' + str(n)))
            bag_nets[n].save(os.path.join(c_model_dir, 'M' + clam_model_names[2], 'Class_' + str(n)))
    else:
        if g_att_op:
            att_nets = c_model.clam_model()[0]
            for m in range(len(att_nets)):
                att_nets[m].save(os.path.join(c_model_dir, 'G' + clam_model_names[0], 'Model_' + str(m + 1)))
        else:
            att_nets = c_model.clam_model()[0]
            for m in range(len(att_nets)):
                att_nets[m].save(os.path.join(c_model_dir, 'NG' + clam_model_names[0], 'Model_' + str(m + 1)))
                                         
        for n in range(n_class):
            ins_nets = c_model.clam_model()[1]
            ins_nets[n].save(os.path.join(c_model_dir, 'M' + clam_model_names[1], 'Class_' + str(n)))
        
        c_model.clam_model()[2].save(os.path.join(c_model_dir, 'S' + clam_model_names[2]))

In [597]:
def restore_model(i_model_dir, b_model_dir, c_model_dir, n_class, m_bag_op, m_clam_op, g_att_op):
    i_trained_model = list()
    for i in range(n_class):
        m_ins_names = os.listdir(os.path.join(i_model_dir, 'M_Ins'))
        m_ins_names.sort()
        m_ins_name = m_ins_names[i]
        m_ins_model = tf.keras.models.load_model(os.path.join(i_model_dir, 'M_Ins', m_ins_name))
        i_trained_model.append(m_ins_model)
        
    if m_bag_op:
        b_trained_model = list()
        for j in range(n_class):
            m_bag_names = os.listdir(os.path.join(b_model_dir, 'M_Bag'))
            m_bag_names.sort()
            m_bag_name = m_bag_names[j]                                     
            m_bag_model = tf.keras.models.load_model(os.path.join(b_model_dir, 'M_Bag', m_bag_name))
            b_trained_model.append(m_bag_model)
    else:
        s_bag_name = os.listdir(b_model_dir)[0]
        b_trained_model = tf.keras.models.load_model(os.path.join(b_model_dir, s_bag_name))
    
    clam_model_names = ['_Att', '_Ins', '_Bag']
 
    trained_att_net = list()
    trained_ins_classifier = list()
    trained_bag_classifier = list()
    
    c_trained_model = list()
    
    if m_clam_op:
        if g_att_op:
            att_nets_dir = os.path.join(c_model_dir, 'G' + clam_model_names[0])
            for k in range(len(os.listdir(att_nets_dir))):
                att_net = tf.keras.models.load_model(os.path.join(att_nets_dir, 'Model_' + str(k+1)))
                trained_att_net.append(att_net)
        else:
            att_nets_dir = os.path.join(c_model_dir, 'NG' + clam_model_names[0])
            for k in range(len(os.listdir(att_nets_dir))):
                att_net = tf.keras.models.load_model(os.path.join(att_nets_dir, 'Model_' + str(k+1)))
                trained_att_net.append(att_net)
        
        ins_nets_dir = os.path.join(c_model_dir, 'M' + clam_model_names[1])
        bag_nets_dir = os.path.join(c_model_dir, 'M' + clam_model_names[2])
        
        for m in range(n_class):
            ins_net = tf.keras.models.load_model(os.path.join(ins_nets_dir, 'Class_' + str(m)))
            bag_net = tf.keras.models.load_model(os.path.join(bag_nets_dir, 'Class_' + str(m)))
            
            trained_ins_classifier.append(ins_net)
            trained_bag_classifier.append(bag_net)
        
        c_trained_model = [trained_att_net, trained_ins_classifier, trained_bag_classifier]
    else:
        if g_att_op:
            att_nets_dir = os.path.join(c_model_dir, 'G' + clam_model_names[0])
            for k in range(len(os.listdir(att_nets_dir))):
                att_net = tf.keras.models.load_model(os.path.join(att_nets_dir, 'Model_' + str(k + 1)))
                trained_att_net.append(att_net)
        else:
            att_nets_dir = os.path.join(c_model_dir, 'NG' + clam_model_names[0])
            for k in range(len(os.listdir(att_nets_dir))):
                att_net = tf.keras.models.load_model(os.path.join(att_nets_dir, 'Model_' + str(k + 1)))
                trained_att_net.append(att_net)
        
        ins_nets_dir = os.path.join(c_model_dir, 'M' + clam_model_names[1])
        
        for m in range(n_class):
            ins_net = tf.keras.models.load_model(os.path.join(ins_nets_dir, 'Class_' + str(m)))
            trained_ins_classifier.append(ins_net)
            
        bag_nets_dir = os.path.join(c_model_dir, 'S' + clam_model_names[2])
        trained_bag_classifier.append(tf.keras.models.load_model(bag_nets_dir))
        
        c_trained_model = [trained_att_net, trained_ins_classifier, trained_bag_classifier[0]]
    
    return i_trained_model, b_trained_model, c_trained_model

## Saving & Restoring CLAM Model Training Checkpoints

## Optimizing CLAM Model

### Loading Models for Training

In [474]:
ng_att = NG_Att_Net(dim_features=1024, dim_compress_features=512, n_hidden_units=256, n_classes=2,
                 dropout=False, dropout_rate=.25)

g_att = G_Att_Net(dim_features=1024, dim_compress_features=512, n_hidden_units=256, n_classes=2,
                 dropout=False, dropout_rate=.25)

In [475]:
ins = Ins(dim_compress_features=512, n_class=2, n_ins=8, mut_ex=True)

In [476]:
s_bag = S_Bag(dim_compress_features=512, n_class=2)

m_bag = M_Bag(dim_compress_features=512, n_class=2)

In [477]:
s_clam = S_CLAM(att_gate=True, net_size='big', n_ins=8, n_class=2, mut_ex=False,
            dropout=True, drop_rate=.55, mil_ins=True, att_only=False)

m_clam = M_CLAM(att_gate=True, net_size='big', n_ins=8, n_class=2, mut_ex=False,
            dropout=True, drop_rate=.55, mil_ins=True, att_only=False)

### Loading Required Path

In [478]:
train_nis_bach = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/BACH/No_Image_Standardization/train/'
val_nis_bach = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/BACH/No_Image_Standardization/val/'
test_nis_bach = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/BACH/No_Image_Standardization/test/'

In [479]:
train_is_bach = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/BACH/Image_Standardization/train/'
val_is_bach = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/BACH/Image_Standardization/val/'
test_is_bach = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/BACH/Image_Standardization/test/'

In [480]:
train_nis_tcga = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/No_Image_Standardization/train/'
val_nis_tcga = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/No_Image_Standardization/val/'
test_nis_tcga = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/No_Image_Standardization/test/'
extra_nis_tcga = 'research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/No_Image_Standardization/extra/'

In [481]:
train_is_tcga = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/Image_Standardization/train/'
val_is_tcga = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/Image_Standardization/val/'
test_is_tcga = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/Image_Standardization/test/'
extra_is_tcga = 'research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/TCGA/Image_Standardization/extra/'

In [482]:
clam_result_dir = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM'

In [483]:
i_trained_model_dir = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/Saved_Model/Ins_Classifier'
b_trained_model_dir = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/Saved_Model/Bag_Classifier'
c_trained_model_dir = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/Quincy/Data/CLAM/Saved_Model/CLAM_Model'

In [484]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/' \
'Quincy/Data/CLAM/log/' + current_time + '/train'
val_log_dir = '/research/bsi/projects/PI/tertiary/Hart_Steven_m087494/s211408.DigitalPathology/' \
'Quincy/Data/CLAM/log/' + current_time + '/val'

## Training & Validating CLAM Model

In [485]:
def train_val(train_log, val_log, train_path, val_path, i_model, b_model,
               c_model, i_optimizer_func, b_optimizer_func, c_optimizer_func, 
               i_loss_func, b_loss_func, mutual_ex, n_class, c1, c2, learn_rate, 
               l2_decay, n_ins, batch_size, batch_op, epochs):
    
    train_summary_writer = tf.summary.create_file_writer(train_log)
    val_summary_writer = tf.summary.create_file_writer(val_log)

    for epoch in range(epochs):
        # Training Step
        start_time = time.time()

        train_loss, train_ins_loss, train_bag_loss, train_tn, train_fp, train_fn, train_tp, \
        train_sensitivity, train_specificity, train_acc, train_auc = train_step(
            i_model=i_model, b_model=b_model, c_model=c_model, train_path=train_path,
            i_optimizer_func=i_optimizer_func, b_optimizer_func=b_optimizer_func, 
            c_optimizer_func=c_optimizer_func, i_loss_func=i_loss_func, 
            b_loss_func=b_loss_func, mutual_ex=mutual_ex, n_class=n_class, 
            c1=c1, c2=c2, learn_rate=learn_rate, l2_decay=l2_decay, 
            n_ins=n_ins, batch_size=batch_size, batch_op=batch_op)

        with train_summary_writer.as_default():
            tf.summary.scalar('Total Loss', float(train_loss), step=epoch)
            tf.summary.scalar('Instance Loss', float(train_ins_loss), step=epoch)
            tf.summary.scalar('Bag Loss', float(train_bag_loss), step=epoch)
            tf.summary.scalar('Accuracy', float(train_acc), step=epoch)
            tf.summary.scalar('AUC', float(train_auc), step=epoch)
            tf.summary.scalar('Sensitivity', float(train_sensitivity), step=epoch)
            tf.summary.scalar('Specificity', float(train_specificity), step=epoch)
            tf.summary.histogram('True Positive', int(train_tp), step=epoch)
            tf.summary.histogram('False Positive', int(train_fp), step=epoch)
            tf.summary.histogram('True Negative', int(train_tn), step=epoch)
            tf.summary.histogram('False Negative', int(train_fn), step=epoch)

        # Validation Step
        val_loss, val_ins_loss, val_bag_loss, val_tn, val_fp, val_fn, val_tp, \
        val_sensitivity, val_specificity, val_acc, val_auc = val_step(
            i_model=i_model, b_model=b_model, c_model=c_model, val_path=val_path,
            i_loss_func=i_loss_func, b_loss_func=b_loss_func, mutual_ex=mutual_ex, 
            n_class=n_class, c1=c1, c2=c2, n_ins=n_ins, batch_size=batch_size, batch_op=batch_op)

        with val_summary_writer.as_default():
            tf.summary.scalar('Total Loss', float(val_loss), step=epoch)
            tf.summary.scalar('Instance Loss', float(val_ins_loss), step=epoch)
            tf.summary.scalar('Bag Loss', float(val_bag_loss), step=epoch)
            tf.summary.scalar('Accuracy', float(val_acc), step=epoch)
            tf.summary.scalar('AUC', float(val_auc), step=epoch)
            tf.summary.scalar('Sensitivity', float(val_sensitivity), step=epoch)
            tf.summary.scalar('Specificity', float(val_specificity), step=epoch)
            tf.summary.histogram('True Positive', int(val_tp), step=epoch)
            tf.summary.histogram('False Positive', int(val_fp), step=epoch)
            tf.summary.histogram('True Negative', int(val_tn), step=epoch)
            tf.summary.histogram('False Negative', int(val_fn), step=epoch)

        epoch_run_time = time.time() - start_time
        template = '\n Epoch {},  Train Loss: {}, Train Accuracy: {}, Val Loss: {}, Val Accuracy: {}, Epoch Running ' \
                   'Time: {} '
        print(template.format(epoch + 1,
                              f"{float(train_loss):.8}",
                              f"{float(train_acc):.4%}",
                              f"{float(val_loss):.8}",
                              f"{float(val_acc):.4%}",
                              "--- %s mins ---" % int(epoch_run_time / 60)))

### Main Function to Optimizing and Testing CLAM Model

### Approach 1 - Test the Optimized CLAM Model by Saving the Trained CLAM Model

In [486]:
def clam_optimize(train_log, val_log, train_path, val_path, i_model, b_model,
                  c_model, i_optimizer_func, b_optimizer_func, c_optimizer_func, 
                  i_loss_func, b_loss_func, mutual_ex, n_class, c1, c2, learn_rate, 
                  l2_decay, n_ins, batch_size, batch_op, i_model_dir, b_model_dir, 
                  c_model_dir, m_bag_op, m_clam_op, g_att_op, epochs):
    
    train_val(train_log=train_log, val_log=val_log, train_path=train_path,
               val_path=val_path, i_model=i_model, b_model=b_model, c_model=c_model,
               i_optimizer_func=i_optimizer_func, b_optimizer_func=b_optimizer_func,
               c_optimizer_func=c_optimizer_func, i_loss_func=i_loss_func,
               b_loss_func=b_loss_func, mutual_ex=mutual_ex, n_class=n_class, 
               c1=c1, c2=c2, learn_rate=learn_rate, l2_decay=l2_decay, 
               n_ins=n_ins, batch_size=batch_size, batch_op=batch_op, epochs=epochs)
    
    model_save(i_model=i_model, b_model=b_model, c_model=c_model, 
               i_model_dir=i_model_dir, b_model_dir=b_model_dir, 
               c_model_dir=c_model_dir, n_class=n_class, m_bag_op=m_bag_op, 
               m_clam_op=m_clam_op, g_att_op=g_att_op)

In [487]:
def clam_test(test_path, result_path, result_file_name, 
              i_model_dir, b_model_dir, c_model_dir, 
              n_class, m_bag_op, m_clam_op, g_att_op):
    
    i_trained_model, b_trained_model, c_trained_model = restore_model(i_model_dir=i_model_dir, 
                                                                      b_model_dir=b_model_dir, 
                                                                      c_model_dir=c_model_dir, 
                                                                      n_class=n_class, m_bag_op=m_bag_op, 
                                                                      m_clam_op=m_clam_op, g_att_op=g_att_op)
    
    test_tn, test_fp, test_fn, test_tp, test_sensitivity, test_specificity, \
    test_acc, test_auc = test(i_model=ins, b_model=b_trained_model, 
                              c_model=s_clam, test_path=test_path, 
                              result_path=result_path, result_file_name=result_file_name)

### Approach 2 - Test the Optimized CLAM Model Withought Saving the Trained CLAM Model

In [488]:
def optimize_test(train_log, val_log, train_path, val_path, test_path, result_path, result_file_name,
              i_model, b_model, c_model, i_optimizer_func, b_optimizer_func,
              c_optimizer_func, i_loss_func, b_loss_func, mutual_ex,
              n_class, c1, c2, learn_rate, l2_decay, n_ins, batch_size, batch_op, epochs):
    
    train_val(train_log=train_log, val_log=val_log, train_path=train_path,
               val_path=val_path, i_model=i_model, b_model=b_model, c_model=c_model,
               i_optimizer_func=i_optimizer_func, b_optimizer_func=b_optimizer_func,
               c_optimizer_func=c_optimizer_func, i_loss_func=i_loss_func,
               b_loss_func=b_loss_func, mutual_ex=mutual_ex, n_class=n_class, 
               c1=c1, c2=c2, learn_rate=learn_rate, l2_decay=l2_decay, 
               n_ins=n_ins, batch_size=batch_size, batch_op=batch_op, epochs=epochs)

    test_tn, test_fp, test_fn, test_tp, test_sensitivity, test_specificity, \
    test_acc, test_auc = test(i_model=i_model, b_model=b_model, c_model=c_model, 
                              test_path=test_path, result_path=result_path,
                              result_file_name=result_file_name)

## Start Training, Validating & Testing CLAM Model

In [489]:
tf_shut_up(no_warn_op=True)

In [490]:
clam_optimize(train_log=train_log_dir, val_log=val_log_dir, 
              train_path=train_is_bach, val_path=val_is_bach, 
              i_model=ins, b_model=s_bag, c_model=s_clam, 
              i_optimizer_func=tfa.optimizers.AdamW, 
              b_optimizer_func=tfa.optimizers.AdamW, 
              c_optimizer_func=tfa.optimizers.AdamW, 
              i_loss_func=tf.keras.losses.binary_crossentropy, 
              b_loss_func=tf.keras.losses.binary_crossentropy, 
              mutual_ex=False, n_class=2, c1=0.7, c2=0.3, 
              learn_rate=2e-04, l2_decay=1e-05, n_ins=8, 
              batch_size=2000, batch_op=False, 
              i_model_dir=i_trained_model_dir, 
              b_model_dir=b_trained_model_dir, 
              c_model_dir=c_trained_model_dir, 
              m_bag_op=False, m_clam_op=False, g_att_op=True, epochs=1)

 Epoch 1,  Train Loss: 8.1606448, Train Accuracy: 54.0000%, Val Loss: 9.9709798, Val Accuracy: 50.0000%, Epoch Running Time: --- 1 mins --- 


In [522]:
i_trained_model, b_trained_model, c_trained_model = restore_model(i_model_dir=i_trained_model_dir, 
                                                                  b_model_dir=b_trained_model_dir, 
                                                                  c_model_dir=c_trained_model_dir, 
                                                                  n_class=2, m_bag_op=False, 
                                                                  m_clam_op=False, g_att_op=True)

In [434]:
clam_test(test_path=test_is_bach, result_path=clam_result_dir, 
          result_file_name='test_bach_save_model.tsv', 
          i_model_dir=i_trained_model_dir, 
          b_model_dir=b_trained_model_dir, 
          c_model_dir=c_trained_model_dir, 
          n_class=2, m_bag_op=False, 
          m_clam_op=False, g_att_op=True)

In [400]:
optimize_test(train_log=train_log_dir, val_log=val_log_dir, train_path=train_is_bach,
          val_path=val_is_bach, test_path=test_is_bach, result_path=clam_result_dir, 
          result_file_name='test_bach.tsv', i_model=ins, b_model=s_bag, c_model=s_clam,
          i_optimizer_func=tfa.optimizers.AdamW, b_optimizer_func=tfa.optimizers.AdamW,
          c_optimizer_func=tfa.optimizers.AdamW, i_loss_func=tf.keras.losses.binary_crossentropy,
          b_loss_func=tf.keras.losses.binary_crossentropy, mutual_ex=True, n_class=2,
          c1=0.7, c2=0.3, learn_rate=2e-04, l2_decay=1e-05, n_ins=8, batch_size=1000, batch_op=False, epochs=1)

 Epoch 1,  Train Loss: 7.6590552, Train Accuracy: 50.0000%, Val Loss: 7.6483054, Val Accuracy: 50.0000%, Epoch Running Time: --- 1 mins --- 
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
 Test Accuracy: 50.0000%, Test Sensitivity: 100.0000%, Test Specificity: 0.0000%, Test Running Time: --- 0 mins ---
