## Import Required Packages

In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
from tqdm import tqdm
import pandas as pd
import sklearn
from sklearn import metrics
import re
import numpy as np
import pickle as pkl
import PIL
import datetime
import os
import random
import shutil
import statistics
import time

## Import Required Functions or Methods from Other Files

In [15]:
import import_ipynb
from util import *

In [16]:
from model import *

## Train CLAM Model

### Train CLAM Model on the Given Training Data

In [17]:
def nb_optimize(img_features, slide_label, i_model, b_model, c_model, i_optimizer, b_optimizer, c_optimizer, 
                i_loss_func, b_loss_func, n_class, c1, c2, mutual_ex):
    
    with tf.GradientTape() as i_tape, tf.GradientTape() as b_tape, tf.GradientTape() as c_tape:
        
        att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
        Y_prob, Y_hat, Y_true, predict_slide_label = c_model.call(img_features, slide_label)

        ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
        ins_loss = list()
        for j in range(len(ins_logits)):
            i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
            ins_loss.append(i_loss)
        if mutual_ex:
            I_Loss = tf.math.add_n(ins_loss) / n_class
        else:
            I_Loss = tf.math.add_n(ins_loss)

        slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = b_model.call(slide_label, A, h)
        
        B_Loss = b_loss_func(Y_true, Y_prob)
        
        T_Loss = c1 * B_Loss + c2 * I_Loss

    i_grad = i_tape.gradient(I_Loss, i_model.trainable_weights)
    i_optimizer.apply_gradients(zip(i_grad, i_model.trainable_weights))

    b_grad = b_tape.gradient(B_Loss, b_model.trainable_weights)
    b_optimizer.apply_gradients(zip(b_grad, b_model.trainable_weights))

    c_grad = c_tape.gradient(T_Loss, c_model.trainable_weights)
    c_optimizer.apply_gradients(zip(c_grad, c_model.trainable_weights))
    
    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [18]:
def b_optimize(batch_size, n_ins, n_samples, img_features, slide_label, i_model, b_model,
               c_model, i_optimizer, b_optimizer, c_optimizer, i_loss_func, b_loss_func,
               n_class, c1, c2, mutual_ex):
    
    step_size = 0
    
    Ins_Loss = list()
    Bag_Loss = list()
    Total_Loss = list()
    
    label_predict = list()
    
    for n_step in range(0, (n_samples // batch_size + 1)):
        if step_size < (n_samples - batch_size):
            with tf.GradientTape() as i_tape, tf.GradientTape() as b_tape, tf.GradientTape() as c_tape:
                att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
                Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[step_size:(step_size + batch_size)],
                                                                    slide_label=slide_label)

                ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
                
                ins_loss = list()
                for j in range(len(ins_logits)):
                    i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                    ins_loss.append(i_loss)
                if mutual_ex:
                    Loss_I = tf.math.add_n(ins_loss) / n_class
                else:
                    Loss_I = tf.math.add_n(ins_loss)

                slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
                
                Loss_B = b_loss_func(Y_true, Y_prob)
                
                Loss_T = c1 * Loss_B + c2 * Loss_I

            i_grad = i_tape.gradient(Loss_I, i_model.trainable_weights)
            i_optimizer.apply_gradients(zip(i_grad, i_model.trainable_weights))

            b_grad = b_tape.gradient(Loss_B, b_model.trainable_weights)
            b_optimizer.apply_gradients(zip(b_grad, b_model.trainable_weights))

            c_grad = c_tape.gradient(Loss_T, c_model.trainable_weights)
            c_optimizer.apply_gradients(zip(c_grad, c_model.trainable_weights))
    
        else:
            with tf.GradientTape() as i_tape, tf.GradientTape() as b_tape, tf.GradientTape() as c_tape:
                att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
                Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[(step_size - n_ins):],
                                                                    slide_label=slide_label)

                ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
                
                ins_loss = list()
                for j in range(len(ins_logits)):
                    i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                    ins_loss.append(i_loss)
                if mutual_ex:
                    Loss_I = tf.math.add_n(ins_loss) / n_class
                else:
                    Loss_I = tf.math.add_n(ins_loss)

                slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
                
                Loss_B = b_loss_func(Y_true, Y_prob)
                
                Loss_T = c1 * Loss_B + c2 * Loss_I

            i_grad = i_tape.gradient(Loss_I, i_model.trainable_weights)
            i_optimizer.apply_gradients(zip(i_grad, i_model.trainable_weights))

            b_grad = b_tape.gradient(Loss_B, b_model.trainable_weights)
            b_optimizer.apply_gradients(zip(b_grad, b_model.trainable_weights))

            c_grad = c_tape.gradient(Loss_T, c_model.trainable_weights)
            c_optimizer.apply_gradients(zip(c_grad, c_model.trainable_weights))
            
        Ins_Loss.append(float(Loss_I))
        Bag_Loss.append(float(Loss_B))
        Total_Loss.append(float(Loss_T))
        
        label_predict.append(predict_label)
        
        step_size += batch_size
    
    I_Loss = statistics.mean(Ins_Loss)
    B_Loss = statistics.mean(Bag_Loss)
    T_Loss = statistics.mean(Total_Loss)
    
    predict_slide_label = most_frequent(label_predict)
    
    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [19]:
def train_step(i_model, b_model, c_model, train_path, i_optimizer_func, b_optimizer_func,
               c_optimizer_func, i_loss_func, b_loss_func, mutual_ex, n_class, c1, c2, 
               i_learn_rate, b_learn_rate, c_learn_rate, i_l2_decay, b_l2_decay, c_l2_decay,
               n_ins, batch_size, batch_op):
    
    loss_total = list()
    loss_ins = list()
    loss_bag = list()

    i_optimizer = i_optimizer_func(learning_rate=i_learn_rate, weight_decay=i_l2_decay)
    b_optimizer = b_optimizer_func(learning_rate=b_learn_rate, weight_decay=b_l2_decay)
    c_optimizer = c_optimizer_func(learning_rate=c_learn_rate, weight_decay=c_l2_decay)

    slide_true_label = list()
    slide_predict_label = list()

    train_sample_list = os.listdir(train_path)
    train_sample_list = random.sample(train_sample_list, len(train_sample_list))
    for i in train_sample_list:
        print('=', end="")
        single_train_data = train_path + i
        img_features, slide_label = get_data_from_tf(single_train_data)
        # shuffle the order of img features list in order to reduce the side effects of randomly drop potential 
        # number of patches' feature vectors during training when enable batch training option
        img_features = random.sample(img_features, len(img_features))
        
        if batch_op:
            I_Loss, B_Loss, T_Loss, predict_slide_label = b_optimize(batch_size=batch_size, n_ins=n_ins, n_samples=len(img_features), 
                                                                     img_features=img_features, slide_label=slide_label, 
                                                                     i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                     i_optimizer=i_optimizer, b_optimizer=b_optimizer, 
                                                                     c_optimizer=c_optimizer, i_loss_func=i_loss_func, 
                                                                     b_loss_func = b_loss_func, n_class=n_class, c1=c1, 
                                                                     c2=c2, mutual_ex=mutual_ex)
        else:
            I_Loss, B_Loss, T_Loss, predict_slide_label = nb_optimize(img_features=img_features, slide_label=slide_label,
                                                                      i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                      i_optimizer=i_optimizer, b_optimizer=b_optimizer, 
                                                                      c_optimizer=c_optimizer, i_loss_func=i_loss_func, 
                                                                      b_loss_func=b_loss_func, n_class=n_class, c1=c1, c2=c2, 
                                                                      mutual_ex=mutual_ex)

        loss_total.append(float(T_Loss))
        loss_ins.append(float(I_Loss))
        loss_bag.append(float(B_Loss))

        slide_true_label.append(slide_label)
        slide_predict_label.append(predict_slide_label)

    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(slide_true_label, slide_predict_label).ravel()
    train_tn = int(tn)
    train_fp = int(fp)
    train_fn = int(fn)
    train_tp = int(tp)

    train_sensitivity = round(train_tp / (train_tp + train_fn), 2)
    train_specificity = round(train_tn / (train_tn + train_fp), 2)
    train_acc = round((train_tp + train_tn) / (train_tn + train_fp + train_fn + train_tp), 2)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(slide_true_label, slide_predict_label, pos_label=1)
    train_auc = round(sklearn.metrics.auc(fpr, tpr), 2)

    train_loss = statistics.mean(loss_total)
    train_ins_loss = statistics.mean(loss_ins)
    train_bag_loss = statistics.mean(loss_bag)

    return train_loss, train_ins_loss, train_bag_loss, train_tn, train_fp, train_fn, train_tp, train_sensitivity, \
           train_specificity, train_acc, train_auc

### Validating CLAM Model

In [20]:
def nb_val(img_features, slide_label, i_model, b_model, c_model, 
           i_loss_func, b_loss_func, n_class, c1, c2, mutual_ex):
    
    att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
    Y_prob, Y_hat, Y_true, predict_slide_label = c_model.call(img_features, slide_label)
 
    ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
    
    ins_loss = list()
    for j in range(len(ins_logits)):
        i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
        ins_loss.append(i_loss)
    if mutual_ex:
        I_Loss = tf.math.add_n(ins_loss) / n_class
    else:
        I_Loss = tf.math.add_n(ins_loss)

    slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = b_model.call(slide_label, A, h)

    B_Loss = b_loss_func(Y_true, Y_prob)
    
    T_Loss = c1 * B_Loss + c2 * I_Loss
    
    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [21]:
def b_val(batch_size, n_ins, n_samples, img_features, slide_label, i_model, b_model,
          c_model, i_loss_func, b_loss_func, n_class, c1, c2, mutual_ex):
    
    step_size = 0
    
    Ins_Loss = list()
    Bag_Loss = list()
    Total_Loss = list()
    
    label_predict = list()
    
    for n_step in range(0, (n_samples // batch_size + 1)):
        if step_size < (n_samples - batch_size):
            att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
            Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[step_size:(step_size + batch_size)],
                                                                slide_label=slide_label)

            ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
            
            ins_loss = list()
            for j in range(len(ins_logits)):
                i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                ins_loss.append(i_loss)
            if mutual_ex:
                Loss_I = tf.math.add_n(ins_loss) / n_class
            else:
                Loss_I = tf.math.add_n(ins_loss)

            slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
            
            Loss_B = b_loss_func(Y_true, Y_prob)
            Loss_T = c1 * Loss_B + c2 * Loss_I
            
        else:
            att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
            Y_prob, Y_hat, Y_true, predict_label = c_model.call(img_features=img_features[(step_size - n_ins):],
                                                                slide_label=slide_label)

            ins_labels, ins_logits_unnorm, ins_logits = i_model.call(slide_label, h, A)
            
            ins_loss = list()
            for j in range(len(ins_logits)):
                i_loss = i_loss_func(tf.one_hot(ins_labels[j], 2), ins_logits[j])
                ins_loss.append(i_loss)
            if mutual_ex:
                Loss_I = tf.math.add_n(ins_loss) / n_class
            else:
                Loss_I = tf.math.add_n(ins_loss)

            slide_score_unnorm, Y_hat, Y_prob, predict_label, Y_true = b_model.call(slide_label, A, h)
            
            Loss_B = b_loss_func(Y_true, Y_prob)
            
            Loss_T = c1 * Loss_B + c2 * Loss_I
        
        Ins_Loss.append(float(Loss_I))
        Bag_Loss.append(float(Loss_B))
        Total_Loss.append(float(Loss_T))

        label_predict.append(predict_label)
        
        step_size += batch_size
    
    I_Loss = statistics.mean(Ins_Loss)
    B_Loss = statistics.mean(Bag_Loss)
    T_Loss = statistics.mean(Total_Loss)
    
    predict_slide_label = most_frequent(label_predict)

    return I_Loss, B_Loss, T_Loss, predict_slide_label

In [22]:
def val_step(i_model, b_model, c_model, val_path, i_loss_func, b_loss_func, mutual_ex, 
             n_class, c1, c2, n_ins, batch_size, batch_op):
    loss_t = list()
    loss_i = list()
    loss_b = list()

    slide_true_label = list()
    slide_predict_label = list()

    val_sample_list = os.listdir(val_path)
    val_sample_list = random.sample(val_sample_list, len(val_sample_list))
    for i in val_sample_list:
        print('=', end="")
        single_val_data = val_path + i
        img_features, slide_label = get_data_from_tf(single_val_data)
        
        img_features = random.sample(img_features, len(img_features)) # follow the training loop, see details there
                                    
        if batch_op:
            I_Loss, B_Loss, T_Loss, predict_slide_label = b_val(batch_size=batch_size, n_ins=n_ins, n_samples=len(img_features),
                                                                img_features=img_features, slide_label=slide_label, 
                                                                i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                i_loss_func=i_loss_func, b_loss_func=b_loss_func, 
                                                                n_class=n_class, c1=c1, c2=c2, mutual_ex=mutual_ex)
        else:    
            I_Loss, B_Loss, T_Loss, predict_slide_label = nb_val(img_features=img_features, slide_label=slide_label, 
                                                                 i_model=i_model, b_model=b_model, c_model=c_model, 
                                                                 i_loss_func=i_loss_func, b_loss_func=b_loss_func, 
                                                                 n_class=n_class, c1=c1, c2=c2, mutual_ex=mutual_ex)

        loss_t.append(float(T_Loss))
        loss_i.append(float(I_Loss))
        loss_b.append(float(B_Loss))

        slide_true_label.append(slide_label)
        slide_predict_label.append(predict_slide_label)

    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(slide_true_label, slide_predict_label).ravel()
    val_tn = int(tn)
    val_fp = int(fp)
    val_fn = int(fn)
    val_tp = int(tp)

    val_sensitivity = round(val_tp / (val_tp + val_fn), 2)
    val_specificity = round(val_tn / (val_tn + val_fp), 2)
    val_acc = round((val_tp + val_tn) / (val_tn + val_fp + val_fn + val_tp), 2)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(slide_true_label, slide_predict_label, pos_label=1)
    val_auc = round(sklearn.metrics.auc(fpr, tpr), 2)

    val_loss = statistics.mean(loss_t)
    val_ins_loss = statistics.mean(loss_i)
    val_bag_loss = statistics.mean(loss_b)

    return val_loss, val_ins_loss, val_bag_loss, val_tn, val_fp, val_fn, val_tp, val_sensitivity, val_specificity, \
           val_acc, val_auc

## Test Optimized CLAM Model

In [23]:
def test_step(n_class, n_ins, att_gate, att_only, mil_ins, mut_ex, i_model, b_model, c_model, 
              test_path, result_path, result_file_name):
    
    start_time = time.time()

    slide_true_label = list()
    slide_predict_label = list()
    sample_names = list()

    for i in os.listdir(test_path):
        print('>', end="")
        single_test_data = test_path + i
        img_features, slide_label = get_data_from_tf(single_test_data)
        
        
        att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, \
        Y_prob, Y_hat, Y_true, predict_slide_label = s_clam_call(att_net=c_model[0], 
                                                                 ins_net=c_model[1], 
                                                                 bag_net=c_model[2], 
                                                                 img_features=img_features, 
                                                                 slide_label=slide_label,
                                                                 n_class=n_class, n_ins=n_ins, 
                                                                 att_gate=att_gate, att_only=att_only, 
                                                                 mil_ins=mil_ins, mut_ex=mut_ex)
        
        ins_labels, ins_logits_unnorm, ins_logits = ins_call(m_ins_classifier=i_model, 
                                                             bag_label=slide_label, 
                                                             h=h, A=A, n_class=n_class, 
                                                             n_ins=n_ins, mut_ex=mut_ex)
        
        slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = s_bag_call(bag_classifier=b_model, 
                                                                                    bag_label=slide_label, 
                                                                                    A=A, h=h, n_class=n_class)
        
        slide_true_label.append(slide_label)
        slide_predict_label.append(predict_slide_label)
        sample_names.append(i)

        test_results = pd.DataFrame(list(zip(sample_names, slide_true_label, slide_predict_label)),
                                    columns=['Sample Names', 'Slide True Label', 'Slide Predict Label'])
        test_results.to_csv(os.path.join(result_path, result_file_name), sep='\t', index=False)

    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(slide_true_label, slide_predict_label).ravel()
    test_tn = int(tn)
    test_fp = int(fp)
    test_fn = int(fn)
    test_tp = int(tp)

    test_sensitivity = round(test_tp / (test_tp + test_fn), 2)
    test_specificity = round(test_tn / (test_tn + test_fp), 2)
    test_acc = round((test_tp + test_tn) / (test_tn + test_fp + test_fn + test_tp), 2)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(slide_true_label, slide_predict_label, pos_label=1)
    test_auc = round(sklearn.metrics.auc(fpr, tpr), 2)

    test_run_time = time.time() - start_time

    template = '\n Test Accuracy: {}, Test Sensitivity: {}, Test Specificity: {}, Test Running Time: {}'
    print(template.format(f"{float(test_acc):.4%}",
                          f"{float(test_sensitivity):.4%}",
                          f"{float(test_specificity):.4%}",
                          "--- %s mins ---" % int(test_run_time / 60)))

## Optimizing CLAM Model

## Training & Validating CLAM Model

In [24]:
def train_val(train_log, val_log, train_path, val_path, i_model, b_model,
               c_model, i_optimizer_func, b_optimizer_func, c_optimizer_func, 
               i_loss_func, b_loss_func, mutual_ex, n_class, c1, c2, 
               i_learn_rate, b_learn_rate, c_learn_rate, 
               i_l2_decay, b_l2_decay, c_l2_decay, n_ins, 
               batch_size, batch_op, epochs):
    
    train_summary_writer = tf.summary.create_file_writer(train_log)
    val_summary_writer = tf.summary.create_file_writer(val_log)

    for epoch in range(epochs):
        # Training Step
        start_time = time.time()

        train_loss, train_ins_loss, train_bag_loss, train_tn, train_fp, train_fn, train_tp, \
        train_sensitivity, train_specificity, train_acc, train_auc = train_step(
            i_model=i_model, b_model=b_model, c_model=c_model, train_path=train_path,
            i_optimizer_func=i_optimizer_func, b_optimizer_func=b_optimizer_func, 
            c_optimizer_func=c_optimizer_func, i_loss_func=i_loss_func, 
            b_loss_func=b_loss_func, mutual_ex=mutual_ex, n_class=n_class, 
            c1=c1, c2=c2, i_learn_rate=i_learn_rate, b_learn_rate=b_learn_rate,
            c_learn_rate=c_learn_rate, i_l2_decay=i_l2_decay, b_l2_decay=b_l2_decay,
            c_l2_decay=c_l2_decay, n_ins=n_ins, batch_size=batch_size, batch_op=batch_op)

        with train_summary_writer.as_default():
            tf.summary.scalar('Total Loss', float(train_loss), step=epoch)
            tf.summary.scalar('Instance Loss', float(train_ins_loss), step=epoch)
            tf.summary.scalar('Bag Loss', float(train_bag_loss), step=epoch)
            tf.summary.scalar('Accuracy', float(train_acc), step=epoch)
            tf.summary.scalar('AUC', float(train_auc), step=epoch)
            tf.summary.scalar('Sensitivity', float(train_sensitivity), step=epoch)
            tf.summary.scalar('Specificity', float(train_specificity), step=epoch)
            tf.summary.histogram('True Positive', int(train_tp), step=epoch)
            tf.summary.histogram('False Positive', int(train_fp), step=epoch)
            tf.summary.histogram('True Negative', int(train_tn), step=epoch)
            tf.summary.histogram('False Negative', int(train_fn), step=epoch)

        # Validation Step
        val_loss, val_ins_loss, val_bag_loss, val_tn, val_fp, val_fn, val_tp, \
        val_sensitivity, val_specificity, val_acc, val_auc = val_step(
            i_model=i_model, b_model=b_model, c_model=c_model, val_path=val_path,
            i_loss_func=i_loss_func, b_loss_func=b_loss_func, mutual_ex=mutual_ex, 
            n_class=n_class, c1=c1, c2=c2, n_ins=n_ins, batch_size=batch_size, batch_op=batch_op)

        with val_summary_writer.as_default():
            tf.summary.scalar('Total Loss', float(val_loss), step=epoch)
            tf.summary.scalar('Instance Loss', float(val_ins_loss), step=epoch)
            tf.summary.scalar('Bag Loss', float(val_bag_loss), step=epoch)
            tf.summary.scalar('Accuracy', float(val_acc), step=epoch)
            tf.summary.scalar('AUC', float(val_auc), step=epoch)
            tf.summary.scalar('Sensitivity', float(val_sensitivity), step=epoch)
            tf.summary.scalar('Specificity', float(val_specificity), step=epoch)
            tf.summary.histogram('True Positive', int(val_tp), step=epoch)
            tf.summary.histogram('False Positive', int(val_fp), step=epoch)
            tf.summary.histogram('True Negative', int(val_tn), step=epoch)
            tf.summary.histogram('False Negative', int(val_fn), step=epoch)

        epoch_run_time = time.time() - start_time
        template = '\n Epoch {},  Train Loss: {}, Train Accuracy: {}, Val Loss: {}, Val Accuracy: {}, Epoch Running ' \
                   'Time: {} '
        print(template.format(epoch + 1,
                              f"{float(train_loss):.8}",
                              f"{float(train_acc):.4%}",
                              f"{float(val_loss):.8}",
                              f"{float(val_acc):.4%}",
                              "--- %s mins ---" % int(epoch_run_time / 60)))

### Main Function to Optimizing and Testing CLAM Model

### Test the Optimized CLAM Model by Saving the Trained CLAM Model

In [25]:
def clam_optimize(train_log, val_log, train_path, val_path, i_model, b_model,
                  c_model, i_optimizer_func, b_optimizer_func, c_optimizer_func, 
                  i_loss_func, b_loss_func, mutual_ex, n_class, c1, c2, 
                  i_learn_rate, b_learn_rate, c_learn_rate, i_l2_decay, b_l2_decay,
                  c_l2_decay, n_ins, batch_size, batch_op, i_model_dir, b_model_dir, 
                  c_model_dir, m_bag_op, m_clam_op, g_att_op, epochs):
    
    train_val(train_log=train_log, val_log=val_log, train_path=train_path,
               val_path=val_path, i_model=i_model, b_model=b_model, c_model=c_model,
               i_optimizer_func=i_optimizer_func, b_optimizer_func=b_optimizer_func,
               c_optimizer_func=c_optimizer_func, i_loss_func=i_loss_func,
               b_loss_func=b_loss_func, mutual_ex=mutual_ex, n_class=n_class, 
               c1=c1, c2=c2, i_learn_rate=i_learn_rate, b_learn_rate=b_learn_rate,
               c_learn_rate=c_learn_rate, i_l2_decay=i_l2_decay, b_l2_decay=b_l2_decay,
               c_l2_decay=c_l2_decay, n_ins=n_ins, batch_size=batch_size, 
               batch_op=batch_op, epochs=epochs)
    
    model_save(i_model=i_model, b_model=b_model, c_model=c_model, 
               i_model_dir=i_model_dir, b_model_dir=b_model_dir, 
               c_model_dir=c_model_dir, n_class=n_class, m_bag_op=m_bag_op, 
               m_clam_op=m_clam_op, g_att_op=g_att_op)

In [26]:
def clam_test(n_class, n_ins, att_gate, att_only, mil_ins, mut_ex, test_path, 
              result_path, result_file_name, i_model_dir, b_model_dir, c_model_dir, 
              m_bag_op, m_clam_op):
    
    i_trained_model, b_trained_model, c_trained_model = restore_model(i_model_dir=i_model_dir, 
                                                                      b_model_dir=b_model_dir, 
                                                                      c_model_dir=c_model_dir, 
                                                                      n_class=n_class, m_bag_op=m_bag_op, 
                                                                      m_clam_op=m_clam_op, g_att_op=att_gate)
    
    test_step(n_class=n_class, n_ins=n_ins, 
              att_gate=att_gate, att_only=att_only, 
              mil_ins=mil_ins, mut_ex=mut_ex, 
              i_model=i_trained_model, 
              b_model=b_trained_model, 
              c_model=c_trained_model, 
              test_path=test_path, 
              result_path=result_path, 
              result_file_name=result_file_name)