In [None]:
! pip install -q efficientnet >> /dev/null

## Import modules

In [None]:
import os
import re
import math
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.mixed_precision import experimental as mixed_precision

import efficientnet.tfkeras as efn

# from functools import partial
# from albumentations import (
#     Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast, HorizontalFlip,
#     Rotate
# )

## Set parameters

In [None]:
DATA_PATH = '/kaggle/input/ranzcr-clip-catheter-line-classification'

MODEL_PATH = '/kaggle/working/models'

In [None]:

DEVICE = 'TPU' # ['CPU' GPU' 'TPU']

ENABLE_MIXED_PRECISION = True # [True False]

In [None]:
XLA_ACCELERATE = True


if XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')

In [None]:
SEED = 42

FOLDS = 3 

IMG_SIZE = 600

BATCH_SIZE = 16 # [8, 16, 32, 64, 128, 256, 512]

EPOCHS = 50

EFF_NET = 'B7' # ['B0',B1','B2',B3','B4',B5','B6',B7']

VERBOSE = 1 # [0: silent, 1: progress bar, 2: single line]

In [None]:
NUM_TF_RECS = len(os.listdir(f'{DATA_PATH}/train_tfrecords'))

print(NUM_TF_RECS)

## Setup devices and settings

In [None]:
# For kaggle tpus
from kaggle_datasets import KaggleDatasets
if DEVICE == 'TPU':
    print('TPU')
    DATA_PATH = KaggleDatasets().get_gcs_path(DATA_PATH.split('/')[-1])

In [None]:
if DEVICE == 'CPU':

    strategy = tf.distribute.get_strategy()
    print('\nUsing Default Distribution Strategy  for CPU')


if DEVICE == 'GPU':

    gpu_accelerarors = tf.config.list_physical_devices('GPU')
        
    if len(gpu_accelerarors) > 1:
        strategy = tf.distribute.MirroredStrategy()
        print(f'Number of GPUs available: {len(gpu_accelerarors)}')
        print('\n Using Mirrored Distribution Strategy')
        
    else:
        strategy = tf.distribute.get_strategy()
        if len(gpu_accelerarors) == 1:
            print(f'Number of GPUs available: 1')
            print('\nUsing Default Distribution Strategy for GPU')
        else:
            print('ERROR: GPU not available')
            print('\nUsing Default Distribution Strategy  for CPU')
        
if DEVICE == 'TPU':

    try:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
        tpu_accelerarors = tf.config.list_logical_devices('TPU')
        print(f'Number of TPU cores available: {len(tpu_accelerarors)}')
        print(f'\nUsing TPU Distribution Strategy')
        
    except:
        print('ERROR: TPU not available')
        print('\nUsing Default Distribution Strategy for CPU')
        strategy = tf.distribute.get_strategy()
        
        
if ENABLE_MIXED_PRECISION:
    
    print('\nMixed Precision enabled:')
    
    if DEVICE == 'GPU':
        policy = mixed_precision.Policy('mixed_float16')
        
    if DEVICE == 'TPU':
        policy = mixed_precision.Policy('mixed_bfloat16')
        
    mixed_precision.set_policy(policy)
    
    print('\t...Compute dtype: %s' % policy.compute_dtype)
    print('\t...Variable dtype: %s' % policy.variable_dtype)


REPLICAS = strategy.num_replicas_in_sync
print(f'\nREPLICAS: {REPLICAS}')

## Helper functions

In [None]:
class Dataset:
    
    feature_description = {
        "StudyInstanceUID"           : tf.io.FixedLenFeature([], tf.string),
        "image"                      : tf.io.FixedLenFeature([], tf.string),
        "ETT - Abnormal"             : tf.io.FixedLenFeature([], tf.int64), 
        "ETT - Borderline"           : tf.io.FixedLenFeature([], tf.int64), 
        "ETT - Normal"               : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Abnormal"             : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Borderline"           : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Incompletely Imaged"  : tf.io.FixedLenFeature([], tf.int64), 
        "NGT - Normal"               : tf.io.FixedLenFeature([], tf.int64), 
        "CVC - Abnormal"             : tf.io.FixedLenFeature([], tf.int64), 
        "CVC - Borderline"           : tf.io.FixedLenFeature([], tf.int64), 
        "CVC - Normal"               : tf.io.FixedLenFeature([], tf.int64), 
        "Swan Ganz Catheter Present" : tf.io.FixedLenFeature([], tf.int64),
    }
    
#     transforms = Compose([
#             Rotate(limit=40),
#             RandomBrightness(limit=0.1),
#             JpegCompression(quality_lower=85, quality_upper=100, p=0.5),
#             HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
#             RandomContrast(limit=0.2, p=0.5),
#             HorizontalFlip(),
#         ])
    
    aug = tf.keras.Sequential([
        tf.keras.layers.Activation(None, dtype='float32'),
        tf.keras.layers.experimental.preprocessing.RandomFlip(dtype='float32'),
        tf.keras.layers.experimental.preprocessing.RandomRotation(0.04, fill_mode='constant',dtype='float32'),
        tf.keras.layers.experimental.preprocessing.RandomTranslation(0.15,0.15, fill_mode='constant',dtype='float32'),
#         tf.keras.layers.experimental.preprocessing.RandomContrast((.9,1.2),dtype='float32'),
#         tf.keras.layers.experimental.preprocessing.RandomZoom(0.2, fill_mode='constant'),
        tf.keras.layers.experimental.preprocessing.RandomWidth(0.35, dtype='float32'),
        tf.keras.layers.experimental.preprocessing.RandomHeight(0.35, dtype='float32'),
        tf.keras.layers.experimental.preprocessing.Resizing(600,600, dtype='float32')
    ])
    
    def __init__(self, image_size):
        self.image_size = image_size
        
    def parse_function(self, example_proto):
        example = tf.io.parse_single_example(example_proto, self.feature_description)
        image = tf.io.decode_image(example['image'], channels=3)
        label = [example['ETT - Abnormal'],
                 example['ETT - Borderline'],
                 example['ETT - Normal'],
                 example['NGT - Abnormal'],
                 example['NGT - Borderline'],
                 example['NGT - Incompletely Imaged'],
                 example['NGT - Normal'],
                 example['CVC - Abnormal'],
                 example['CVC - Borderline'],
                 example['CVC - Normal'],
                 example['Swan Ganz Catheter Present']]
        return image, label 
    
    
#     def aug_fn(image):
#         data = {"image":image}
#         aug_data = self.transforms(**data)
#         aug_img = aug_data["image"]
# #         aug_img = tf.cast(aug_img/255.0, tf.float32)
#         aug_img = tf.image.resize(aug_img, size=[self.image_size, self.image_size])
#         return aug_img
    
#     def process_data(image, label):
#         aug_img = tf.numpy_function(func=aug_fn, inp=[image], Tout=tf.float32)
#         return aug_img, label
    
    def augment_function(self, image, label): 
#         image = tf.image.random_contrast(image, 0.8, 1.2)
#         image = tf.image.random_brightness(image, 0.1) 
        return self.aug(image, training=True), label 
    
    def process_function(self, image, label):
        image.set_shape([None, self.image_size, self.image_size, 3])
        label.set_shape([None, 11])
        image = tf.image.resize(image, [self.image_size, self.image_size], 'bilinear')/255.
        return image, label
            
    def generator(self, files, batch_size=1, repeat=False, augment=False, shuffle=True, cache=False):
        AUTO = tf.data.experimental.AUTOTUNE
        ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO)
        if shuffle: 
            opt = tf.data.Options()
            opt.experimental_deterministic = False
            ds = ds.with_options(opt)
            ds = ds.shuffle(2000)
        ds = ds.map(self.parse_function, num_parallel_calls=AUTO)
        
        if cache:
            ds = ds.cache()
        
        if repeat:
            ds = ds.repeat()
            
        ds = ds.batch(batch_size)
        
        ds = ds.map(self.process_function, num_parallel_calls=AUTO)

        if augment:
            ds = ds.map(self.augment_function, num_parallel_calls=AUTO)        
        
        ds = ds.prefetch(AUTO)
        return ds

In [None]:
# # A custom layer
# class SpatialAttentionModule(tf.keras.layers.Layer):
#     def __init__(self, kernel_size=3):
#         '''
#         paper: https://arxiv.org/abs/1807.06521
#         code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
#         '''
#         super(SpatialAttentionModule, self).__init__()
#         self.conv1 = tf.keras.layers.Conv2D(16, kernel_size=kernel_size, 
#                                             use_bias=False, 
#                                             kernel_initializer='he_normal',
#                                             strides=1, padding='same', 
#                                             activation=tf.nn.relu6)
# #         self.conv2 = tf.keras.layers.Conv2D(32, kernel_size=kernel_size, 
# #                                             use_bias=False, 
# #                                             kernel_initializer='he_normal',
# #                                             strides=1, padding='same', 
# #                                             activation=tf.nn.relu6)
# #         self.conv3 = tf.keras.layers.Conv2D(16, kernel_size=kernel_size, 
# #                                             use_bias=False, 
# #                                             kernel_initializer='he_normal',
# #                                             strides=1, padding='same', 
# #                                             activation=tf.nn.relu6)
#         self.conv2 = tf.keras.layers.Conv2D(1, kernel_size=kernel_size,  
#                                             use_bias=False,
#                                             kernel_initializer='he_normal',
#                                             strides=1, padding='same', 
#                                             activation=tf.math.sigmoid)

#     def call(self, inputs):
#         avg_out = tf.reduce_mean(inputs, axis=3)
#         max_out = tf.reduce_max(inputs,  axis=3)
#         x = tf.stack([avg_out, max_out], axis=3) 
#         x = self.conv1(x)
# #         x = self.conv2(x)
# #         x = self.conv3(x)
#         return self.conv2(x)

#     def get_config(self):

#         config = super().get_config().copy()
#         config.update({
#             'conv1': self.conv1,
#             'conv2': self.conv2,
#         })
#         return config
    
# # A custom layer
# class ChannelAttentionModule(tf.keras.layers.Layer):
#     def __init__(self, ratio=8):
#         '''
#         paper: https://arxiv.org/abs/1807.06521
#         code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
#         '''
#         super(ChannelAttentionModule, self).__init__()
#         self.ratio = ratio
#         self.gapavg = tf.keras.layers.GlobalAveragePooling2D()
#         self.gmpmax = tf.keras.layers.GlobalMaxPooling2D()
        
#     def build(self, input_shape):
#         self.conv1 = tf.keras.layers.Conv2D(input_shape[-1]//self.ratio, 
#                                             kernel_size=1, 
#                                             strides=1, padding='same',
#                                             use_bias=True, activation=tf.nn.relu)
    
#         self.conv2 = tf.keras.layers.Conv2D(input_shape[-1], 
#                                             kernel_size=1, 
#                                             strides=1, padding='same',
#                                             use_bias=True, activation=tf.nn.relu)
#         super(ChannelAttentionModule, self).build(input_shape)

#     def call(self, inputs):
#         # compute gap and gmp pooling 
#         gapavg = self.gapavg(inputs)
#         gmpmax = self.gmpmax(inputs)
#         gapavg = tf.keras.layers.Reshape((1, 1, gapavg.shape[1]))(gapavg)   
#         gmpmax = tf.keras.layers.Reshape((1, 1, gmpmax.shape[1]))(gmpmax)   
#         # forward passing to the respected layers
#         gapavg_out = self.conv2(self.conv1(gapavg))
#         gmpmax_out = self.conv2(self.conv1(gmpmax))
#         return tf.math.sigmoid(gapavg_out + gmpmax_out)
    
#     def get_output_shape_for(self, input_shape):
#         return self.compute_output_shape(input_shape)

#     def compute_output_shape(self, input_shape):
#         output_len = input_shape[3]
#         return (input_shape[0], output_len)
    
#     def get_config(self):

#         config = super().get_config().copy()
#         config.update({
#             'ratio': self.ratio,
#             'gapavg': self.gapavg,
#             'gmpmax': self.gmpmax,
#         })
#         return config

In [None]:
# # Original Src: https://github.com/bfelbo/DeepMoji/blob/master/deepmoji/attlayer.py
# # Adoped and Modified: https://www.kaggle.com/c/human-protein-atlas-image-classification/discussion/77269#454482
# class AttentionWeightedAverage2D(tf.keras.layers.Layer):
#     def __init__(self, **kwargs):
#         self.init = tf.keras.initializers.get('uniform')
#         super(AttentionWeightedAverage2D, self).__init__(** kwargs)

#     def build(self, input_shape):
#         self.input_spec = [tf.keras.layers.InputSpec(ndim=4)]
#         assert len(input_shape) == 4
#         self.W = self.add_weight(shape=(input_shape[3], 1),
#                                  name='{}_W'.format(self.name),
#                                  initializer=self.init)
#         self._trainable_weights = [self.W]
#         super(AttentionWeightedAverage2D, self).build(input_shape)

#     def call(self, x):
#         # computes a probability distribution over the timesteps
#         # uses 'max trick' for numerical stability
#         # reshape is done to avoid issue with Tensorflow
#         # and 2-dimensional weights
#         logits  = K.dot(x, self.W)
#         x_shape = K.shape(x)
#         logits  = K.reshape(logits, (x_shape[0], x_shape[1], x_shape[2]))
#         ai      = K.exp(logits - K.max(logits, axis=[1,2], keepdims=True))
        
#         att_weights    = ai / (K.sum(ai, axis=[1,2], keepdims=True) + K.epsilon())
#         weighted_input = x * K.expand_dims(att_weights)
#         result         = K.sum(weighted_input, axis=[1,2])
#         return result

#     def get_output_shape_for(self, input_shape):
#         return self.compute_output_shape(input_shape)

#     def compute_output_shape(self, input_shape):
#         output_len = input_shape[3]
#         return (input_shape[0], output_len)
    
#     def get_config(self):

#         config = super().get_config().copy()
#         config.update({
#             'init': self.init,
#         })
#         return config

In [None]:
def create_model(name, input_shape, classes, output_bias=None):
    
    # Dictionary mapping name to model function
    
    EFFICIENT_NETS = {'B0': efn.EfficientNetB0, 
                      'B1': efn.EfficientNetB1, 
                      'B2': efn.EfficientNetB2, 
                      'B3': efn.EfficientNetB3, 
                      'B4': efn.EfficientNetB4, 
                      'B5': efn.EfficientNetB5, 
                      'B6': efn.EfficientNetB6,
                      'B7': efn.EfficientNetB7}
    
    # Output layer bias initialization
    
    if output_bias is None:
        output_bias = 'zeros'
    else:
        output_bias = tf.keras.initializers.Constant(output_bias)
        
    
    # Base model
    
    base_model = EFFICIENT_NETS[name](include_top=False, 
                                      weights='imagenet', 
                                      input_shape=input_shape)
    
    # Model
    base_model.trainable = True
    
    for layer in base_model.layers:
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            # we do aggressive exponential smoothing of batch norm
            # parameters to faster adjust to our new dataset
            layer.momentum = 0.99
    
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs)
#     cam = ChannelAttentionModule()(x)
#     camx = cam * x
#     camx = tf.keras.layers.BatchNormalization()(camx)
#     sam = SpatialAttentionModule()(camx)
#     spnx = sam * camx
#     spnx = tf.keras.layers.BatchNormalization()(spnx)
#     gap = tf.keras.layers.GlobalAveragePooling2D()(spnx)
#     sam1 = SpatialAttentionModule()(camx)
#     sam1 = tf.keras.layers.BatchNormalization()(sam1)
#     wvgx = tf.keras.layers.GlobalAveragePooling2D()(sam1)
#     gapavg = tf.keras.layers.Average()([gap, wvgx])
#     gapavg = tf.keras.layers.BatchNormalization()(gapavg)
#     awgavg = AttentionWeightedAverage2D()(x)
    
#     x = tf.keras.layers.Add()([gap, awgavg])
#     x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.Dense(256, activation=tf.nn.relu)(x)
#     x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
#     x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(rate=0.2)(x)
#     x = tf.keras.layers.Dense(32)(x)
#     x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.LeakyReLU()(x)
#     x = tf.keras.layers.Dropout(rate=0.3)(x)
#     x = tf.keras.layers.Dense(128)(x)
#     x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.LeakyReLU()(x)
#     x = tf.keras.layers.Dropout(rate=0.3)(x)   
#     x = tf.keras.layers.Reshape((320,8))(x)
#     x = tf.keras.layers.SeparableConv1D(8, 20, activation='relu')(x)
#     x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.Dropout(rate=0.2)(x)
#     x = tf.keras.layers.SeparableConv1D(16, 20, activation='relu')(x)
#     x = tf.keras.layers.Dropout(rate=0.2)(x)
#     x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.SeparableConv1D(32, 20, activation='relu')(x)
#     x = tf.keras.layers.Dropout(rate=0.2)(x)
#     x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.GlobalMaxPooling1D()(x)
#     x = tf.keras.layers.Dropout(rate=0.2)(x)
#     x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(classes, bias_initializer=output_bias)(x)
    outputs = tf.keras.layers.Activation('sigmoid', dtype='float32')(x) # Supports mixed-precision training
    
    model = tf.keras.Model(inputs, outputs)
    
    return model

In [None]:
def compile_model(model, lr=0.0001):
    
    optimizer = tf.keras.optimizers.Adam(lr=lr)
    
    loss = 'binary_crossentropy'
    
#     loss = get_weighted_loss(pos_weights, neg_weights)
        
    metrics = [
        tf.keras.metrics.AUC(name='auc', multi_label=True)
    ]

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    return model

In [None]:
model = create_model(name='B7', 
                             input_shape=(600,600,3), 
                             classes=11)
model = compile_model(model, lr=0.01)

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
model.summary()

In [None]:
def compute_class_weights(labels):
    """
    Compute positive and negative frequences for each class.

    Args:
        labels (np.array): matrix of labels, size (num_examples, num_classes)
    Returns:
        positive_frequencies (np.array): array of positive frequences for each
                                         class, size (num_classes)
        negative_frequencies (np.array): array of negative frequences for each
                                         class, size (num_classes)
    """
    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
    
    # total number of patients (rows)
    N = labels.shape[0]
    weights = {}
    
    positive_frequencies = np.mean(labels, axis=0)
#     negative_frequencies = 1 - positive_frequencies

    ### END CODE HERE ###
    w = (1 / positive_frequencies)/11.0
    
    return dict(enumerate(w))


In [None]:
import pandas as pd
df = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/train.csv')

In [None]:
df_sub = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/sample_submission.csv')

# Get the multi-labels
label_cols = df_sub.columns[1:]
labels = df[label_cols].values

In [None]:
df.head()

In [None]:
class_weights = compute_class_weights(labels)
class_weights

In [None]:
def compute_class_freqs(labels):
    """
    Compute positive and negative frequences for each class.

    Args:
        labels (np.array): matrix of labels, size (num_examples, num_classes)
    Returns:
        positive_frequencies (np.array): array of positive frequences for each
                                         class, size (num_classes)
        negative_frequencies (np.array): array of negative frequences for each
                                         class, size (num_classes)
    """
    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
    
    # total number of patients (rows)
    N = labels.shape[0]
    
    positive_frequencies = np.mean(labels, axis=0)
    negative_frequencies = 1 - positive_frequencies

    ### END CODE HERE ###
    return positive_frequencies, negative_frequencies

In [None]:
freq_pos, freq_neg = compute_class_freqs(labels)

In [None]:
import seaborn as sns

data = pd.DataFrame({"Class": label_cols, "Label": "Positive", "Value": freq_pos})
data = data.append([{"Class": label_cols[l], "Label": "Negative", "Value": v} for l,v in enumerate(freq_neg)], ignore_index=True)
plt.xticks(rotation=90)
f = sns.barplot(x="Class", y="Value", hue="Label" ,data=data)

In [None]:
def create_callbacks(model_save_path, verbose=1, e_s=10, e=4):
    
    verbose = int(verbose>0)
    
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    
    cpk_path = f'{model_save_path}/model.h5'

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=cpk_path,
        monitor='val_loss',
        mode='min',
        save_best_only=True,
        verbose=verbose
    )

    reducelr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        mode='min',
        factor=0.1,
        patience=e,
        verbose=0
    )

    earlystop = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        mode='min',
        patience=e_s, 
        verbose=verbose
    )
    
    callbacks = [checkpoint, reducelr, earlystop]
    
    return callbacks

In [None]:
def count_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
# import pandas as pd

# train = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/train.csv')

In [None]:
# train.head(3)

In [None]:
# train = train.iloc[:,1:-1]

In [None]:
# train.head(1)

In [None]:
# labels_ = train.to_numpy()

In [None]:

# freq_pos, freq_neg = compute_class_freqs(labels_)

In [None]:
# labels = np.array(train.columns)

In [None]:
# labels

In [None]:
# import seaborn as sns

# data = pd.DataFrame({"Class": labels, "Label": "Positive", "Value": freq_pos})
# data = data.append([{"Class": labels[l], "Label": "Negative", "Value": v} for l,v in enumerate(freq_neg)], ignore_index=True)
# plt.xticks(rotation=90)
# f = sns.barplot(x="Class", y="Value", hue="Label" ,data=data)

In [None]:
# pos_weights = freq_neg
# neg_weights = freq_pos
# pos_contribution = freq_pos * pos_weights 
# neg_contribution = freq_neg * neg_weights

In [None]:
# data = pd.DataFrame({"Class": labels, "Label": "Positive", "Value": pos_contribution})
# data = data.append([{"Class": labels[l], "Label": "Negative", "Value": v} 
#                         for l,v in enumerate(neg_contribution)], ignore_index=True)
# plt.xticks(rotation=90)
# sns.barplot(x="Class", y="Value", hue="Label" ,data=data);

In [None]:
# def get_weighted_loss(pos_weights, neg_weights, epsilon=1e-7):
#     """
#     Return weighted loss function given negative weights and positive weights.

#     Args:
#       pos_weights (np.array): array of positive weights for each class, size (num_classes)
#       neg_weights (np.array): array of negative weights for each class, size (num_classes)
    
#     Returns:
#       weighted_loss (function): weighted loss function
#     """
#     def weighted_loss(y_true, y_pred):
#         """
#         Return weighted loss value. 

#         Args:
#             y_true (Tensor): Tensor of true labels, size is (num_examples, num_classes)
#             y_pred (Tensor): Tensor of predicted labels, size is (num_examples, num_classes)
#         Returns:
#             loss (Tensor): overall scalar loss summed across all classes
#         """
        
#         ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###

#         loss_pos = -1. * K.sum(K.mean(pos_weights * y_true * K.log(y_pred+epsilon), axis=0))
#         loss_neg = -1. * K.sum(K.mean(neg_weights * (1 - y_true) * K.log(1-y_pred+epsilon), axis=0))
#         return loss_pos+loss_neg
    
#         ### END CODE HERE ###
#     return weighted_loss

In [None]:
folds_val_auc = [None] * FOLDS # Store the validation auc for each fold

skf = KFold(n_splits=FOLDS, shuffle=True, random_state=SEED)

DISPLAY_PLOT = True

print(f'Training...')


for fold, (train_idx, valid_idx) in enumerate(skf.split(np.arange(NUM_TF_RECS))):
    
    print(f'\n\n{"*"*100} \nFOLD: {fold+1}')
    
    # Input Pipeline ******************************************************
    
    train_files = tf.io.gfile.glob(f'{DATA_PATH}/train_tfrecords/{idx:02}*.tfrec' for idx in train_idx)
    valid_files = tf.io.gfile.glob(f'{DATA_PATH}/train_tfrecords/{idx:02}*.tfrec' for idx in valid_idx)
    
    ds = Dataset(IMG_SIZE)
    
    train_ds = ds.generator(train_files, 
                            BATCH_SIZE*REPLICAS, 
                            repeat=True, 
                            augment=True, 
                            shuffle=True,
                            cache=True)

    valid_ds = ds.generator(valid_files, 
                            BATCH_SIZE*REPLICAS,  
                            repeat=False, 
                            augment=False, 
                            shuffle=False,
                            cache=False)
    
    
    # Calculate the steps_per_epoch
    
    steps_per_epoch = count_items(train_files)//(BATCH_SIZE*REPLICAS) * 2
    
    
    # Build Model ******************************************************
    
    if fold==0:
        lr=0.0001
        e_s=10
        e=4
    else:
        lr=0.00001
        e_s=5
        e=2
        
    print('Learning Rate: '+str(lr))
    
    tf.keras.backend.clear_session()
        
    with strategy.scope():
        
        if fold==0:
            model = create_model(name=EFF_NET, 
                             input_shape=(IMG_SIZE,IMG_SIZE,3), 
                             classes=11)
        
        else:
            model.load_weights('/kaggle/working/models/model.h5')

        
        
        model = compile_model(model, lr=lr)
        
    print(f'\nModel initialized and compiled: EfficientNet-{EFF_NET}')
    
        
    # Train ******************************************************
   
    callbacks = create_callbacks(MODEL_PATH, verbose=VERBOSE, e_s=e_s, e=e)

    print(f'\nModel training...\n')
    
    history = model.fit(train_ds, 
                        epochs=EPOCHS, 
                        steps_per_epoch=steps_per_epoch,
                        validation_data=valid_ds, 
                        callbacks=callbacks, 
                        verbose=VERBOSE,
                       class_weight=class_weights)
    
    # Save acc for each fold in a list
    folds_val_auc[fold] = max(history.history['val_auc'])
    
    print(f'\nModel trained \n\nFOLD-{fold+1} Validation AUC = {folds_val_auc[fold]}')
    
    n_epochs = len(history.history['loss'])
    
    # PLOT TRAINING
    # https://www.kaggle.com/cdeotte/triple-stratified-kfold-with-tfrecords
    if DISPLAY_PLOT:        
        plt.figure(figsize=(15,5))
        plt.plot(np.arange(n_epochs),history.history['auc'],'-o',label='auc',color='#ff7f0e')
        plt.plot(np.arange(n_epochs),history.history['val_auc'],'-o',label='Val auc',color='#1f77b4')
        
        x = np.argmax( history.history['val_auc'] ); y = np.max( history.history['val_auc'] )
        xdist = plt.xlim()[1] - plt.xlim()[0]; ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x,y,s=200,color='#1f77b4'); plt.text(x-0.03*xdist,y-0.13*ydist,'max auc\n%.2f'%y,size=14)
        
        plt.ylabel('auc',size=14); plt.xlabel('Epoch',size=14)
        plt.legend(loc=2)
        
        plt2 = plt.gca().twinx()
        
        plt2.plot(np.arange(n_epochs),history.history['loss'],'-o',label='Train Loss',color='#2ca02c')
        plt2.plot(np.arange(n_epochs),history.history['val_loss'],'-o',label='Val Loss',color='#d62728')
        
        x = np.argmin( history.history['val_loss'] ); y = np.min( history.history['val_loss'] )
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x,y,s=200,color='#d62728'); plt.text(x-0.03*xdist,y+0.05*ydist,'min loss',size=14)
        
        plt.ylabel('Loss',size=14)
        plt.legend(loc=3)
        plt.show()