<h2>Stratified GroupKFold with EFN & TFRecords</h2>

This is a sample notebook for Ranzr clip competition which presents Stratified GroupKFold cross validation and Efficient Net architecture getting trained using TPUs. This notebook allows you to configure your model training by letting you choose efn architechutre, TFRecords with different image shapes and many more parameters.

All the Dataset used in this notebook is public and can be found at : [ [(128x128)](https://www.kaggle.com/prateek0x/ranzcr-128x128) , [(256x256)](https://www.kaggle.com/prateek0x/ranzcr-256x256) , [(384x384)](https://www.kaggle.com/prateek0x/ranzcr-384x384) , [(512x512)](https://www.kaggle.com/prateek0x/ranzcr-512x512)  ]


In [None]:
!pip install -q efficientnet

In [None]:
import pandas as pd
import numpy as np 
from kaggle_datasets import KaggleDatasets
import tensorflow as tf
import efficientnet.tfkeras as efn
import warnings, gc, random, math, os, re
from sklearn.model_selection import KFold
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 0
seed_everything(seed)
warnings.filterwarnings('ignore')

## Model Configuration

In [None]:
DEVICE= "TPU"

if DEVICE == "TPU":
    print("connecting to TPU...")
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        print("Could not connect to TPU")
        tpu = None

    if tpu:
        try:
            print("initializing  TPU ...")
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("TPU initialized")
        except _:
            print("failed to initialize TPU")
    else:
        print("Please select TPU as accelrator.")


AUTO     = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
FOLDS = 5
LEARNING_RATE = 1e-5
BATCH_SIZE = 32 * REPLICAS
EPOCHS = 10
HEIGHT = 512 # 128, 256, 384
WIDTH = 512 # 128, 256, 384
RS_HEIGHT = 512 
RS_WIDTH= 512
EFN = 4
N_CLASSES =11
ES_PATIENCE = 4
CHANNELS =3
DISPLAY_PLOT =True

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r'-([0-9]*)\.').search(filename).group(1)) for filename in filenames]
    return np.sum(n)

GCS_PATH = KaggleDatasets().get_gcs_path('ranzcr-{0}x{1}'.format(HEIGHT,WIDTH))

In [None]:
# data augmentation @cdeotte kernel: https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
def transform_rotation(image, height, rotation):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    rotation = rotation * tf.random.uniform([1],dtype='float32')
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def transform_shear(image, height, shear):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly sheared
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

In [None]:
def data_augment(image, label):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Shear
    if p_shear > .2:
        if p_shear > .6:
            image = transform_shear(image, HEIGHT, shear=20.)
        else:
            image = transform_shear(image, HEIGHT, shear=-20.)
            
    # Rotation
    if p_rotation > .2:
        if p_rotation > .6:
            image = transform_rotation(image, HEIGHT, rotation=45.)
        else:
            image = transform_rotation(image, HEIGHT, rotation=-45.)
    
    
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270ยบ
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180ยบ
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90ยบ
        
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
        
    # Crops
    if p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.6)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.7)
        else:
            image = tf.image.central_crop(image, central_fraction=.8)
    elif p_crop > .4:
        crop_size = tf.random.uniform([], int(HEIGHT*.6), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
            
    image = tf.image.resize(image, size=[HEIGHT, WIDTH])

    return image, label

## TPU utility Functions

In [None]:
# Datasets utility functions
def decode_image(image_data):
    """
        1. Decode a JPEG-encoded image to a uint8 tensor.
        2. Cast tensor to float and normalizes (range between 0 and 1).
        3. Resize and reshape images to the expected size.
    """
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
                      
    image = tf.image.resize(image, [HEIGHT, WIDTH])
    image = tf.reshape(image, [HEIGHT, WIDTH, 3])
    return image

def read_tfrecord(example, labeled=True):
    """
        1. Parse data based on the 'TFREC_FORMAT' map.
        2. Decode image.
        3. If 'labeled' returns (image, label) if not (image, name).
    """
    if labeled:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'ETT - Abnormal' : tf.io.FixedLenFeature([], tf.int64),
            'ETT - Borderline' : tf.io.FixedLenFeature([], tf.int64),
            'ETT - Normal' : tf.io.FixedLenFeature([], tf.int64),
            "NGT - Abnormal" : tf.io.FixedLenFeature([], tf.int64),
            'NGT - Borderline' : tf.io.FixedLenFeature([], tf.int64),
            'NGT - Incompletely Imaged' : tf.io.FixedLenFeature([], tf.int64),
            'NGT - Normal' : tf.io.FixedLenFeature([], tf.int64),
            'CVC - Abnormal' : tf.io.FixedLenFeature([], tf.int64),
            'CVC - Borderline': tf.io.FixedLenFeature([], tf.int64),
            'CVC - Normal': tf.io.FixedLenFeature([], tf.int64),
            'StudyInstanceUID' : tf.io.FixedLenFeature([], tf.string),
            'Swan Ganz Catheter Present' : tf.io.FixedLenFeature([], tf.int64),
            "StudyInstanceUID":  tf.io.FixedLenFeature([], tf.string)
        }
    else:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'StudyInstanceUID': tf.io.FixedLenFeature([], tf.string), 
        }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    
    if labeled:
        #tf.cast(example['target'], tf.int32),
        
        label_or_name = tf.stack([
            tf.cast( example['ETT - Abnormal'],tf.int64),
            tf.cast( example['ETT - Borderline'],tf.int64),
            tf.cast( example['ETT - Normal'],tf.int64),
            tf.cast( example["NGT - Abnormal"],tf.int64),
            tf.cast( example['NGT - Borderline'],tf.int64),
            tf.cast( example['NGT - Incompletely Imaged'],tf.int64),
            tf.cast( example['NGT - Normal'],tf.int64),
            tf.cast( example['CVC - Abnormal'],tf.int64),
            tf.cast( example['CVC - Borderline'],tf.int64),
            tf.cast( example['CVC - Normal'],tf.int64),
            tf.cast( example['Swan Ganz Catheter Present'],tf.int64)
            ])
        
    else:
        label_or_name =  example['StudyInstanceUID']
        
    return image, label_or_name

def load_dataset(filenames, labeled=True, ordered=False):
    """
        Create a Tensorflow dataset from TFRecords.
    """
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(lambda x: read_tfrecord(x, labeled=labeled), num_parallel_calls=AUTO)
    return dataset

def get_dataset(FILENAMES, labeled=True, ordered=False, repeated=False, augment=False):
    """
        Return a Tensorflow dataset ready for training or inference.
    """
    dataset = load_dataset(FILENAMES, labeled=labeled, ordered=ordered)
    if augment:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    if repeated:
        dataset = dataset.repeat()
    if not ordered:
        dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

In [None]:
LR_START = 1e-5
LR_MIN = 1e-5
LR_MAX = LEARNING_RATE
LR_RAMPUP_EPOCHS = 3
LR_SUSTAIN_EPOCHS = 0
N_CYCLES = .5 


def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        progress = (epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) / (EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS)
        lr = LR_MAX * (0.5 * (1.0 + tf.math.cos(math.pi * N_CYCLES * 2.0 * progress)))
        if LR_MIN is not None:
            lr = tf.math.maximum(LR_MIN, lr)
            
    return lr

In [None]:
EFNS = [efn.EfficientNetB0, efn.EfficientNetB1, efn.EfficientNetB2, efn.EfficientNetB3, 
        efn.EfficientNetB4, efn.EfficientNetB5, efn.EfficientNetB6]

def model_fn(dim=HEIGHT):
    inp = tf.keras.layers.Input(shape=(dim,dim,3))
    base = EFNS[EFN](input_shape=(dim,dim,3),include_top=False,weights='noisy-student')
    x = base(inp)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(N_CLASSES,activation='sigmoid')(x)
    model = tf.keras.Model(inputs=inp,outputs=x)
    opt = tf.keras.optimizers.Adam(lr=LEARNING_RATE) 
    loss = "binary_crossentropy"
    model.compile(optimizer=opt,loss=loss,metrics=[tf.keras.metrics.AUC(multi_label=True)])
    
    return model

m = model_fn()
m.summary()

In [None]:
skf = KFold(n_splits=FOLDS, shuffle=True, random_state=seed)
oof_pred = []; oof_labels = []; history_list = []

for fold,(idxT, idxV) in enumerate(skf.split(np.arange(15))):
    if tpu:
        tf.tpu.experimental.initialize_tpu_system(tpu)
    
    print(f'\nFOLD: {fold+1}')
    print(f'TRAIN: {idxT} VALID: {idxV}')

    # Create train and validation sets
    # Create train and validation sets
    TRAIN_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/Id_train%.2i*.tfrec' % x for x in idxT])    
    VALID_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/Id_train%.2i*.tfrec' % x for x in idxV])
    
    np.random.shuffle(TRAIN_FILENAMES)
    
    ct_train = count_data_items(TRAIN_FILENAMES)
    ct_valid = count_data_items(VALID_FILENAMES)
    
    ## MODEL
    K.clear_session()
    with strategy.scope():
        model = model_fn()
        
    model_path = f'model_{fold}.h5'
    es = EarlyStopping(monitor='val_auc', mode='max', 
                       patience=ES_PATIENCE, restore_best_weights=True, verbose=1)

    ## TRAIN
    history = model.fit(x=get_dataset(TRAIN_FILENAMES, labeled=True, ordered=False, repeated=True, augment=True), 
                        validation_data=get_dataset(VALID_FILENAMES, labeled=True, ordered=True, repeated=False, augment=False), 
                        steps_per_epoch=(ct_train // BATCH_SIZE), 
                        callbacks=[es, LearningRateScheduler(lrfn, verbose=1)], 
                        epochs=EPOCHS,  
                        #callbacks=[es], 
                        verbose=1).history
      
    history_list.append(history)
    # Save last model weights
    model.save_weights(model_path)

    # OOF predictions
    ds_valid = get_dataset(VALID_FILENAMES, labeled=True, ordered=True, repeated=False, augment=False)
    oof_labels.append([target.numpy() for img, target in iter(ds_valid.unbatch())])
    x_oof = ds_valid.map(lambda image, image_name: image)
    oof_pred.append(np.argmax(model.predict(x_oof), axis=-1))
    
    
    if DISPLAY_PLOT:
        plt.figure(figsize=(15,5))
        plt.plot(np.arange(len(history['auc'])),history['auc'],'-o',label='Train AUC',color='#ff7f0e')
        plt.plot(np.arange(len(history['auc'])),history['val_auc'],'-o',label='Val AUC',color='#1f77b4')
        x = np.argmax( history['val_auc'] )
        y = np.max( history['val_auc'] )
        xdist = plt.xlim()[1] - plt.xlim()[0]
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x,y,s=200,color='#1f77b4')
        plt.text(x-0.03*xdist,y-0.13*ydist,'max auc\n%.2f'%y,size=14)
        plt.ylabel('AUC',size=14); plt.xlabel('Epoch',size=14)
        plt.legend(loc=2)
        
        plt2 = plt.gca().twinx()
        plt2.plot(np.arange(len(history['loss'])),history['loss'],'-o',label='Train Loss',color='#2ca02c')
        plt2.plot(np.arange(len(history['loss'])),history['val_loss'],'-o',label='Val Loss',color='#d62728')
        x = np.argmin( history['val_loss'] )
        y = np.min( history['val_loss'] )
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x,y,s=200,color='#d62728')
        plt.text(x-0.03*xdist,y+0.05*ydist,'min loss',size=14)
        plt.ylabel('Loss',size=14)
        plt.title('Fold %i,Image Size %i, EfficientNet B%i, batch_size %i '% 
                (fold+1,HEIGHT,EFN,BATCH_SIZE),size=18)
        plt.legend(loc=3)
        plt.show()  
        
    
    ## RESULTS
    print(f"#### FOLD {fold+1} OOF Accuracy = {np.max(history['val_auc']):.3f}")

<h1>!The End</h1>Updating....