## Import Required Packages

In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
from tqdm import tqdm
import pandas as pd
import sklearn
from sklearn import metrics
import re
import numpy as np
import pickle as pkl
import PIL
import datetime
import os
import random
import shutil
import statistics
import time
import import_ipynb

## Import Required Functions or Methods from Other Files

In [2]:
from util import *

importing Jupyter notebook from util.ipynb
importing Jupyter notebook from model.ipynb
importing Jupyter notebook from optimize_test.ipynb


In [3]:
from model import *

In [4]:
from optimize_test import *

## Saving & Restoring CLAM Model Training Checkpoints

### Loading Models for Training

In [5]:
ng_att = NG_Att_Net(dim_features=1024, dim_compress_features=512, n_hidden_units=256, n_classes=2,
                 dropout=False, dropout_rate=.25)

g_att = G_Att_Net(dim_features=1024, dim_compress_features=512, n_hidden_units=256, n_classes=2,
                 dropout=False, dropout_rate=.25)

In [6]:
ins = Ins(dim_compress_features=512, n_class=2, n_ins=8, mut_ex=True)

In [7]:
s_bag = S_Bag(dim_compress_features=512, n_class=2)

m_bag = M_Bag(dim_compress_features=512, n_class=2)

In [8]:
s_clam = S_CLAM(att_gate=True, net_size='big', n_ins=8, n_class=2, mut_ex=False,
            dropout=True, drop_rate=.55, mil_ins=True, att_only=False)

m_clam = M_CLAM(att_gate=True, net_size='big', n_ins=8, n_class=2, mut_ex=False,
            dropout=True, drop_rate=.55, mil_ins=True, att_only=False)

### Loading Required Path

In [10]:
train_is_bach = '/path/'
val_is_bach = '/path/'
test_is_bach = '/path/'

In [13]:
clam_result_dir = '/path/'

In [14]:
i_trained_model_dir = '/path/'
b_trained_model_dir = '/path/'
c_trained_model_dir = '/path/'

In [15]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = '/path/' + current_time + '/train'
val_log_dir = '/path/' + current_time + '/val'

## Start Training, Validating & Testing CLAM Model

In [16]:
tf_shut_up(no_warn_op=True)

In [18]:
clam_optimize(train_log=train_log_dir, val_log=val_log_dir, 
              train_path=train_is_bach, val_path=val_is_bach, 
              i_model=ins, b_model=s_bag, c_model=s_clam, 
              i_optimizer_func=tfa.optimizers.AdamW, 
              b_optimizer_func=tfa.optimizers.AdamW, 
              c_optimizer_func=tfa.optimizers.AdamW, 
              i_loss_func=tf.keras.losses.binary_crossentropy, 
              b_loss_func=tf.keras.losses.binary_crossentropy, 
              mutual_ex=False, n_class=2, c1=0.7, c2=0.3, 
              i_learn_rate=2e-04, b_learn_rate=2e-04, c_learn_rate=2e-04,
              i_l2_decay=1e-05, b_l2_decay=1e-05, c_l2_decay=1e-05,
              n_ins=8, batch_size=2000, batch_op=False, 
              i_model_dir=i_trained_model_dir, 
              b_model_dir=b_trained_model_dir, 
              c_model_dir=c_trained_model_dir, 
              m_bag_op=False, m_clam_op=False, g_att_op=True, epochs=200)

In [16]:
clam_test(n_class=2, n_ins=8, att_gate=True, att_only=False, mil_ins=True, mut_ex=False, 
          test_path=test_is_bach, result_path=clam_result_dir, 
          result_file_name='test_bach_model_save.tsv', 
          i_model_dir=i_trained_model_dir, 
          b_model_dir=b_trained_model_dir, 
          c_model_dir=c_trained_model_dir, 
          m_bag_op=False, m_clam_op=False)