[Inference Notebook](https://www.kaggle.com/venkat555/ranzcr-clip-tpu-densenet-with-kfold-inference/)

**Credits** 
* Flowers TPU Notebook 
* Fellow Kagglers - All the amazing posts and kernels to learn from 
* Using various image size which are already stratified by InstanceUID from https://www.kaggle.com/prateek0x/creating-stratified-groupkfold-tfrecords-256x256
* https://www.kaggle.com/prateek0x/stratified-groupkfold-with-efn-tfrecords

## Header Imports

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import math, re, os
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets
import warnings, gc, random, math, os, re
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE
from sklearn.model_selection import train_test_split 
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import tensorflow.keras.backend as K




## Bootstrap 

In [None]:
FFO = False 
if (FFO==False):
    FOLDS = 3
    EPOCHS = 10
    FIRST_FOLD_ONLY = False
    EFN=3
else:
    FOLDS = 3
    EPOCHS = 10
    FIRST_FOLD_ONLY = True
    EFN=1

IMG_SIZE_WIDTH = HEIGHT = WIDTH = 512 

print(" folds={} epochs={} ffo={} efn={} save={}".format(FOLDS,EPOCHS,FIRST_FOLD_ONLY, EFN, FFO))

In [None]:
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
SEED =555 
seed = SEED
seed_everything(seed)

## TPU Setup

In [None]:
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except ValueError: # no TPU found, detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)

## Setup to read data

In [None]:
IMAGE_SIZE = [WIDTH, HEIGHT] # At this size, a GPU will run out of memory. Use the TPU.

# For GPU training, please select 224 x 224 px image size.
GCS_DS_PATH = KaggleDatasets().get_gcs_path("ranzcr-{0}x{1}".format(IMG_SIZE_WIDTH,IMG_SIZE_WIDTH)) # you can list the bucket with "!gsutil ls $GCS_DS_PATH"
print( "GCS_DS_PATH :{} ".format(GCS_DS_PATH))
   
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
AUG_BATCH = BATCH_SIZE
CHANNELS=3

training_filenames = []
training_filenames.append(GCS_DS_PATH + '/*.tfrec')
TRAINING_FILENAMES = tf.io.gfile.glob(training_filenames)
GCS_TEST_DS_PATH = KaggleDatasets().get_gcs_path("ranzcr-clip-catheter-line-classification".format(IMG_SIZE_WIDTH,IMG_SIZE_WIDTH))
TEST_FILENAMES = tf.io.gfile.glob(GCS_TEST_DS_PATH + '/test_tfrecords/*.tfrec') 
print(" [file names] train {} , test : {}".format(TRAINING_FILENAMES,TEST_FILENAMES))
# inferred from 
# label_num_to_disease_map.json
CLASSES = [    
    'ETT - Abnormal',
    'ETT - Borderline',
    'ETT - Normal',
    'NGT - Abnormal',
    'NGT - Borderline',
    'NGT - Incompletely Imaged',
    'NGT - Normal',
    'CVC - Abnormal',
    'CVC - Borderlinedex',
    'CVC - Normal',
    'Swan Ganz Catheter Present']

## Setup a learning rate scheduler

In [None]:
# Using an LR ramp up because fine-tuning a pre-trained model.
# Starting with a high LR would break the pre-trained weights.

LR_START = 0.00001
LR_MAX = 0.00005 * strategy.num_replicas_in_sync
LR_MIN = 0.00001
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 0
LR_EXP_DECAY = .8

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = True)

## Data Augmentation 


In [None]:
# data augmentation @cdeotte kernel: https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
def transform_rotation(image, height, rotation):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    rotation = rotation * tf.random.uniform([1],dtype='float32')
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def transform_shear(image, height, shear):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly sheared
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

In [None]:
def data_augment(image, label):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Shear
    if p_shear > .2:
        if p_shear > .6:
            image = transform_shear(image, HEIGHT, shear=20.)
        else:
            image = transform_shear(image, HEIGHT, shear=-20.)
            
    # Rotation
    if p_rotation > .2:
        if p_rotation > .6:
            image = transform_rotation(image, HEIGHT, rotation=45.)
        else:
            image = transform_rotation(image, HEIGHT, rotation=-45.)
    
    
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270ยบ
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180ยบ
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90ยบ
        
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
        
    # Crops
    if p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.6)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.7)
        else:
            image = tf.image.central_crop(image, central_fraction=.8)
    elif p_crop > .4:
        crop_size = tf.random.uniform([], int(HEIGHT*.6), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
            
    image = tf.image.resize(image, size=[HEIGHT, WIDTH])

    return image, label

## Dataset Transformation utilities

In [None]:
def onehot(image,label):
    NUMCLASSES = len(CLASSES)
    return image,tf.one_hot(label,NUMCLASSES)

In [None]:
def to_float32_2(image, label):
    max_val = tf.reduce_max(label, axis=-1,keepdims=True)
    cond = tf.equal(label, max_val)
    label = tf.where(cond, tf.ones_like(label), tf.zeros_like(label))
    return tf.cast(image, tf.float32), tf.cast(label, tf.int32)

def to_float32(image, label):
    return tf.cast(image, tf.float32), label

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [IMG_SIZE_WIDTH,IMG_SIZE_WIDTH, 3]) # explicit size needed for TPU
    return image

def decode_test_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [IMG_SIZE_WIDTH*2,IMG_SIZE_WIDTH*2, 3]) # explicit size needed for TPU
    return image

# Create a dictionary describing the features.


def read_labeled_tfrecord(example):
    # Create a dictionary describing the features.
    LABELED_TFREC_FORMAT = {
        "StudyInstanceUID"           : tf.io.FixedLenFeature([], tf.string),
        "image"                      : tf.io.FixedLenFeature([], tf.string),
        "ETT - Abnormal"             : tf.io.FixedLenFeature([], tf.int64), 
        "ETT - Borderline"           : tf.io.FixedLenFeature([], tf.int64), 
        "ETT - Normal"               : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Abnormal"             : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Borderline"           : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Incompletely Imaged"  : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Normal"               : tf.io.FixedLenFeature([], tf.int64), 
        "CVC - Abnormal"             : tf.io.FixedLenFeature([], tf.int64), 
        "CVC - Borderline"           : tf.io.FixedLenFeature([], tf.int64), 
        "CVC - Normal"               : tf.io.FixedLenFeature([], tf.int64), 
        "Swan Ganz Catheter Present" : tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image']) 
    image= tf.image.resize(image, [IMAGE_SIZE[0],IMAGE_SIZE[0]])
    uid= example["StudyInstanceUID"]
    cvca = example["CVC - Abnormal"]
    cvcb = example["CVC - Borderline"]
    cvcn = example["CVC - Normal"]
    etta = example["ETT - Abnormal"]
    ettb = example["ETT - Borderline"]
    ettn = example["ETT - Normal"]
    ngta = example["NGT - Abnormal"]
    ngtb = example["NGT - Borderline"]
    ngti = example["NGT - Incompletely Imaged"]
    ngtn = example["NGT - Normal"]
    sgcp = example["Swan Ganz Catheter Present"]

    label  = [  etta, ettb, ettn, ngta, ngtb, ngti, ngtn,cvca, cvcb, cvcn , sgcp]
#     label = tf.cast(0, tf.int32)
#     for i in range(len(values)):
#         if ( values[i]==1):
#             label = tf.cast(i, tf.int32)
    label=[tf.cast(i,tf.float32) for i in label]
    return image,label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT  = {
    "StudyInstanceUID" : tf.io.FixedLenFeature([], tf.string),
    "image" : tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_test_image(example['image'])
    image= tf.image.resize(image, [IMAGE_SIZE[0],IMAGE_SIZE[0]])
    image_name = example['StudyInstanceUID']
    return image, image_name # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def data_augment_old(image, label):
    # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
    # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
    # of the TPU while the TPU itself is computing gradients.
    # RandomCrop, VFlip, HFilp, RandomRotate
    #image = tf.image.rot90(image,k=np.random.randint(4))
    image = tf.image.random_flip_left_right(image , seed=SEED)
    image= tf.image.random_flip_up_down(image, seed=SEED)
    IMG_SIZE=IMAGE_SIZE[0]
    # Add 6 pixels of padding
    #image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6) 
    # Random crop back to the original size
    #image = tf.image.random_crop(image, size=[IMG_SIZE, IMG_SIZE, 3])
    image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness
    image = tf.image.random_saturation(image, 0, 2, seed=SEED)
    image = tf.image.adjust_saturation(image, 3)
    
    #image = tf.image.central_crop(image, central_fraction=0.5)
    return image, label   

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_training_dataset(dataset, do_aug=True , do_onehot=False):
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.batch(AUG_BATCH)
    #if do_aug: dataset = dataset.map(transform, num_parallel_calls=AUTO) # note we put AFTER batching
    if do_onehot: dataset = dataset.map(onehot, num_parallel_calls=AUTO) 
    dataset = dataset.unbatch()
    
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    #the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)
    #c = 0
    #for filename in filenames:
    #    c += sum(1 for _ in tf.data.TFRecordDataset(filename))
    #return c
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
print("STEPS_PER_EPOCH {}".format(STEPS_PER_EPOCH))
#NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
#print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))
print('Dataset: {} training images,  {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_TEST_IMAGES))

## Print Data Shapes

#### Train

In [None]:
# data dump
print("Training data shapes:")
training_raw_dataset=load_dataset(TRAINING_FILENAMES, labeled=True)

for image,label in get_training_dataset(training_raw_dataset, do_aug=False , do_onehot=False).take(3):
    print(image.numpy().shape, label.numpy().shape)


In [None]:
# Peek at training data
training_dataset = get_training_dataset(training_raw_dataset , do_aug=False , do_onehot=False ).map(to_float32)
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)

In [None]:
for image,label in get_training_dataset(training_raw_dataset, do_aug=False , do_onehot=False).take(1):
    print(image.numpy().shape, label.numpy() , type(label.numpy()))

## Visualization utilities

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)

def batch_to_numpy_images_and_labels(data):
    images, labels = data    
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        #print("label {}".format(label))
        title = '' if label is None else CLASSES[np.argmax(label,axis=-1)]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    

In [None]:
display_batch_of_images(next(train_batch))

In [None]:
# peer at test data
test_dataset = get_test_dataset()
test_dataset = test_dataset.unbatch().batch(20)
test_batch = iter(test_dataset)

In [None]:
display_batch_of_images(next(test_batch), None)

#### Test

In [None]:
print("Test data shapes:")
for image, image_name in get_test_dataset().take(3):
    print(image.numpy().shape )

## Modelling

In [None]:
!pip install -q efficientnet

In [None]:
from tensorflow import keras
from tensorflow.keras import regularizers
import efficientnet.tfkeras as efn
backbones = [efn.EfficientNetB0, efn.EfficientNetB1, efn.EfficientNetB2, efn.EfficientNetB3, 
        efn.EfficientNetB4, efn.EfficientNetB5, efn.EfficientNetB6
            ]

def get_model():
    lr_scheduler = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-5, 
        decay_steps=10000, 
        decay_rate=0.9)

    with strategy.scope():

        #pretrained_model = tf.keras.applications.EfficientNetB0(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3),weights='imagenet', include_top=False)
        pretrained_model = backbones[EFN](input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3),weights='noisy-student', include_top=False)
        #pretrained_model.trainable = False # False = transfer learning, True = fine-tuning
        pretrained_model.trainable = True
        
        set_trainable = False
#         for layer in pretrained_model.layers:
#             if layer.name in ['conv5_block32_1_conv', 'conv5_block32_2_conv', 'conv5_block32_1_bn' , 'conv5_block32_1_relu' ,
#                              'conv5_block32_concat' , 'bn' , 'relu']:
#                 set_trainable = True
#             if set_trainable:
#                 layer.trainable = True
#             else:
#                 layer.trainable = False
        
        layers = [(layer, layer.name, layer.trainable) for layer in pretrained_model.layers]
        layer_frame=pd.DataFrame(layers, columns=['Layer Type', 'Layer Name', 'Layer Trainable']) 

        model = tf.keras.Sequential([
            #img_adjust_layer,
            pretrained_model,
            tf.keras.layers.GlobalAveragePooling2D(),
            
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Dense(1024,activation='relu'  ),
            tf.keras.layers.Dropout(0.3),
            
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Dense(512,activation='relu'  ),
            tf.keras.layers.Dropout(0.3),
            
            tf.keras.layers.Dense(len(CLASSES), activation='sigmoid',dtype='float32')
        ])
        
        auc = keras.metrics.AUC(name='auc')
 
    opt = tf.keras.optimizers.Adam(lr=0.00001) 
    model.compile(
        optimizer=opt,
        loss = tf.keras.losses.BinaryCrossentropy(label_smoothing=0.05),
        metrics= auc
    )
    return model, layer_frame

base_model , layer_frame = get_model()
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))


In [None]:
layer_frame.tail(10)

In [None]:
base_layers = [(layer, layer.name, layer.trainable) for layer in base_model.layers]
base_layer_frame=pd.DataFrame(base_layers, columns=['Layer Type', 'Layer Name', 'Layer Trainable']) 
base_layer_frame

In [None]:
base_model.summary()

## Setup for KFold 

In [None]:
def get_validation_dataset_for_kfold(dataset, do_onehot=True):
    dataset = dataset.batch(BATCH_SIZE)
    if do_onehot: dataset = dataset.map(onehot, num_parallel_calls=AUTO) # we must use one hot like augmented train data
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset


In [None]:
def create_callbacks(model_save_path,fold,verbose=1 ):
    verbose = int(verbose>0)
    checkpoint_filepath = "{}/eff{}-cmodel-fold{}.h5".format(model_save_path , EFN, fold)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_weights_only=True,
        monitor='val_auc',
        mode='max',
        save_best_only=True,
        verbose=verbose)

    reducelr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_auc',
        mode='max',
        factor=0.1,
        patience=3,
        verbose=0)

    earlystop = tf.keras.callbacks.EarlyStopping(
        monitor='val_auc',
        mode='max',
        patience=5, 
        verbose=verbose)

    callbacks = [reducelr, earlystop,checkpoint]
    return callbacks

In [None]:
from sklearn.model_selection import KFold
import tensorflow.keras.backend as K
VERBOSE =1
def train_cross_validate(folds = 5):
    histories = []
    models = []
    # Define per-fold score containers
    acc_per_fold = []
    loss_per_fold = []
    #early_stopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 10)
    kfold = KFold(folds, shuffle = True, random_state = SEED)
    for f, (trn_ind, val_ind) in enumerate(kfold.split(TRAINING_FILENAMES)):
        print(); print('#'*25)
        print('### FOLD',f+1)
        print('#'*25)
        train_dataset = load_dataset(list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[trn_ind]['TRAINING_FILENAMES']), labeled = True)
        val_dataset = load_dataset(list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[val_ind]['TRAINING_FILENAMES']), labeled = True, ordered = True)
        K.clear_session()
        model = base_model
        data_for_validation = get_validation_dataset_for_kfold(val_dataset , do_onehot=False)
            
        history = model.fit(
            get_training_dataset(train_dataset, do_aug=False , do_onehot=False),
            steps_per_epoch = STEPS_PER_EPOCH,
            epochs = EPOCHS,
            callbacks = create_callbacks(".",f, VERBOSE),
            validation_data = data_for_validation,
            verbose=2
        )
        scores = model.evaluate(data_for_validation, verbose=0)
        #print(scores)
        print(f'Score for fold {f+1}: {model.metrics_names[0]} of {scores[0]}')
        acc_per_fold.append(scores[1] * 100)
        loss_per_fold.append(scores[0])
        model.save("model-fold{}.h5".format(f+1))
        models.append(model)
        histories.append(history)
        if FIRST_FOLD_ONLY: break
    return histories, models,acc_per_fold, loss_per_fold

In [None]:
def train_and_predict(folds = 5):
    test_ds = get_test_dataset(ordered=True) #map(data_augment, num_parallel_calls=AUTO) # since we are splitting the dataset and iterating separately on images and ids, order matters.
    test_images_ds = test_ds.map(lambda image, idnum: image)
    print('Start training %i folds'%folds)
    histories, models,acc_per_fold,loss_per_fold  = train_cross_validate(folds = folds)
    # == Provide average scores ==
    print('------------------------------------------------------------------------')
    print('Score per fold')
    for i in range(0, len(acc_per_fold)):
        print('------------------------------------------------------------------------')
        print(f'> Fold {i+1} - Loss: {loss_per_fold[i]} - Accuracy: {acc_per_fold[i]}%')
    print('------------------------------------------------------------------------')
    print('Average scores for all folds:')
    print(f'> Accuracy: {np.mean(acc_per_fold)} (+- {np.std(acc_per_fold)})')
    print(f'> Loss: {np.mean(loss_per_fold)}')
    print('------------------------------------------------------------------------')
    
    print('Computing predictions...')
    # get the mean probability of the folds models
    if FIRST_FOLD_ONLY: probabilities = np.average([models[i].predict(test_images_ds) for i in range(1)], axis = 0)
    else: probabilities = np.average([models[i].predict(test_images_ds) for i in range(folds)], axis = 0)
    
    return histories, models, probabilities , test_ds

## Train and Predict

In [None]:
%%time
histories, models, probabilities, test_ds = train_and_predict(folds = FOLDS)

## Generate submission file 

In [None]:
print('Generating submission.csv file...')
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids] +  [probabilities[:,i] for i in range(probabilities.shape[1])]), fmt=['%s', '%f','%f' , '%f', '%f','%f' , '%f', '%f','%f' , '%f', '%f','%f'  ], delimiter=',', header='StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present', comments='')


In [None]:
!head submission.csv


In [None]:
def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

### Plot train/loss curves

In [None]:
for history in histories : 
    #print(history)
    display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
    display_training_curves(history.history['auc'], history.history['val_auc'], 'accuracy', 212)


## Confusion Matrix

In [None]:
def get_class_name(label):
    index = 0 
    size = len(CLASSES)
    name = 9
    for i in range(size):
        if ( label[i]==1):
            name = i
            
    
    return name

In [None]:
def get_correct_labels(cm_correct_labels):
    labels = [] 
    for x in cm_correct_labels :
        labels.append(get_class_name(x))
    
    return labels

In [None]:
%%time
all_labels = []; all_prob = []; all_pred = [];  raw_labels = []
kfold = KFold(FOLDS, shuffle = True, random_state = SEED)
for j, (trn_ind, val_ind) in enumerate( kfold.split(TRAINING_FILENAMES) ):
    print('Inferring fold',j+1,'validation images...')
    VAL_FILES = list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[val_ind]['TRAINING_FILENAMES'])
    NUM_VALIDATION_IMAGES = count_data_items(VAL_FILES)
    cmdataset = get_validation_dataset_for_kfold(load_dataset(VAL_FILES, labeled = True, ordered = True), do_onehot=False)
    images_ds = cmdataset.map(lambda image, label: image)
    labels_ds = cmdataset.map(lambda image, label: label).unbatch()
    labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy()
    raw_labels.append(labels)
    labels = get_correct_labels(labels)
    all_labels.append(labels  ) # get everything as one batch
    prob = models[j].predict(images_ds)
    all_prob.append( prob )
    all_pred.append( np.argmax(prob, axis=-1) )
    if FIRST_FOLD_ONLY: break
cm_correct_labels = np.concatenate(all_labels)
cm_probabilities = np.concatenate(all_prob)
cm_predictions = np.concatenate(all_pred)
cm_raw_labels = np.concatenate(raw_labels)

## Compute ROC AUC Score 

In [None]:
y_true = pd.DataFrame(columns=CLASSES)
for i in range(len(CLASSES)):
    y_true[CLASSES[i]] = cm_raw_labels[:,i]

In [None]:
y_pred=pd.DataFrame(columns=CLASSES)
for i in range(len(CLASSES)):
    y_pred[CLASSES[i]] = cm_probabilities[:,i]

In [None]:
y_true.head(4)

In [None]:
y_pred.head(4)

In [None]:
from sklearn.metrics import roc_auc_score
y_true_flat = y_true.values.reshape(-1)
y_pred_flat = y_pred.values.reshape(-1)
roc_auc_score(y_true_flat, y_pred_flat)

In [None]:
roc_auc_score(y_true, y_pred)