![](https://i.ibb.co/tbW6pxF/cover2-01.jpg)

<p style='text-align: center;'><span style="color: #000508; font-family: Segoe UI; font-size: 2.5em; font-weight: 300;">[TPU] RSNA Keras 3D CNN Voxel Classifier Train</span></p>

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">Overview</span>
<br>
<br>
&ensp;‚úîÔ∏è&ensp;Fast Training with Large Voxel Batch Sizes on TPU<br>
&ensp;‚úîÔ∏è&ensp;TPU Mixed Precision<br>
&ensp;‚úîÔ∏è&ensp;Custom Keras Training Loop<br>
&ensp;‚úîÔ∏è&ensp;3D Augmentations on TPU<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Flip<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Gamma<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Brightness<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Contrast<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Cutout<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Rotate<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Random Resized Crop<br>
&emsp;&emsp;&emsp;‚úîÔ∏è&ensp;Blur<br>
&ensp;‚úîÔ∏è&ensp;Experiment Tracking & Interactive Visualizations with Weights & Biases<br>

<br>

This competition presents a unique challenge of predicting the status of a MGMT (O[6]-methylguanine-DNA methyltransferase genetic biomarker, important for the success of brain cancer treatment as a favorable prognostic factor and a strong predictor of responsiveness to chemotherapy.
<br>
<br>
With this challenge, we could potentially build solutions that can minimize the number of exploratory invasive procedures for a diagnosis, to target and refine the treatments required, by making use of MRI imaging techiniques.
<br>
<br>
Multiples exploratory surgeries for the characterization of tumour genetic biomarkers is both detrimental and may delay the right treatment upto several weeks, which is cutting short the average span of 5 years for people diagnosed with brain cancer.
<br>
<br>
Let's look at how we can train 3D MRI scan voxels on TPUs. The following processing were performed to create the voxels:<br>

‚ÄÇ‚úîÔ∏è‚ÄÇAlign Scans Across the Volume & Crop<br>
‚ÄÇ‚úîÔ∏è‚ÄÇFilter Out Slices with Less Information<br>
‚ÄÇ‚úîÔ∏è‚ÄÇCLAHE Contrast Enhancement & Normalization Across the 3D Volume<br>
‚ÄÇ‚úîÔ∏è‚ÄÇResampled to size 64 x 256 x 256 for Optimal Training
<br>
<br>

Augmentations are applied by preserving the proper geometric alignment between planes and we make sure to uniformly apply the same operation to every slice. Please refer to the dataset generation and augmentation notebook for more details.
<br>
<br>
TPUs complete training at a quicker pace and enables fitting more batches into memory comapared to the small volumes and batches that can be fit on GPUs. This is even more useful as entire volumes need to be fit into memory, rather than single image samples.
<br>
<br>
We also make use of mixed precision to make the training run faster and consume less memory, by using 16-bit floating-point calculations wherever it's appropriate.
<br>
<br>


<p style='text-align: left;'><span style="color: #000508; font-family: Segoe UI; font-size: 1.1em; font-weight: 600;"> üè∑Ô∏è Notebook with Dataset Generation & Augmentations</span></p>

&emsp;&emsp;[RSNA 3D CLAHE Voxels + TPU 3D AugmentationsüåÉüöÖ](https://www.kaggle.com/sreevishnudamodaran/rsna-3d-clahe-voxels-tpu-3d-augmentations)<br>

<p style='text-align: left;'><span style="color: #000508; font-family: Segoe UI; font-size: 1.1em; font-weight: 600;"> üè∑Ô∏è Dataset with Processed 64 x 256 x 256 Voxels</span></p>

&emsp;&emsp;[RSNA Processed Voxels 64x256x256](https://www.kaggle.com/sreevishnudamodaran/rsna-processed-voxels-64x256x256)<br>

<p style='text-align: left;'><span style="color: #000508; font-family: Segoe UI; font-size: 1.1em; font-weight: 600;"> üè∑Ô∏è Dataset with Processed 64 x 256 x 256 Voxels with CLAHE</span></p>

&emsp;&emsp;[RSNA Processed Voxels 64x256x256 CLAHE](https://www.kaggle.com/sreevishnudamodaran/rsna-processed-voxels-64x256x256-clahe)

<br>
<br>
<span style="float:center;"><a href="https://www.kaggle.com/sreevishnudamodaran"><img style="padding: 5px;" border="0" alt="Ask Me Something" src="https://img.shields.io/badge/Ask%20me-something-7a43bc.svg?style=for-the-badge&logo=kaggle" width="160" height="20"></a><br>
    <img style="padding: 5px;" border="0" alt="Ask Me Something" src="https://img.shields.io/badge/Please-Upvote%20If%20you%20like%20this-16a5e9?style=for-the-badge&logo=kaggle" width="250" height="20"></span>
<br>

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">‚öóÔ∏è Imports & Setup</span>

In [None]:
!pip install /kaggle/input/kerasapplications -q
!pip install /kaggle/input/keras-3d-model-and-3d-augmentation/classification_models_3D-1.0.2-py3-none-any.whl -q

In [None]:
import itertools
import os
import random
import gc
import glob
import re
import math
import time

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pprint
import matplotlib.pylab as plt
import matplotlib.ticker as mtick
import seaborn as sns

from sklearn.model_selection import train_test_split, GroupKFold
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow import keras
import tensorflow_addons as tfa
from tensorflow.keras.utils import Progbar

import tensorflow.experimental.numpy as tnp

from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient
from IPython.display import IFrame
from IPython.core.display import display, HTML
import imageio

In [None]:
class Config:
    seed = 42
    job = 1
    num_classes = 1
    #     dataset = '/kaggle/input/rsna-processed-voxels-64x256x256' # No CLAHE
    dataset = '/kaggle/input/rsna-processed-voxels-64x256x256-clahe'
    input_dims = (64, 256, 256)
    voxel_dtype = tf.uint8
    scan_types = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
    batch_size = 32
    kfold = 5
    n_epochs = 15
    lr = 0.00005
    fp16= True
    label_smoothing = 0.1
    wandb_project = 'RSNA-train-public'
    
cfg = Config()

In [None]:
def seed_everything(SEED):
    os.environ['PYTHONHASHSEED']=str(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    os.environ['TF_CUDNN_DETERMINISTIC'] = str(SEED)

seed_everything(cfg.seed)

In [None]:
split = 'train'
train_voxels = sorted(glob.glob(f"{cfg.dataset}/voxels/*/*.npy"))

df_train = pd.DataFrame(train_voxels, columns=['voxel_paths'])
df_train['BraTS21ID'] = df_train.voxel_paths.map(lambda path:path.split('/')[-1].strip('.npy'))
df_train['scan_type'] = df_train.voxel_paths.map(lambda path:path.split('/')[-2])
df_train.sample(3)

In [None]:
df_train_labels = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv',
                             dtype={'BraTS21ID':np.object,
                                   'MGMT_value':np.int32})
df_train = df_train.set_index('BraTS21ID').join(df_train_labels.set_index('BraTS21ID'), on='BraTS21ID', how='left')
df_train = df_train.reset_index()
df_train.to_csv("/kaggle/working/train_df_meta.csv")
df_train.head(3)

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">K-fold Split</span>

In [None]:
scan_type = 'FLAIR'

df_kfold = df_train.drop_duplicates().reset_index(drop=True)

kfold = cfg.kfold
df_kfold['fold'] = -1
group_kfold  = GroupKFold(n_splits = kfold)

for fold, (train_index, val_index) in enumerate(group_kfold.split(df_kfold[['BraTS21ID',
                                                                            'scan_type',
                                                                           'voxel_paths']],
                                                                  df_kfold['MGMT_value'],
                                                                  groups=df_kfold.BraTS21ID.tolist())):
    df_kfold.loc[val_index, 'fold'] = fold
    
display(df_kfold.head(3))
df_kfold.to_csv("/kaggle/working/df_kfold.csv")

for fold in range(cfg.kfold):
    train_paths = df_kfold[(df_kfold.scan_type==scan_type) & \
                           (df_kfold['fold'] != fold)].voxel_paths.unique()
    valid_paths = df_kfold[(df_kfold.scan_type==scan_type) & \
                           (df_kfold['fold'] == fold)].voxel_paths.unique()

    print(f"Fold:{fold}\nSplit Counts\nTrain Voxels:\t\t{len(train_paths)}\nVal Voxels:\t\t{len(valid_paths)}",
          end='\n\n')

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">Create Dataloader with TPU Augmentations</span>

In [None]:
def get_npy_header_cnt(sample_voxel, dtype=tf.uint8):
    img = np.load(sample_voxel)
    tf_img = tf.io.read_file(sample_voxel)
    tf_img = tf.io.decode_raw(tf_img, tf.uint8)
    return tf_img.shape[0]-int(np.prod(img.shape))

npy_header_size = get_npy_header_cnt(df_kfold.voxel_paths[0],
                                     cfg.voxel_dtype)
npy_header_size

In [None]:
voxel_dtype = cfg.voxel_dtype
seed = cfg.seed

FLIP = 0.5 # @params: probability
CONTRAST = (0.7, 1.3, 0.5) # @params: (minval, maxval, probability)
BRIGHTNESS = (0.2, 0.4) # @params: (delta, probability)
GAMMA = (0.8, 1.2, 0.5) # @params: (minval, maxval, probability)
ROTATE = (20, 0.5) # @params: (maxangle, probability)
RANDOM_CROP = (180, 180, 0.4) # @params: (min_width, min_height, probability)
CUTOUT = ((15, 15), 15, 0.4) # @params: ((mask_dim0, mask_dim1), max_num_holes, probability)
BLUR = ([5, 5], 10, 0.4) # @params: ((filter_dim0, filter_dim1), sigma, probability)

def build_decoder(with_labels=True, target_size=(64, 256, 256), ext='npy'):
    def decode(path):
        if ext == 'npy':
            voxel = tf.io.read_file(path)
            voxel = tf.io.decode_raw(voxel, voxel_dtype)
            voxel = voxel[npy_header_size:]
        else:
            raise ValueError("voxel extension not supported")
        
        voxel = tf.cast(voxel, tf.float32)/255.0
        voxel = tf.reshape(voxel, target_size)
        voxel = tf.expand_dims(voxel, axis=-1)

        return voxel
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode

def random_rotate3D(voxel, limit=90, p=0.5):
    if tf.random.uniform(()) < p:
        angle = tf.random.uniform((), minval=-limit, maxval=limit,
                                  dtype=tf.int32)
        voxel = tfa.image.rotate(voxel, tf.cast(angle,
                                                tf.float32)*(math.pi/180),
                                 interpolation='nearest',
                                 fill_mode='constant',
                                 fill_value=0.0)
    return voxel

def random_resized_crop3D(voxel, min_width, min_height, p=0.5):
    if tf.random.uniform(()) < p:
        voxel_shape = voxel.shape
        assert voxel_shape[1] >= min_height
        assert voxel_shape[2] >= min_width
        
        width = tf.random.uniform((), minval=min_width,
                                  maxval=voxel_shape[2],
                                  dtype=tf.int32)
        height = tf.random.uniform((), minval=min_height,
                                   maxval=voxel_shape[1],
                                   dtype=tf.int32)
        x = tf.random.uniform((), minval=0,
                              maxval=voxel_shape[2] - width,
                              dtype=tf.int32)
        y = tf.random.uniform((), minval=0,
                              maxval=voxel_shape[1] - height,
                              dtype=tf.int32)
        voxel = voxel[:, y:y+height, x:x+width, :]
        voxel = tf.image.resize(voxel,
                                voxel_shape[1:3],
                                method='lanczos5')
    return voxel

def random_cutout3D(voxel, mask_shape=(10, 10), num_holes=20, p=0.5):
    if tf.random.uniform(()) < p:
        voxel_shape = voxel.shape
        assert voxel_shape[1] >= mask_shape[0]
        assert voxel_shape[2] >= mask_shape[1]

        holes = tf.random.uniform((), minval=1, maxval=num_holes,
                                  dtype=tf.int32)
        mask_size = tf.constant([mask_shape[0], mask_shape[1]])
        mask = tf.Variable((lambda : tf.ones(voxel_shape)),
                           trainable=False)
        
        for i in tf.range(holes):
            x = tf.random.uniform((), minval=0,
                                  maxval=voxel_shape[2],
                                  dtype=tf.int32)
            y = tf.random.uniform((), minval=0,
                                  maxval=voxel_shape[1],
                                  dtype=tf.int32)
            mask_endx = tf.add(x, mask_size[1])
            mask_endy = tf.add(y, mask_size[0])
            mask[:, x:mask_endx,
                 y:mask_endy, :].assign(tf.zeros_like(mask[:, x:mask_endx,
                                                        y:mask_endy, :]))
        voxel = tf.multiply(voxel, mask)
        mask.assign(tf.ones(voxel_shape))
    return voxel

def _get_gaussian_kernel(sigma, filter_shape):
    x = tf.range(-filter_shape // 2 + 1, filter_shape // 2 + 1)
    x = tf.cast(x ** 2, sigma.dtype)
    x = tf.nn.softmax(-x / (2.0 * (sigma ** 2)))
    return x

def random_gaussian_blur3D(voxel, filter_shape=[5, 5], max_sigma=3, p=0.5):
    if tf.random.uniform(()) < p:
        sigma = tf.random.uniform((), minval=3, maxval=max_sigma,
                          dtype=tf.int32)
        filter_shape = tf.constant(filter_shape)
        channels = voxel.shape[-1]
        sigma = tf.cast(sigma, voxel.dtype)
        gaussian_kernel_x = _get_gaussian_kernel(sigma,
                                                 filter_shape[1])
        gaussian_kernel_x = gaussian_kernel_x[tf.newaxis, :]
        gaussian_kernel_y = _get_gaussian_kernel(sigma,
                                                 filter_shape[0])        
        gaussian_kernel_y = gaussian_kernel_y[:, tf.newaxis]
        gaussian_kernel_2d = tf.matmul(gaussian_kernel_y,
                                       gaussian_kernel_x)
        gaussian_kernel_2d = gaussian_kernel_2d[:, :, tf.newaxis,
                                                tf.newaxis]
        gaussian_kernel_2d = tf.tile(gaussian_kernel_2d,
                                     tf.constant([1, 1, channels, 1]))
        voxel = tf.nn.depthwise_conv2d(input=voxel,
                                       filter=gaussian_kernel_2d,
                                       strides=(1, 1, 1, 1),
                                       padding="SAME",
                                       )
        voxel = tf.cast(voxel, voxel.dtype)
    return voxel

def build_augmenter(with_labels=True):
    '''
    Performing tranformations with the same seed
    to ensure the same tranformation is applied to every voxel slice.
    ''' 
    def augment(voxel):
        aug_seed = tf.random.uniform((2,), minval=1, maxval=9999, dtype=tf.int32)
        if tf.random.uniform(()) < FLIP:
            if tf.random.uniform(()) < 0.5:
                voxel = tf.image.flip_up_down(voxel)
            else:
                voxel = tf.image.flip_left_right(voxel)
                
        if tf.random.uniform(()) < BRIGHTNESS[1]:
            voxel = tf.image.adjust_brightness(
                voxel, tf.random.uniform((), minval=0.0,
                                         maxval=BRIGHTNESS[0],
                                         seed=seed))
        if tf.random.uniform(()) < CONTRAST[2]:
            voxel = tf.image.adjust_contrast(
                voxel, tf.random.uniform((), minval=CONTRAST[0],
                                         maxval=CONTRAST[1],
                                         seed=seed))
        if tf.random.uniform(()) < GAMMA[2]:
            voxel = tf.image.adjust_gamma(
                voxel, tf.random.uniform((), minval=GAMMA[0],
                                         maxval=GAMMA[1],
                                         seed=seed))
        voxel = random_rotate3D(voxel, limit=ROTATE[0],
                                p=ROTATE[1])
        voxel = random_resized_crop3D(voxel, RANDOM_CROP[0],
                                      RANDOM_CROP[1],
                                      p=RANDOM_CROP[2])
        voxel = random_cutout3D(voxel, mask_shape=CUTOUT[0],
                                num_holes=CUTOUT[1],
                                p=CUTOUT[2])
        voxel = random_gaussian_blur3D(voxel,
                                       filter_shape=BLUR[0],
                                       max_sigma=BLUR[1],
                                       p=BLUR[2])

        # Remove nans in place of black pixels in some imgs.
        # Anyone knows the reason for the nans?
        voxel = tf.where(tf.math.is_nan(voxel),
                         tf.zeros_like(voxel), voxel)
        voxel = tnp.maximum(tnp.array([0.]), voxel)
        voxel = tnp.minimum(tnp.array([1.]), voxel)
        voxel = tf.cast(voxel, tf.float32)
        return voxel
    
    def augment_with_labels(voxel, label):
        return augment(voxel), label
    
    return augment_with_labels if with_labels else augment

def build_dataset(paths, labels=None, bsize=32, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=128,
                  seed=None, cache_dir=""):
    
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else ds
    
    ## Map the functions to perform Augmentations
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle, seed=seed) if shuffle else dset
    dset = dset.batch(bsize, drop_remainder=True).prefetch(AUTO)
#     dset = dset.batch(bsize).prefetch(AUTO)
    return dset

In [None]:
scan_type = 'FLAIR'
label_cols = 'MGMT_value'
train_paths = df_kfold.loc[(df_kfold.scan_type==scan_type) & (df_kfold['fold'] != fold), :].voxel_paths.values
train_labels = df_kfold.loc[(df_kfold.scan_type==scan_type) & (df_kfold['fold'] != fold), :].MGMT_value.values

decoder = build_decoder(with_labels=True, target_size=cfg.input_dims, ext='npy')

train_dataset = build_dataset(
    train_paths, train_labels, bsize=32, decode_fn=decoder,
    augment=True, repeat=False, shuffle=False
)

train_voxels, _ = next(iter(train_dataset))

sqrt = math.ceil(math.sqrt(train_voxels.shape[0]))
fig = plt.figure(figsize=(10, 10))

for idx, voxel in enumerate(train_voxels):
    ax = fig.add_subplot(int(sqrt), int(sqrt), idx+1)
    ax.imshow(voxel.numpy()[voxel.shape[0]//2], cmap='gray')
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
fig.suptitle("Middle Slice of Each Voxel", fontsize=22)
fig.tight_layout()
plt.savefig(f'samples.png')
fig.show()

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">Interactive Visualizations with Weights & Biases</span>

In [None]:
!pip install wandb --upgrade -q

import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
key = user_secrets.get_secret("wandb_key")
wandb.login(key=key)

In [None]:
run = wandb.init(project='RSNA-public', job_type='augmentations-viz',
                name='dataloader-samples')
os.makedirs('gifs/', exist_ok=True)

for i, voxel in tqdm(enumerate(train_voxels), total=train_voxels.shape[0]):
    gifs = (voxel*255.).numpy().clip(0, 255).astype('uint8')
    imageio.mimsave(f'gifs/voxel_{i}.gif', gifs)    

wandb.log({'images': [wandb.Image(f'gifs/voxel_{i}.gif') for i in range(train_voxels.shape[0])]})
run.finish()

In [None]:
display(HTML(f'<a href={run.get_url()} style="font-family: Segoe UI;font-size: 1.5em; font-weight: 300;">View the Full Dashboard Here üîé</a>'))
print()
display(HTML(f"<div style='width: 780px; height: 500px; padding: 0; overflow: hidden;'><iframe src='{run.get_url()}', width=1010, height=650, style='-ms-zoom: 0.75; -moz-transform: scale(0.75); -moz-transform-origin: 0 0; -o-transform: scale(0.75); -o-transform-origin: 0 0; -webkit-transform: scale(0.75); -webkit-transform-origin: 0 0;'></iframe></div>"))

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">Learning Rate Schedule</span>

In [None]:
START_LR = 1e-6
MAX_LR = cfg.lr
MIN_LR = 1e-8
LR_RAMP = 3
LR_SUSTAIN = 5
LR_DECAY = 0.90

def lrfn(epoch):
    if LR_RAMP > 0 and epoch < LR_RAMP:
        lr = (MAX_LR-START_LR)/(LR_RAMP*1.0)*epoch+START_LR
    elif epoch < LR_RAMP+LR_SUSTAIN:
        lr = MAX_LR
    else: # exponential decay from MAX_LR to MIN_LR
        lr = (MAX_LR-MIN_LR)*LR_DECAY**(epoch-LR_RAMP-LR_SUSTAIN)+MIN_LR
    return lr
    
@tf.function
def lrfn_tffun(epoch):
    return lrfn(epoch)

fig = plt.figure(figsize=(14, 5))
ax = fig.add_subplot(111)
rng = [i for i in range(cfg.n_epochs)]
plt.plot(rng, [lrfn(x) for x in rng],
         marker='o')
plt.ticklabel_format(axis="y", style="plain")
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
plt.show()

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">Setup Mixed Precision</span>

In [None]:
from tensorflow.keras import mixed_precision

if cfg.fp16:
    # Using 'mixed_bfloat16' compatible with TPUs
    mixed_precision.set_global_policy('mixed_bfloat16')

    print('Mixed precision enabled')

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">Define Model Architecure</span>

We use the models from the awesome [classification_models_3D](https://github.com/ZFTurbo/classification_models_3D) library by [@ZFTurbo](https://www.kaggle.com/zfturbo)

* If mixed precision is enabled, the output layer of the model must be cast into float32 type for numerical stability

<br>

#### Supported Nets:

<table class="tableizer-table" style="align:left; display:block">
<thead><tr class="tableizer-firstrow"><th></th><th>&nbsp;</th><th>&nbsp;</th><th>&nbsp;</th></tr></thead><tbody>
 <tr><td>resnet18</td><td>seresnet34</td><td>senet154</td><td>densenet169</td></tr>
 <tr><td>resnet34</td><td>seresnet50</td><td>resnext50</td><td>densenet201</td></tr>
 <tr><td>resnet50</td><td>seresnet101</td><td>resnext101</td><td>inceptionresnetv2</td></tr>
 <tr><td>resnet101</td><td>seresnet152</td><td>vgg16</td><td>inceptionv3</td></tr>
 <tr><td>resnet152</td><td>seresnext50</td><td>vgg19</td><td>mobilenet</td></tr>
 <tr><td>seresnet18</td><td>seresnext101</td><td>densenet121</td><td>mobilenetv2</td></tr>
</tbody></table>

<br>

In [None]:
from classification_models_3D.tfkeras import Classifiers

model_arch = 'seresnet50'

def create_model(input_shape, num_classes):
    inputs = tf.keras.layers.Input((*input_shape, 1), name='inputs')
    x = tf.keras.layers.Conv3D(3, (3, 3, 3), strides=(1, 1, 1), 
                          padding='same', use_bias=True)(inputs)
    
    net, preprocess_input = Classifiers.get(model_arch)
    x = net(input_shape=(*input_shape, 3), include_top=False,
                   weights='imagenet')(x)
    x = tf.keras.layers.GlobalAveragePooling3D()(x)
    x = tf.keras.layers.Dropout(rate=0.5)(x)
    
    # Cast output to float32 for numerical stability
    outputs = tf.keras.layers.Dense(num_classes, activation='sigmoid',
                                   dtype='float32')(x)
    model  = tf.keras.Model(inputs, outputs)
    print(model.summary())
    
    return model

# model = create_model(cfg.input_dims, 1)

<span style="color: #1e009c; font-family: Segoe UI; font-size: 1.8em; font-weight: 300;">üß™ Train with Custom Training Loop</span>

In [None]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy

In [None]:
fold = 0
DISPLAY_PLOT = True

oof_aucs = dict()

for scan_type in cfg.scan_types:
    print("\nScan Type:", scan_type)
    # Define TPU strategy and clear TPU
    strategy = auto_select_accelerator()
    
    user_secrets = UserSecretsClient()
    user_credential = user_secrets.get_gcloud_credential()
    user_secrets.set_tensorflow_credential(user_credential)
    
    GCS_DS_PATH = KaggleDatasets().get_gcs_path(cfg.dataset.split('/')[-1])
    print("GCS_DS_PATH", GCS_DS_PATH)

    train_paths = df_kfold.loc[(df_kfold.scan_type==scan_type) & \
                               (df_kfold['fold'] != fold), :].voxel_paths.map( \
        lambda path: os.path.join(GCS_DS_PATH+'/' , *path.split('/')[4:])).values
    valid_paths = df_kfold.loc[(df_kfold.scan_type==scan_type) & \
                               (df_kfold['fold'] == fold), :].voxel_paths.map( \
        lambda path: os.path.join(GCS_DS_PATH+'/' , *path.split('/')[4:])).values
    
    train_labels = df_kfold.loc[(df_kfold.scan_type==scan_type) & \
                                (df_kfold['fold'] != fold), :].MGMT_value.values.reshape(-1,1)
    valid_labels = df_kfold.loc[(df_kfold.scan_type==scan_type) & \
                                (df_kfold['fold'] == fold), :].MGMT_value.values.reshape(-1,1)

    # Converting global config class object to a dictionary to log using Wandb
    config_dict = dict(vars(Config))
    config_dict = {k:(v if type(v)==int else str(v)) for (k,v) in config_dict.items() if '__' not in k}
    config_dict['fold'] = fold
    config_dict['model_arch'] = model_arch
    config_dict['job_name'] = f"{config_dict['model_arch']}_fold{fold}_{scan_type}_job{config_dict['job']}"
    print("Train Job:", config_dict['job_name'], "\nConfig")
    pprint.pprint(config_dict)

    run = wandb.init(project=cfg.wandb_project, name=config_dict['job_name'],
               config=config_dict)

    decoder = build_decoder(with_labels=True, target_size=cfg.input_dims, ext='npy')

    train_dataset = build_dataset(
        train_paths, train_labels, bsize=cfg.batch_size,
        decode_fn=decoder, augment=True, shuffle=256
    )
    
    val_batch_size = 8
    valid_dataset = build_dataset(
        valid_paths, valid_labels, bsize=val_batch_size,
        decode_fn=decoder, repeat=False, shuffle=False, augment=False
    )

    train_steps_per_epoch = train_paths.shape[0]//cfg.batch_size
    valid_steps_per_epoch = valid_paths.shape[0]//val_batch_size
    print()

    with strategy.scope():
        model = create_model(cfg.input_dims, cfg.num_classes)

        # Custom learning rate schedule
        class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
            def __call__(self, step):
                return lrfn_tffun(epoch=step//train_steps_per_epoch)

        optimizer = tf.keras.optimizers.Adam(learning_rate=LRSchedule())
#         optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.lr)

        train_accuracy = tf.keras.metrics.BinaryAccuracy(name="acc")
        train_precision = tf.keras.metrics.Precision(name='precision')
        train_recall = tf.keras.metrics.Recall(name='recall')
        train_roc_auc = tf.keras.metrics.AUC(curve='ROC', name="auc")
        train_map = tf.keras.metrics.AUC(curve='PR', name="map")

        valid_accuracy = tf.keras.metrics.BinaryAccuracy(name="val_acc")
        valid_precision = tf.keras.metrics.Precision(name='precision')
        valid_recall = tf.keras.metrics.Recall(name='recall')
        valid_roc_auc = tf.keras.metrics.AUC(curve='ROC', name="val_auc")
        valid_map = tf.keras.metrics.AUC(curve='PR', name="val_map")
        
        # Apply label smoothing to training
        loss_fn = lambda labels, probabilities: tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(
                labels,
                probabilities,
                label_smoothing=cfg.label_smoothing))
        
        val_loss_fn = lambda labels, probabilities: tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(labels,
                                                probabilities))
        
        @tf.function
        def train_step(images, labels):
            with tf.GradientTape() as tape:
                probabilities = model(images, training=True)
                loss = loss_fn(labels, probabilities)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            train_accuracy.update_state(labels, probabilities)
            train_precision.update_state(labels, probabilities)
            train_recall.update_state(labels, probabilities)
            train_roc_auc.update_state(labels, probabilities)
            train_map.update_state(labels, probabilities)
            return loss

        @tf.function
        def valid_step(images, labels):
            probabilities = model(images, training=False)
            loss = val_loss_fn(labels, probabilities)
            valid_accuracy.update_state(labels, probabilities)
            valid_precision.update_state(labels, probabilities)
            valid_recall.update_state(labels, probabilities)
            valid_roc_auc.update_state(labels, probabilities)
            valid_map.update_state(labels, probabilities)
            return loss

        # Distribute the datset according to the strategy
        train_dist_ds = strategy.experimental_distribute_dataset(train_dataset)
        valid_dist_ds = strategy.experimental_distribute_dataset(valid_dataset)
        
        bar_stateful_metrics = ["lr", "loss", "acc", "auc",
                                "precision", "recall", "map",
                                "val_loss", "val_acc",
                                "val_precision", "val_recall",
                                "val_auc", "val_map"]
        train_prog_bar = Progbar(train_steps_per_epoch, width=60,
                           stateful_metrics=bar_stateful_metrics)

        epoch = 0
        train_losses=[]
        history = []
        start_time = epoch_start_time = time.time()
        best_metric = 0.0
        metric = 'val_auc'
        
        for step, (voxels, labels) in enumerate(train_dist_ds):
            loss = strategy.run(train_step, args=(voxels, labels))
            loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
            train_losses.append(loss)
            
            train_prog_bar.add(1,
                               values=[("lr",lrfn(epoch)),
                                       ("loss", loss.numpy()),
                                       ('acc',train_accuracy.result().numpy()),
                                       ('precision',train_precision.result().numpy()),
                                       ('recall',train_recall.result().numpy()),
                                       ('auc',train_roc_auc.result().numpy()),
                                       ('map',train_map.result().numpy())])
            
            if ((step+1) // train_steps_per_epoch) > epoch:
                valid_prog_bar = Progbar(valid_steps_per_epoch, width=60,
                           stateful_metrics=["lr", "loss", "acc",
                                             "auc", "precision",
                                             "recall", "map", "val_acc",
                                             "val_precision",
                                             "val_recall", "val_auc",
                                             "val_map"])
                
                valid_loss = []
                for voxel, label in valid_dist_ds:
                    batch_loss  = strategy.run(valid_step, args=(voxel, label)) # just one batch
                    batch_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                                 batch_loss, axis=None)
                    valid_loss.append(batch_loss.numpy())
                    
                    valid_prog_bar.add(
                        1,
                        values=[("lr",lrfn(epoch)),
                                ('acc',train_accuracy.result().numpy()),
                                ('precision',train_precision.result().numpy()),
                                ('recall',train_recall.result().numpy()),
                                ('auc',train_roc_auc.result().numpy()),
                                ('map',train_map.result().numpy()),
                                ('val_loss',batch_loss.numpy()),
                                ('val_acc',valid_accuracy.result().numpy()),
                                ('val_precision',valid_precision.result().numpy()),
                                ('val_recall',valid_recall.result().numpy()),
                                ('val_auc',valid_roc_auc.result().numpy()),
                                ('val_map',valid_map.result().numpy())
                               ])
                    
                valid_loss = np.mean(valid_loss)
                
                epoch = (step+1) // train_steps_per_epoch
                epoch_time = time.time() - epoch_start_time
                print(f"\nEpoch: {epoch}/{cfg.n_epochs} Done.")
                print("-"*60)
                     
                val_log_dict = {
                                'epoch': epoch,
                                'lr': lrfn(epoch),
                                'time': epoch_time,
                                'loss': loss.numpy(),
                                'acc': train_accuracy.result().numpy(),
                                'precision': train_precision.result().numpy(),
                                'recall': train_recall.result().numpy(),
                                'auc': train_roc_auc.result().numpy(),
                                'map': train_map.result().numpy(),
                                'val_loss': round(valid_loss, 4),
                                'val_acc': valid_accuracy.result().numpy(),
                                'val_precision': valid_precision.result().numpy(),
                                'val_recall': valid_recall.result().numpy(),
                                'val_auc': valid_roc_auc.result().numpy(),
                                'val_map': valid_map.result().numpy(),
                               }

                # Save the best model
                if val_log_dict[metric]>best_metric:
                    best_metric = val_log_dict[metric]
                    model.save(f'model_fold{fold}_{scan_type}.h5')
                    wandb.save(f'model_fold{fold}_{scan_type}.h5')

                val_log_dict['best_val_auc'] = best_metric
                wandb.log(val_log_dict, step=step)
                
                history.append(val_log_dict)
                epoch_start_time = time.time()
                train_accuracy.reset_states()
                train_roc_auc.reset_states()
                train_map.reset_states()
                valid_accuracy.reset_states()
                valid_roc_auc.reset_states()
                valid_map.reset_states()

                if epoch >= cfg.n_epochs:
                    break
                
                print()
                train_prog_bar = Progbar(train_steps_per_epoch, width=60,
                                         stateful_metrics=bar_stateful_metrics)
            else:
                log_dict = {
                        'epoch': epoch,
                        'lr': lrfn(epoch),
                        'loss': loss.numpy(),
                        'acc': train_accuracy.result().numpy(),
                        'precision': train_precision.result().numpy(),
                        'recall': train_recall.result().numpy(),
                        'auc': train_roc_auc.result().numpy(),
                        'map': train_map.result().numpy(),
                       }
                wandb.log(log_dict, step=step)

        training_time = time.time() - start_time
        print("\nTotal Training Time: {:0.1f}s".format(training_time))
    
    history_df = pd.DataFrame(history)
    history_df.to_csv(f'history{fold}.csv')

    del model, decoder, train_dataset, valid_dataset
    gc.collect()

    oof_aucs[scan_type] = float(np.max(history_df['val_auc']))
    print("oof_aucs", oof_aucs)
    
    run.finish()
    
    # Plot Training History
    if DISPLAY_PLOT:
        print ("\n\n")
        plt.figure(figsize=(15,5))
        plt.plot(np.arange(len(history_df['auc'])), history_df['auc'],
                 '-o', label='Train auc', color='#ff7f0e')
        plt.plot(np.arange(len(history_df['val_auc'])), history_df['val_auc'],
                 '-o', label='Val auc', color='#1f77b4')
        x = np.argmax(history_df['val_auc'])
        y = np.max(history_df['val_auc'])
        xdist = plt.xlim()[1] - plt.xlim()[0]
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x,y,s=200, color='#1f77b4')
        plt.text(x-0.03*xdist, y-0.13*ydist, 'Max AUC\n%.2f'%y, size=10)
        plt.ylabel('auc', size=14)
        plt.xlabel('Epoch', size=14)
        plt.legend(loc=2)
        
        plt2 = plt.gca().twinx()
        plt2.plot(np.arange(len(history_df['loss'])),
                  history_df['loss'],'-o', label='Train Loss', color='#f98b88')
        plt2.plot(np.arange(len(history_df['val_loss'])),
                  history_df['val_loss'],'-o', label='Val Loss', color='#3c1361')
        x = np.argmin(history_df['val_loss'])
        y = np.min(history_df['val_loss'])
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x, y, s=200, color='#3c1361')
        plt.text(x-0.03*xdist, y+0.05*ydist,'Min Loss', size=10)
        plt.ylabel('Loss', size=14)
        plt.title(f"{config_dict['job_name']} Size - {cfg.input_dims}")
        plt.legend(loc=3)
        plt.savefig(f'fig{fold}_{scan_type}.png')
        plt.show()        

In [None]:
display(HTML(f'<a href=https://wandb.ai/sreevishnu-damodaran/RSNA-train-public style="font-family: Segoe UI;font-size: 1.5em; font-weight: 300;">Last Run Stats - View the Full Dashboard Here üîé</a>'))
print()
display(HTML("<div style='width: 780px; height: 500px; padding: 0; overflow: hidden;'><iframe src='https://wandb.ai/sreevishnu-damodaran/RSNA-train-public/runs/2jwjlf5f', width=1010, height=650, style='-ms-zoom: 0.75; -moz-transform: scale(0.75); -moz-transform-origin: 0 0; -o-transform: scale(0.75); -o-transform-origin: 0 0; -webkit-transform: scale(0.75); -webkit-transform-origin: 0 0;'></iframe></div>"))

In [None]:
!find ./ -name "*.gif" -type f -delete

<p style='text-align: center;'><span style="color: #000508; font-family: Segoe UI; font-size: 2.0em; font-weight: 300; letter-spacing:3px">HAVE A GREAT DAY !</span></p>

<p style='text-align: center;'><span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">Let me know if you have any suggestions!</span></p>